summaryrefslogtreecommitdiff
path: root/ir.c
diff options
context:
space:
mode:
Diffstat (limited to 'ir.c')
-rw-r--r--ir.c159
1 files changed, 131 insertions, 28 deletions
diff --git a/ir.c b/ir.c
index 8a2baf5..255d578 100644
--- a/ir.c
+++ b/ir.c
@@ -1,5 +1,8 @@
#include <stdarg.h>
+#include <assert.h>
+
#include "ir.h"
+#include "strio.h"
extern int no_opt;
@@ -18,6 +21,55 @@ int type_base_eql(Type *a, Type *b) {
return a->next ? type_base_eql(a->next, b->next) : 1;
}
+Str type_desc(Type *t, Arena *arena) {
+ (void)arena;
+ switch (t->t) {
+ case T_TUPLE: return S("tuple");
+ case T_BOOL: return S("bool");
+ case T_INT: return S("i64");
+ default: return S("N/A");
+ }
+}
+
+void type_err(Node *n, Lexer *l) {
+ Str s = S("");
+ for (int i = 0; i < n->in.len; i++) {
+ if (i > 0) str_cat(&s, S(", "), &l->arena);
+ str_cat(&s, type_desc(&n->in.data[i]->val.type, &l->arena), &l->arena);
+ }
+ lex_error_at(l, n->src_pos, LE_ERROR, str_fmt(&l->arena, "type error (%S)", s));
+}
+
+int type_check(Node *n) {
+ switch (n->type) {
+ case N_OP_NEG:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_INT };
+ return n->in.data[0]->val.type.t == T_INT;
+ case N_OP_NOT:
+ n->val.type = (Type) { .lvl = T_TOP, .t = n->in.data[0]->val.type.t };
+ return n->in.data[0]->val.type.t == T_INT || n->in.data[0]->val.type.t == T_BOOL;
+ case N_OP_AND: case N_OP_OR: case N_OP_XOR:
+ n->val.type = (Type) { .lvl = T_TOP, .t = n->in.data[0]->val.type.t };
+ return (n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT)
+ || (n->in.data[0]->val.type.t == T_BOOL && n->in.data[1]->val.type.t == T_BOOL);
+ case N_OP_ADD: case N_OP_SUB: case N_OP_MUL: case N_OP_DIV:
+ case N_OP_SHL: case N_OP_SHR:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_INT };
+ return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT;
+ case N_CMP_LES: case N_CMP_GTR:
+ case N_CMP_LTE: case N_CMP_GTE:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL };
+ return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT;
+ case N_CMP_EQL:
+ case N_CMP_NEQ:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL };
+ return (n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT)
+ || (n->in.data[0]->val.type.t == T_BOOL && n->in.data[1]->val.type.t == T_BOOL);
+ default:
+ return 1;
+ }
+}
+
/* nodes */
const char *node_type_name(NodeType t) {
@@ -90,11 +142,19 @@ void node_kill(Node *n, Proc *p) {
node_die(n, p);
}
+void node_add_out(Proc *p, Node *a, Node *b) {
+ ZDA_PUSH(&a->out, b, &p->arena);
+ b->refs++;
+}
+
+void node_add_in(Proc *p, Node *a, Node *b) {
+ ZDA_PUSH(&a->in, b, &p->arena);
+ b->refs++;
+}
+
void node_add(Proc *p, Node *src, Node *dest) {
- ZDA_PUSH(&src->out, dest, &p->arena);
- ZDA_PUSH(&dest->in, src, &p->arena);
- src->refs++;
- dest->refs++;
+ node_add_out(p, src, dest);
+ node_add_in(p, dest, src);
if (dest->src_pos.n == 0) dest->src_pos = src->src_pos;
}
@@ -122,6 +182,7 @@ Node *node_new_empty(Proc *p, NodeType t) {
return n;
}
+int type_check(Node *);
Node *node_newv(Proc *p, NodeType t, ...) {
Node *node = node_new_empty(p, t);
va_list ap;
@@ -136,6 +197,7 @@ Node *node_newv(Proc *p, NodeType t, ...) {
}
Node *node_dedup_lit(Proc *p, Value v) {
+ return NULL;
/* TODO: this is probably real inefficient for large procedure graphs,
* but does it matter? how many nodes are direct children of the start node?
* how many literals even usually occur in a procedure? */
@@ -238,14 +300,25 @@ Value node_compute(Node *n, Lexer *l) {
return n->val;
}
+#include <stdio.h>
+
+/* replace an in[] with a peepholed ver */
+void node_peephole_in(Node *n, int idx, Proc *p, Lexer *l) {
+ node_del_out(n->in.data[idx], n);
+ Node *r = node_peephole(n->in.data[idx], p, l);
+ node_add_out(p, r, n);
+ n->in.data[idx] = r;
+}
+
+#define NODE(...) node_peephole(node_new(p, __VA_ARGS__), p, l)
+
/* needs lexer for error reporting */
-Node *node_peephole(Node *n, Proc *p, Lexer *l) {
- if (no_opt) return n;
+Node *node_idealize(Node *n, Proc *p, Lexer *l) {
+ if (no_opt) return NULL;
if (n->type != N_LIT) {
Value v = node_compute(n, l);
if (v.type.lvl == T_CONST) {
- node_kill(n, p);
Node *t = node_dedup_lit(p, v);
if (t) return t;
Node *r = node_new(p, N_LIT, p->start);
@@ -260,33 +333,58 @@ Node *node_peephole(Node *n, Proc *p, Lexer *l) {
if (node_op_communative(n->type)) {
/* transformations to help encourage constant folding */
/* the overall trend is to move them rightwards */
- if (in[0]->type == N_LIT
- && in[1]->type == n->type
- && in[1]->in.data[0]->type != N_LIT
- && in[1]->in.data[1]->type == N_LIT) {
- /* op(lit, op(X, lit)) -> op(X, op(lit, lit)) */
- Node *tmp = in[1]->in.data[0];
- in[1]->in.data[0] = in[0];
- in[0] = tmp;
- /* TODO: ...would it break anything at all to just do in[1] = node_peephole(in[1], p, l)?
- * probably not, right? */
- } else if (in[0]->type == n->type
+
+ if ((in[0]->type == N_LIT && in[1]->type != N_LIT)
+ || (in[1]->type == n->type && in[0]->type != n->type)) {
+
+ if (in[1]->type == n->type) {
+ fprintf(stderr, "op(X, op(Y, Z)) -> op(op(Y, Z), X)\n");
+ } else {
+ fprintf(stderr, "op(lit, X) -> op(X, lit)\n");
+ }
+
+ return NODE(n->type, in[1], in[0]);
+ }
+
+ if (in[1]->type == N_LIT
+ && in[0]->type == n->type
+ && in[0]->in.data[0]->type != N_LIT
+ && in[0]->in.data[1]->type == N_LIT) {
+
+ fprintf(stderr, "op(op(X, lit), lit) -> op(X, op(lit, lit))\n");
+
+ return NODE(n->type,
+ in[0]->in.data[0],
+ NODE(n->type, in[0]->in.data[1], in[1]));
+ }
+
+ /* op(op(X, lit), Y) -> op(op(X, Y), lit) */
+ if (in[0]->type == n->type
&& in[0]->in.data[0]->type != N_LIT
&& in[0]->in.data[1]->type == N_LIT
&& in[1]->type != N_LIT) {
- /* op(op(X, lit), Y) -> op(op(X, Y), lit) */
- Node *tmp = in[0]->in.data[1];
- in[0]->in.data[1] = in[1];
- in[1] = tmp;
- } else if (in[0]->type == N_LIT && in[1]->type != N_LIT) {
- /* op(lit, X) -> op(X, lit) */
- Node *tmp = in[0];
- in[0] = in[1];
- in[1] = tmp;
+ fprintf(stderr, "op(op(X, lit), Y) -> op(op(X, Y), lit)\n");
+
+ return NODE(n->type,
+ NODE(n->type, in[0]->in.data[0], in[1]),
+ in[0]->in.data[1]);
}
}
- return n;
+ return NULL;
+}
+
+Node *node_peephole(Node *n, Proc *p, Lexer *l) {
+ Node *r = node_idealize(n, p, l);
+ if (r) {
+ /* make sure r doesn't get deleted even if connected to n */
+ node_add_out(p, r, p->keepalive);
+ node_kill(n, p);
+ node_del_out(r, p->keepalive);
+ return r;
+ } else {
+ return n;
+ }
}
/* procedures */
@@ -294,6 +392,11 @@ Node *node_peephole(Node *n, Proc *p, Lexer *l) {
void proc_init(Proc *proc, Str name) {
memset(proc, 0, sizeof(Proc));
proc->start = node_new_empty(proc, N_START);
+ proc->start->val.type = (Type) {
+ .lvl = T_BOT,
+ .t = T_TUPLE,
+ .next = NULL
+ };
proc->keepalive = node_new_empty(proc, N_KEEPALIVE);
proc->name = name;
}