#include #include #include "ir.h" #include "strio.h" /* nodes */ const char *node_type_name(NodeType t) { static const char *names[] = { #define X(n, s) s, NODE_TYPE_LIST #undef X }; return names[t]; } void node_die(Node *n, Proc *p) { assert(n->refs == 0); assert(n->op != N_DEAD); n->op = N_DEAD; n->prev_free = p->free_list; p->free_list = n; } void node_del_out(Node *n, Node *p) { for (int i = n->out.len - 1; i >= 0; i--) { if (n->out.data[i] == p) { if (p) p->refs--; n->out.len--; if (i < n->out.len) { n->out.data[i] = n->out.data[n->out.len]; } break; } } } void node_del_in(Node *n, Node *p) { for (int i = n->in.len - 1; i >= 0; i--) { if (n->in.data[i] == p) { if (p) p->refs--; n->in.len--; if (i < n->in.len) { memmove(&n->in.data[i], &n->in.data[i + 1], sizeof(Node*) * (n->in.len - i)); } break; } } } void node_kill(Node *n, Proc *p) { if (p->ctrl == n) { /* probably this is fine */ p->ctrl = CTRL(n); } while (n->refs > 0 && n->in.len > 0) node_remove(p, n->in.data[0], n); while (n->refs > 0 && n->out.len > 0) node_remove(p, n, n->out.data[0]); assert(n->refs == 0); } void node_add_out(Proc *p, Node *a, Node *b) { ZDA_PUSH(&p->arena, &a->out, b); if (b) b->refs++; } void node_add_in(Proc *p, Node *a, Node *b) { ZDA_PUSH(&p->arena, &a->in, b); if (b) b->refs++; } void node_set_in(Proc *p, Node *n, int idx, Node *to) { Node *in = n->in.data[idx]; if (in) in->refs--; node_add_out(p, to, n); node_del_out(in, n); n->in.data[0] = to; if (in->out.len < 1) node_kill(in, p); } void node_add(Proc *p, Node *src, Node *dest) { if (src) assert(src->op != N_DEAD); if (dest) assert(dest->op != N_DEAD); node_add_in(p, dest, src); if (!src) return; node_add_out(p, src, dest); if (dest->src_pos.n == 0) dest->src_pos = src->src_pos; else if (src->src_pos.n != 0) { int lo = dest->src_pos.ofs < src->src_pos.ofs ? dest->src_pos.ofs : src->src_pos.ofs; int hi = dest->src_pos.ofs + dest->src_pos.n > src->src_pos.ofs + src->src_pos.n ? dest->src_pos.ofs + dest->src_pos.n : src->src_pos.ofs + src->src_pos.n; dest->src_pos = (LexSpan) { lo, hi - lo }; } } void node_remove(Proc *p, Node *src, Node *dest) { assert(dest->op != N_DEAD); node_del_in(dest, src); if (dest->refs < 1) node_die(dest, p); if (src) { assert(src->op != N_DEAD); node_del_out(src, dest); if (src->out.len < 1) node_kill(src, p); } } static int global_node_count = 0; Node *node_new_empty(Proc *p, NodeType t) { Node *n; if (p->free_list) { n = p->free_list; assert(n->op == N_DEAD); p->free_list = n->prev_free; memset(n, 0, sizeof(Node)); } else { n = new(&p->arena, Node); } n->op = t; n->id = global_node_count++; return n; } Node *node_newv(Proc *p, NodeType t, Node *ctrl, ...) { Node *node = node_new_empty(p, t); va_list ap; va_start(ap, ctrl); node_add(p, ctrl, node); for (;;) { Node *n = va_arg(ap, Node *); if (!n) break; node_add(p, n, node); } va_end(ap); return node; } Node *node_dedup_lit(Proc *p, Value v) { /* TODO: this is probably real inefficient for large procedure graphs, * but does it matter? how many nodes are direct children of the start node? * how many literals even usually occur in a procedure? */ for (int i = 0; i < p->start->out.len; i++) { Node *t = p->start->out.data[i]; if (t->op == N_LIT && type_eql(&t->type, &v.type) && t->val.i == v.i) { return t; } } return NULL; } Node *node_new_lit(Proc *p, Value v) { Node *t = node_dedup_lit(p, v); if (t) return t; Node *n = node_new(p, N_LIT, NULL, p->start); n->val = v; return n; } Node *node_new_lit_i64(Proc *p, int64_t i) { return node_new_lit(p, (Value) { { .lvl = T_CONST, .t = T_INT }, { .i = i } }); } Node *node_new_lit_bool(Proc *p, int b) { return node_new_lit(p, (Value) { { .lvl = T_CONST, .t = T_BOOL }, { .i = b } }); } /* procedures */ void proc_init(Proc *proc, Str name) { memset(proc, 0, sizeof(Proc)); proc->start = node_new(proc, N_START, NULL); proc->start->type = (Type) { .lvl = T_BOT, .t = T_TUPLE, .next = NULL }; proc->stop = node_new_empty(proc, N_STOP); proc->ctrl = proc->start; proc->keepalive = node_new(proc, N_KEEPALIVE, NULL); proc->name = name; } void proc_free(Proc *proc) { arena_free(&proc->arena); } /* scope */ NameBinding *scope_find(Scope *scope, Str name) { for (ScopeFrame *f = scope->tail; f; f = f->prev) { for (NameBinding *b = f->latest; b; b = b->prev) { if (str_eql(b->name, name)) { return b; } } } return NULL; } ScopeFrame *scope_push(Scope *scope, Proc *proc) { ScopeFrame *f; if (scope->free_scope) { f = scope->free_scope; *f = (ScopeFrame) { 0 }; scope->free_scope = f->prev; } else { f = new(&proc->arena, ScopeFrame); } f->prev = scope->tail; scope->tail = f; return f; } ScopeFrame *scope_pop(Scope *scope, Proc *proc) { ScopeFrame *f = scope->tail; scope->tail = f->prev; f->prev = scope->free_scope; scope->free_scope = f; for (NameBinding *b = f->latest; b; ) { NameBinding *p = b->prev; b->prev = scope->free_bind; scope->free_bind = b; node_remove(proc, b->node, proc->keepalive); b = p; } return scope->tail; } /* returns previous value */ NameBinding *scope_bind(Scope *scope, Str name, Node *value, LexSpan pos, Proc *proc) { NameBinding *prev = scope_find(scope, name); NameBinding *b; if (scope->free_bind) { b = scope->free_bind; *b = (NameBinding) { 0 }; scope->free_bind = b->prev; } else { b = new(&proc->arena, NameBinding); } b->name = name; b->prev = scope->tail->latest; scope->tail->latest = b; b->node = value; b->src_pos = pos; node_add(proc, value, proc->keepalive); return prev; } NameBinding *scope_update(NameBinding *b, Node *to, Proc *proc) { Node *n = b->node; node_add(proc, to, proc->keepalive); b->node = to; node_remove(proc, n, proc->keepalive); return b; } /* adds to keepalive so these aren't invalidated */ void scope_collect(Scope *scope, Proc *proc, ScopeNameList *nl, Arena *arena) { for (ScopeFrame *f = scope->tail; f; f = f->prev) { for (NameBinding *b = f->latest; b; b = b->prev) { node_add(proc, b->node, proc->keepalive); ZDA_PUSH(arena, nl, (ScopeName) { b->name, b->node }); } } } void scope_uncollect(Scope *scope, Proc *proc, ScopeNameList *nl) { for (int i = 0; i < nl->len; i++) { node_remove(proc, nl->data[i].node, proc->keepalive); } } /* types */ int type_eql(Type *a, Type *b) { if (a->t != b->t) return 0; if (a->lvl != b->lvl) return 0; if (a->next != b->next) return 0; return a->next ? type_eql(a->next, b->next) : 1; } int type_base_eql(Type *a, Type *b) { if (a->t != b->t) return 0; if (a->next != b->next) return 0; return a->next ? type_base_eql(a->next, b->next) : 1; } int value_eql(Value *a, Value *b) { if (!type_eql(&a->type, &b->type)) return 0; return a->i == b->i; } Str type_desc(Type *t, Arena *arena) { (void)arena; switch (t->lvl) { case T_CTRL: return S("ctrl"); case T_XCTRL: return S("~ctrl"); default: break; } switch (t->t) { case T_TUPLE: return S("tuple"); case T_BOOL: return S("bool"); case T_INT: return S("i64"); case T_PTR: return str_fmt(arena, "^%S", type_desc(t->next, arena)); default: return S("N/A"); } } void type_err(Node *n, Lexer *l) { Str s = S(""); for (int i = 0; i < n->in.len; i++) { if (i > 0) str_cat(&s, S(", "), &l->arena); str_cat(&s, type_desc(&IN(n, i)->type, &l->arena), &l->arena); } lex_error_at(l, n->src_pos, LE_ERROR, str_fmt(&l->arena, "type error %s (%S)", node_type_name(n->op), s)); } void type_expected(Type *want, Node *n, Lexer *l) { if (type_base_eql(want, &n->type)) return; lex_error_at(l, n->src_pos, LE_ERROR, str_fmt(&l->arena, "type error: expected %S, but got %S", type_desc(want, &l->arena), type_desc(&n->type, &l->arena))); } static int type_ok(Node *n) { switch (n->op) { case N_PHI: n->type = (Type) { .lvl = T_TOP, .t = IN(n, 1)->type.t }; for (int i = 2; i < n->in.len; i++) { if (!type_base_eql(&IN(n, i)->type, &n->type)) { return 0; } } return 1; case N_OP_NEG: n->type = (Type) { .lvl = T_TOP, .t = T_INT }; return CAR(n)->type.t == T_INT; case N_OP_NOT: n->type = (Type) { .lvl = T_TOP, .t = CAR(n)->type.t }; return CAR(n)->type.t == T_INT || CAR(n)->type.t == T_BOOL; case N_OP_AND: case N_OP_OR: case N_OP_XOR: n->type = (Type) { .lvl = T_TOP, .t = CAR(n)->type.t }; return (CAR(n)->type.t == T_INT && CDR(n)->type.t == T_INT) || (CAR(n)->type.t == T_BOOL && CDR(n)->type.t == T_BOOL); case N_OP_ADD: case N_OP_SUB: case N_OP_MUL: case N_OP_DIV: case N_OP_SHL: case N_OP_SHR: n->type = (Type) { .lvl = T_TOP, .t = T_INT }; return CAR(n)->type.t == T_INT && CDR(n)->type.t == T_INT; case N_CMP_LES: case N_CMP_GTR: case N_CMP_LTE: case N_CMP_GTE: n->type = (Type) { .lvl = T_TOP, .t = T_BOOL }; return CAR(n)->type.t == T_INT && CDR(n)->type.t == T_INT; case N_CMP_EQL: case N_CMP_NEQ: n->type = (Type) { .lvl = T_TOP, .t = T_BOOL }; return type_base_eql(&CAR(n)->type, &CDR(n)->type); /* (CAR(n)->type.t == T_INT && CDR(n)->type.t == T_INT) || (CAR(n)->type.t == T_BOOL && CDR(n)->type.t == T_BOOL); */ default: return 1; } } /* TODO: make it so * func foo(a, b i64) i64 { * let x i64 * if a < b { * return 0 * } else { * x := 3 * } * return x * } * doesn't throw warnings (e.g. don't generate phi nodes if one scope is * guaranteed to return early) */ int node_uninit(Node *n) { return n->op == N_UNINIT; } int node_maybe_uninit(Node *n) { if (node_uninit(n)) return 1; for (int i = 0; i < n->in.len; i++) { if (IN(n,i) && node_maybe_uninit(IN(n,i))) { return 1; } } return 0; } void uninit_check(Node *n, Lexer *l) { if (NMASK(n->op) & ~NM_PHI) { for (int i = 0; i < n->in.len; i++) { Node *o = IN(n, i); if (!o) continue; if (node_uninit(o)) { fprintf(stderr, "%s\n", node_type_name(n->op)); lex_error_at(l, o->src_pos, LE_WARN, str_fmt(&l->arena, "uninitialized %S", type_desc(&IN(n,i)->type, &l->arena))); } else if (node_maybe_uninit(o)) { lex_error_at(l, o->src_pos, LE_WARN, str_fmt(&l->arena, "possibly uninitialized %S", type_desc(&o->type, &l->arena))); } } } } void type_check(Node *n, Lexer *l) { uninit_check(n, l); if (!type_ok(n)) type_err(n, l); }