#include #include #include #include #include #include #include #include "lex.h" #include "arena.h" #include "dynarr.h" #include "strio.h" #include "ir.h" #include "peephole.h" int no_opt = 0; typedef struct { DYNARR(Proc) procs; } Unit; void unit_init(Unit *u) { (void)u; } void unit_free(Unit *u) { for (int i = 0; i < u->procs.len; i++) { proc_free(&u->procs.data[i]); } free(u->procs.data); } /* parsing */ Node *parse_expr(Lexer *l, Proc *p, Type *twant); /* TODO: eliminate unused if-else statements at the end of compilation * they don't get pruned out by peephole optimizations if there are phi * nodes that are connected to keepalive (due to still being in scope). * so probably this will have to be done in a separate step after the * graph has been generated. */ Node *ctrl(Proc *p, Node *n) { if (!n) { p->ctrl = node_new_lit(p, (Value) { .type = { .lvl = T_XCTRL, .t = T_NONE } }); return NULL; } p->ctrl = n; return n; } void parse_return(Lexer *l, Proc *p) { lex_next(l); Node *n; if (p->ret_type.t == T_NONE) { n = node_new(p, N_RETURN, p->ctrl); } else { Node *e = parse_expr(l, p, NULL); if (!type_base_eql(&e->type, &p->ret_type)) { lex_error_at(l, e->src_pos, LE_ERROR, str_fmt(&p->arena, "incorrect return type (expected %S, got %S)", type_desc(&p->ret_type, &p->arena), type_desc(&e->type, &p->arena))); } n = node_new(p, N_RETURN, p->ctrl, e); } n = node_peephole(n, p, l); if (n->type.lvl != T_XCTRL) { node_add(p, n, p->stop); } ctrl(p, n); ctrl(p, NULL); } Type parse_type(Lexer *l, Proc *proc) { (void)proc; Type t = { .lvl = T_BOT }; if (l->tok == TOK_DEREF) { lex_next(l); t.t = T_PTR; t.next = new(&proc->arena, Type); *t.next = parse_type(l, proc); return t; } lex_expected(l, TM_IDENT); if (l->tok == TOK_IDENT) { if (str_eql(l->ident, S("i64"))) { t.t = T_INT; } else if (str_eql(l->ident, S("bool"))) { t.t = T_BOOL; } else { lex_error(l, LE_ERROR, S("unknown type")); } } lex_next(l); return t; } void parse_let(Lexer *l, Proc *p) { recurse: lex_expect(l, TM_IDENT); Str name = l->ident; LexSpan pos = l->pos; lex_next(l); Node *rhs = NULL; Type t = { .t = T_NONE }; if (l->tok != TOK_EQL) { t = parse_type(l, p); if (l->tok != TOK_EQL) { rhs = node_new(p, N_UNINIT, p->start); rhs->type = t; } } if (l->tok == TOK_EQL) { lex_next(l); rhs = parse_expr(l, p, t.t == T_NONE ? NULL : &t); if (t.t != T_NONE) type_expected(&t, rhs, l); } NameBinding *b = scope_bind(&p->scope, name, rhs, pos, p); if (b) { lex_error_at(l, pos, LE_WARN, S("shadowing previous declaration")); lex_error_at(l, b->src_pos, LE_WARN, S("declared here")); } if (l->tok == TOK_COMMA) goto recurse; } void parse_stmt(Lexer *l, Proc *p); /* TODO: return node from this! */ void parse_block(Lexer *l, Proc *p, ScopeNameList *nl) { lex_next(l); scope_push(&p->scope, p); while (l->tok != TOK_RBRACE) { lex_expected_not(l, TM_EOF); parse_stmt(l, p); } if (nl) { scope_collect(&p->scope, p, nl, &p->arena); } scope_pop(&p->scope, p); lex_expected(l, TM_RBRACE); lex_next(l); } void parse_assign(Lexer *l, Proc *p) { Str name = l->ident; LexSpan pos = l->pos; lex_expect(l, TM_ASSIGN); lex_next(l); NameBinding *b = scope_find(&p->scope, name); if (!b) { lex_error_at(l, pos, LE_ERROR, S("undeclared identifier")); } Node *e = parse_expr(l, p, &b->node->type); if (!type_base_eql(&e->type, &b->node->type)) { lex_error_at(l, pos, LE_ERROR, str_fmt(&p->arena, "tried to assign value of type %S to variable of type %S", type_desc(&e->type, &p->arena), type_desc(&b->node->type, &p->arena))); } if (node_uninit(e)) { lex_error_at(l, e->src_pos, LE_ERROR, str_fmt(&p->arena, "uninitialized %S", type_desc(&e->type, &p->arena))); } else if (node_maybe_uninit(e)) { lex_error_at(l, e->src_pos, LE_WARN, str_fmt(&p->arena, "possibly uninitialized %S", type_desc(&e->type, &p->arena))); } scope_update(b, e, p); } /* TODO: Implement a better system for this. * * Something like: * * ScopeChangeList chg_true = {0}, chg_false = {0}; * scope_track_changes(&p->scope, &chg_true, p, scratch); * parse_block(l, p); * scope_rewind_changes(&p->scope, &chg_true, p); * scope_track_changes(&p->scope, &chg_false, p, scratch); * parse_block(l, p); * scope_merge_changes(&p->scope, &chg_true, &chg_false, region, p); * * We put a flag in Scope somewhere that marks where it's tracking changes to, * then have scope_update() look up the name binding being changed in the change * list; if present, update value to the new node. If not, create it with * orig set to the previous value of the name binding, and value to the new node. * Since we're only tracking _changes_, there should be some way to make sure that * newly created let-bindings after tracking starts aren't stored in the list. * Maybe give NameBindings an index, note the index of the latest binding at the * time scope_track_changes() is called, ignore any larger than it. * * Take care with making sure changes remain attached to keepalive --- rewind * doesn't remove the new values, since it's just to revert temporarily. * * Also, probably should have a scratch arena for this sort of throwaway parsing * data, instead of putting stuff in the procedure graph arena. It can be reset * after each statement. * * Also also, switch scopes from a linked list to something a bit faster, like * a simple arena-backed hash trie. * */ void merge_scope(Lexer *l, Proc *p, Node *region, ScopeNameList *before, ScopeNameList *ntrue, ScopeNameList *nfalse) { for (int i = 0; i < before->len; i++) { int j, k; ScopeName *b4 = &before->data[i]; for (j = 0; j < ntrue->len && !str_eql(ntrue->data[j].name, b4->name); j++); for (k = 0; k < nfalse->len && !str_eql(nfalse->data[k].name, b4->name); k++); ScopeName *yes = j < ntrue->len ? &ntrue->data[j] : NULL; ScopeName *no = k < nfalse->len ? &nfalse->data[k] : NULL; if (!yes && !no) continue; /* no change */ Node *phi; if (!no) { if (yes->node == b4->node) continue; phi = node_new(p, N_PHI, region, yes->node, b4->node); } else if (!yes) { if (no->node == b4->node) continue; phi = node_new(p, N_PHI, region, b4->node, no->node); } else { if (yes->node == b4->node && no->node == b4->node) continue; phi = node_new(p, N_PHI, region, yes->node, no->node); } fprintf(stderr, "phi('%.*s', %d, %d)\n", (int)b4->name.n, b4->name.s, phi->in.data[1]->id, phi->in.data[2]->id); phi = node_peephole(phi, p, l); NameBinding *b = scope_find(&p->scope, b4->name); assert(b); scope_update(b, phi, p); } } /* TODO: find out a way to encode known relations between nodes, based on the * conditional, within the body of an if statement --- e.g., for a statement * * if a < 5 { * if a < 10 { * foo() * } * } else { * if a > 3 { * bar() * } * } * * we should be able to infer that the second condition is always true, because * it's only reachable if a < 5, and 5 < 10. conversely, in the else branch, * we know that a >= 5, and can use that to make assumptions on comparisons * made there too! * */ void parse_if(Lexer *l, Proc *p) { lex_next(l); Node *cond = parse_expr(l, p, &(Type) { .t = T_BOOL }); Node *ctrl_if = NULL, *ctrl_else = NULL; Node *if_node = node_new(p, N_IF_ELSE, p->ctrl, cond); if_node->val = (Value) { .type = { .lvl = T_TOP, .t = T_TUPLE }, .tuple = { 0 } }; ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { .lvl = T_CTRL, .t = T_NONE } }); ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { .lvl = T_CTRL, .t = T_NONE } }); if_node = node_peephole(if_node, p, l); node_add(p, if_node, p->keepalive); Node *if_true = node_new(p, N_PROJ, if_node); Node *if_false = node_new(p, N_PROJ, if_node); if_true->val.i = 0; if_false->val.i = 1; if_true = node_peephole(if_true, p, l); if_false = node_peephole(if_false, p, l); node_remove(p, if_node, p->keepalive); assert(if_true->in.len > 0); assert(if_false->in.len > 0); node_add(p, if_true, p->keepalive); node_add(p, if_false, p->keepalive); ScopeNameList scope_before = { 0 }, scope_true = { 0 }, scope_false = { 0 }; scope_collect(&p->scope, p, &scope_before, &p->arena); if (cond->type.lvl == T_CONST) { if (cond->val.i) if_false->type.lvl = T_XCTRL; else if_true->type.lvl = T_XCTRL; } ctrl(p, if_true); lex_expected(l, TM_LBRACE); int pos = l->pos.ofs; parse_block(l, p, &scope_true); if_true->src_pos = (LexSpan) { .ofs = pos, .n = l->pos.ofs - pos }; ctrl_if = p->ctrl; node_add(p, ctrl_if, p->keepalive); if (l->tok == TOK_ELSE) { for (int i = 0; i < scope_before.len; i++) { scope_update(scope_find(&p->scope, scope_before.data[i].name), scope_before.data[i].node, p); } ctrl(p, if_false); lex_expect(l, TM_LBRACE); pos = l->pos.ofs; parse_block(l, p, &scope_false); if_false->src_pos = (LexSpan) { .ofs = pos, .n = l->pos.ofs - pos }; ctrl_else = p->ctrl; node_add(p, ctrl_else, p->keepalive); } Node *region; if (ctrl_else) { assert(ctrl_if); assert(ctrl_else); region = node_new(p, N_REGION, ctrl_if, ctrl_else); node_add_out(p, region, p->keepalive); node_remove(p, ctrl_if, p->keepalive); node_remove(p, ctrl_else, p->keepalive); } else { assert(ctrl_if); assert(if_false); region = node_new(p, N_REGION, ctrl_if, if_false); node_add_out(p, region, p->keepalive); node_remove(p, ctrl_if, p->keepalive); assert(if_true->refs > 0); assert(if_false->refs > 0); } ctrl(p, region); node_remove(p, if_true, p->keepalive); node_remove(p, if_false, p->keepalive); assert(p->ctrl->in.len > 0); assert(region->in.data[0]); assert(region->in.data[1]); merge_scope(l, p, region, &scope_before, &scope_true, &scope_false); scope_uncollect(&p->scope, p, &scope_true); scope_uncollect(&p->scope, p, &scope_false); scope_uncollect(&p->scope, p, &scope_before); node_del_out(region, p->keepalive); /* make sure we're not orphaning any phi nodes*/ if (p->ctrl->out.len < 1) { ctrl(p, node_peephole(p->ctrl, p, l)); } } void parse_stmt(Lexer *l, Proc *p) { /* TODO */ (void)l; switch (l->tok) { case TOK_RETURN: parse_return(l, p); break; case TOK_LET: parse_let(l, p); break; case TOK_LBRACE: parse_block(l, p, NULL); break; case TOK_IDENT: parse_assign(l, p); break; case TOK_IF: parse_if(l, p); break; default: lex_expected(l, TM_RBRACE); break; } } void parse_args_list(Lexer *l, Proc *proc) { Node *start = proc->start; int i = 0; struct { Str name; LexSpan pos; } idbuf[32]; int id = 0; while (l->tok != TOK_RPAREN && l->tok != TOK_EOF) { lex_expect(l, TM_IDENT); if (id == sizeof idbuf / sizeof *idbuf) { lex_error(l, LE_ERROR, S("too many arguments without specifying a type")); return; } idbuf[id].name = l->ident; idbuf[id].pos = l->pos; id++; lex_next(l); if (l->tok == TOK_COMMA) continue; Value v = (Value) { .type = parse_type(l, proc) }; lex_expected(l, TM_RPAREN | TM_COMMA); for (int j = 0; j < id; j++) { Node *proj = node_new(proc, N_PROJ, proc->start); proj->type = v.type; proj->val.i = i++; ZDA_PUSH(&proc->arena, &start->val.tuple, v); scope_bind(&proc->scope, idbuf[j].name, proj, idbuf[j].pos, proc); } id = 0; } lex_expected(l, TM_RPAREN); lex_next(l); } Node *find_return(Node *n) { if (n->op == N_RETURN) return n; for (int i = 0; i < n->out.len; i++) { Node *r = find_return(n->out.data[i]); if (r) return r; } return NULL; } void proc_opt_fwd(Proc *p, Lexer *l, Node *n) { fprintf(stderr, "%d %s\n", n->id, node_type_name(n->op)); if (n->walked == 2) return; n->walked = 2; switch (n->op) { case N_START: for (int i = 0; i < n->out.len; i++) { proc_opt_fwd(p, l, n->out.data[i]); } break; case N_IF_ELSE: if (n->out.len < 2) { //lex_error_at(l, n->src_pos, LE_ERROR, S("not all codepaths return")); } for (int i = 0; i < n->out.len; i++) { Node *r = find_return(n->out.data[i]); if (!r) { lex_error_at(l, n->out.data[i]->src_pos, LE_ERROR, S("not all codepaths return")); } proc_opt_fwd(p, l, n->out.data[i]); } break; case N_PROJ: proc_opt_fwd(p, l, n->out.data[0]); break; case N_REGION: /* cull empty if else */ if (n->out.len == 1 && n->in.len == 2 && IN(n,0)->op == N_PROJ && IN(n,1)->op == N_PROJ && CTRL(IN(n,0)) == CTRL(IN(n,1)) && CTRL(IN(n,0))->op == N_IF_ELSE) { assert(n->out.data[0]->op != N_PHI); assert(n->out.data[0]->in.data[0] == n); Node *new_ctrl = CTRL(CTRL(CTRL(n))); Node *out = n->out.data[0]; node_set_in(p, out, 0, new_ctrl); proc_opt_fwd(p, l, new_ctrl); return; } for (int i = 0; i < n->out.len; i++) { if (n->out.data[i]->op != N_PHI) { proc_opt_fwd(p, l, n->out.data[i]); } } break; default: break; } } void proc_opt(Proc *p, Lexer *l) { fprintf(stderr, "%ld\n", p->stop->in.len); if (p->stop->in.len == 0) { if (p->ret_type.t != T_NONE) { lex_error(l, LE_ERROR, str_fmt(&p->arena, "no return statement in function expecting %S", type_desc(&p->ret_type, &p->arena))); } } for (int i = 0; i < p->start->out.len; i++) { Node *n = p->start->out.data[i]; if (n->op == N_LIT && n->out.len < 1) { node_kill(n, p); i--; } } proc_opt_fwd(p, l, p->start); } Proc *parse_proc(Lexer *l, Unit *u) { int has_ret = l->tok == TOK_FUNC; DA_FIT(&u->procs, u->procs.len + 1); Proc *proc = &u->procs.data[u->procs.len++]; lex_expect(l, TM_IDENT); proc_init(proc, l->ident); scope_push(&proc->scope, proc); lex_expect(l, TM_LPAREN | TM_LBRACE); if (l->tok == TOK_LPAREN) { parse_args_list(l, proc); } if (has_ret) { proc->ret_type = parse_type(l, proc); } else { proc->ret_type = (Type) { .lvl = T_BOT, .t = T_NONE }; } lex_expected(l, TM_LBRACE); lex_next(l); while (l->tok != TOK_RBRACE) { lex_expected_not(l, TM_EOF); parse_stmt(l, proc); } scope_pop(&proc->scope, proc); lex_expected(l, TM_RBRACE); lex_next(l); proc_opt(proc, l); return proc; } void uninit_check(Lexer *l, Proc *p, Node *n, LexSpan pos) { if (node_uninit(n)) { lex_error_at(l, pos, LE_ERROR, str_fmt(&p->arena, "uninitialized %S", type_desc(&n->type, &p->arena))); } else if (node_maybe_uninit(n)) { lex_error_at(l, pos, LE_WARN, str_fmt(&p->arena, "possibly uninitialized %S", type_desc(&n->type, &p->arena))); } } Node *parse_term(Lexer *l, Proc *p, Type *twant) { (void)twant; /* to be used for .ENUM_TYPE and stuff */ Node *node = NULL; NodeType op_after = N_START; if (TMASK(l->tok) & (TM_MINUS | TM_PLUS | TM_NOT)) { Token t = l->tok; lex_next(l); node = parse_term(l, p, twant); NodeType post_op = N_START; switch (t) { case TOK_MINUS: post_op = N_OP_NEG; break; case TOK_NOT: post_op = N_OP_NOT; break; default: return node; } if (post_op == N_START) return node; return node_peephole(node_new(p, post_op, NULL, node), p, l); } if (l->tok == TOK_LPAREN) { lex_next(l); node = parse_expr(l, p, NULL); lex_expected(l, TM_RPAREN); lex_next(l); node->src_pos.ofs--; node->src_pos.n += 2; } else if (l->tok == TOK_IDENT) { NameBinding *b = scope_find(&p->scope, l->ident); if (b) { node = b->node; } else { lex_error(l, LE_ERROR, S("undeclared identifier")); } uninit_check(l, p, node, l->pos); lex_next(l); } else if (TMASK(l->tok) & (TM_TRUE | TM_FALSE)) { node = node_new_lit_bool(p, l->tok == TOK_TRUE); lex_next(l); } else { lex_expected(l, TM_LIT_NUM); int64_t val = 0; for (int i = 0; i < l->ident.n; i++) { if (!(l->ident.s[i] >= '0' && l->ident.s[i] <= '9')) { lex_error(l, LE_ERROR, S("not a digit")); break; } val = (val * 10) + (l->ident.s[i] - '0'); } node = node_new_lit_i64(p, val); node->src_pos = l->pos; lex_next(l); } if (op_after != N_START) { node = node_new(p, op_after, NULL, node_peephole(node, p, l)); } return node_peephole(node, p, l); } NodeType tok_to_bin_op(Token t) { switch (t) { case TOK_PLUS: return N_OP_ADD; break; case TOK_MINUS: return N_OP_SUB; break; case TOK_ASTERISK: return N_OP_MUL; break; case TOK_SLASH: return N_OP_DIV; break; case TOK_NOT: return N_OP_NOT; break; case TOK_AND: return N_OP_AND; break; case TOK_OR: return N_OP_OR; break; case TOK_XOR: return N_OP_XOR; break; case TOK_SHL: return N_OP_SHL; break; case TOK_SHR: return N_OP_SHR; break; case TOK_EQL: return N_CMP_EQL; break; case TOK_NEQ: return N_CMP_NEQ; break; case TOK_LES: return N_CMP_LES; break; case TOK_GTR: return N_CMP_GTR; break; case TOK_LTE: return N_CMP_LTE; break; case TOK_GTE: return N_CMP_GTE; break; default: return N_START; break; } } /* TODO: operator precedence would be kinda nice actually, sad to say */ Node *parse_expr(Lexer *l, Proc *p, Type *twant) { LexSpan pos = l->pos; Node *lhs = parse_term(l, p, twant); NodeType nt = tok_to_bin_op(l->tok);; if (lhs->refs <= 0) lex_error(l, LE_ERROR, S("dead lhs")); assert(lhs->refs > 0); if (nt != N_START) { lex_next(l); /* necessary because if lhs is a deduplicated literal, it may be an input to rhs * and therefore culled by peephole optimizations */ Node *rhs; NODE_KEEP(p, lhs, { rhs = parse_expr(l, p, &lhs->type); }); lhs = node_peephole(node_new(p, nt, NULL, lhs, rhs), p, l); } lhs->src_pos = (LexSpan) { pos.ofs, l->pos.ofs - pos.ofs }; if (twant) type_expected(twant, lhs, l); return lhs; } void parse_toplevel(Lexer *l, Unit *u) { switch (l->tok) { case TOK_PROC: case TOK_FUNC: parse_proc(l, u); break; default: lex_expected(l, TM_PROC | TM_FUNC); break; } } void unit_print(Unit *u); void parse_unit(Lexer *l) { Unit u = { 0 }; unit_init(&u); while (l->tok != TOK_EOF) { parse_toplevel(l, &u); } unit_print(&u); unit_free(&u); } /* graph output */ /* TODO: Print at every stage of graph compilation and generate a frame for * each, so problems can be debugged visually as they occur. */ void node_print(Node *n, Proc *p) { if (n->walked == 1) return; n->walked = 1; if (n->op == N_START) { Str s = S(""); int c = n->val.tuple.len; Value *v = n->val.tuple.data; for (int i = 0; i < c; i++) { if (i > 0) str_cat(&s, S(", "), &p->arena); str_cat(&s, type_desc(&v[i].type, &p->arena), &p->arena); } printf("\t%d [label=\"start(%.*s)\"]", n->id, (int)s.n, s.s); } else if (n->op == N_LIT) { switch (n->type.t) { case T_INT: printf("\t%d [label=\"%ld\"]", n->id, n->val.i); break; case T_BOOL: printf("\t%d [label=\"%s\"]", n->id, n->val.i ? "true" : "false"); break; case T_NONE: if (n->type.lvl == T_XCTRL) { printf("\t%d [label=\"~ctrl\"]", n->id); break; } /* fallthrough */ default: printf("\t%d [label=\"literal %d\"]", n->id, n->id); break; } } else if (n->op == N_PROJ) { Str d = type_desc(&n->in.data[0]->val.tuple.data[n->val.i].type, &p->arena); printf("\t%d [label=\"%.*s(%ld)\", shape=record]", n->id, (int)d.n, d.s, n->val.i); } else if (n->op == N_UNINIT) { Str s = type_desc(&n->type, &p->arena); printf("\t%d [label=\"uninitialized %.*s\", shape=record]", n->id, (int)s.n, s.s); } else { printf("\t%d [label=\"%s\", shape=record]", n->id, node_type_name(n->op)); } printf("\n"); for (int i = 0; i < n->out.len; i++) { Node *o = n->out.data[i]; if (o->op == N_LIT) { printf("\t%d -> %d [style=dashed]\n", n->id, o->id); } else { int j; for (j = 0; j < o->in.len && o->in.data[j] != n; j++); if (j == 0) { printf("\t%d -> %d [color=red,headlabel=%d]\n", n->id, o->id, j); } else { printf("\t%d -> %d [headlabel=%d]\n", n->id, o->id, j); } } } for (int i = 0; i < n->out.len; i++) { node_print(n->out.data[i], p); } } void proc_print(Proc *p) { if (p->start) { Str d = type_desc(&p->ret_type, &p->arena); printf("\t\"%.*s %.*s\" -> %d\n", (int)p->name.n, p->name.s, (int)d.n, d.s, p->start->id); node_print(p->start, p); if (no_opt) { for (NameBinding *b = p->scope.free_bind; b; b = b->prev) { uint64_t id = (uintptr_t)b->node; printf("\t\t%lu [label=\"%.*s\",shape=none,fontcolor=blue]\n", id, (int)b->name.n, b->name.s); printf("\t\t%lu -> %d [arrowhead=none,style=dotted,color=blue]\n", id, b->node->id); } } } } void unit_print(Unit *u) { puts("digraph {"); for (int i = 0; i < u->procs.len; i++) { proc_print(&u->procs.data[i]); } puts("}"); } /* main */ int main(int argc, const char **argv) { if (argc != 2) { fprintf(stderr, "Usage: %s FILE\n", argv[0]); return 1; } Lexer l = { 0 }; lex_start(&l, argv[1]); parse_unit(&l); lex_free(&l); return 0; }