#include #include #include #include #include #include #include #include "lex.h" #include "arena.h" #include "dynarr.h" int no_opt = 1; /* node graph */ typedef enum { N_START, N_RETURN, N_KEEPALIVE, N_LIT, N_OP_ADD, N_OP_SUB, N_OP_MUL, N_OP_DIV, N_OP_AND, N_OP_OR, N_OP_XOR, N_OP_SHL, N_OP_SHR, N_OP_NEG, N_OP_NOT, N_VALUE } NodeType; const char *node_type_name[] = { "start", "return", "keepalive", "literal", "add", "sub", "mul", "div", "and", "or", "xor", "lshift", "rshift", "neg", "not", "value" }; typedef enum { T_BOT, T_TOP, T_CONST, T_INT } Type; typedef struct { Type type; union { int64_t i; uint64_t u; }; } Value; typedef struct Node { union { struct Node *prev_free; struct { int id, refs; int walked; NodeType type; LexSpan src_pos; DYNARR(struct Node *) in, out; Value val; }; }; } Node; typedef struct NameBinding { struct NameBinding *prev; LexSpan src_pos; Str name; Node *node; } NameBinding; typedef struct ScopeFrame { struct ScopeFrame *prev; NameBinding *latest; } ScopeFrame; typedef struct { ScopeFrame *tail, *free_scope; NameBinding *free_bind; } Scope; typedef struct { Arena arena; Str name; Node *start, *stop, *keepalive; Node *free_list; Scope scope; } Proc; typedef struct { DYNARR(Proc) procs; } Unit; void unit_init(Unit *u) { (void)u; } void unit_fini(Unit *u) { free(u->procs.data); } void node_kill(Node *n, Proc *p); void node_die(Node *n, Proc *p) { /*n->prev_free = p->free_list; p->free_list = n;*/ } void node_del_out(Node *n, Node *p) { for (int i = 0; i < n->out.len; i++) { if (n->out.data[i] == p) { p->refs--; if (i + 1 < n->out.len) { memmove(&n->out.data[i], &n->out.data[i + 1], sizeof(Node*) * (n->out.len - i - 1)); } n->out.len--; i--; } } } void node_del_in(Node *n, Node *p) { for (int i = 0; i < n->in.len; i++) { if (n->in.data[i] == p) { p->refs--; if (i + 1 < n->in.len) { memmove(&n->in.data[i], &n->in.data[i + 1], sizeof(Node*) * (n->in.len - i - 1)); } n->in.len--; i--; } } } void node_kill(Node *n, Proc *p) { for (int i = 0; i < n->in.len; i++) { node_del_out(n->in.data[i], n); n->in.data[i]->refs--; if (n->in.data[i]->out.len < 1) node_kill(n->in.data[i], p); } for (int i = 0; i < n->out.len; i++) { node_del_in(n->out.data[i], n); n->out.data[i]->refs--; if (n->out.data[i]->refs < 1) node_die(n->out.data[i], p); } n->in.len = 0; n->out.len = 0; node_die(n, p); } void node_add(Proc *p, Node *src, Node *dest) { ZDA_PUSH(&src->out, dest, &p->arena); ZDA_PUSH(&dest->in, src, &p->arena); src->refs++; dest->refs++; if (dest->src_pos.n == 0) dest->src_pos = src->src_pos; } void node_remove(Proc *p, Node *src, Node *dest) { node_del_out(src, dest); node_del_in(dest, src); if (dest->refs < 1) node_die(dest, p); if (src->refs < 1) node_die(src, p); else if (src->out.len < 1) node_kill(src, p); } static int global_node_count = 0; Node *node_new_empty(Proc *p, NodeType t) { Node *n; if (p->free_list) { n = p->free_list; p->free_list = n->prev_free; memset(n, 0, sizeof(Node)); } else { n = new(&p->arena, Node); } n->type = t; n->id = global_node_count++; return n; } Node *node_new(Proc *p, NodeType t, ...) { Node *node = node_new_empty(p, t); va_list ap; va_start(ap, t); for (;;) { Node *n = va_arg(ap, Node *); if (!n) break; node_add(p, n, node); } va_end(ap); return node; } #define node_new(...) node_new(__VA_ARGS__, NULL) void node_print(Node *n) { if (n->type == N_LIT) { printf("\t%d [label=\"%ld\"]\n", n->id, n->val.i); } else { printf("\t%d [label=\"%s\", shape=record]\n", n->id, node_type_name[n->type]); } if (n->walked) { return; } n->walked = 1; for (int i = 0; i < n->out.len; i++) { if (n->out.data[i]->type == N_LIT) { printf("\t%d -> %d [style=dashed]\n", n->id, n->out.data[i]->id); } else { printf("\t%d -> %d\n", n->id, n->out.data[i]->id); } } for (int i = 0; i < n->out.len; i++) { node_print(n->out.data[i]); } } void proc_print(Proc *p) { if (p->start) { printf("\t\"%.*s\" -> %d\n", (int)p->name.n, p->name.s, p->start->id); node_print(p->start); if (no_opt) { for (NameBinding *b = p->scope.free_bind; b; b = b->prev) { uint64_t id = (uintptr_t)b->node; printf("\t\t%lu [label=\"%.*s\",shape=none,fontcolor=blue]\n", id, (int)b->name.n, b->name.s); printf("\t\t%lu -> %d [arrowhead=none,style=dotted,color=blue]\n", id, b->node->id); } } } } void unit_print(Unit *u) { puts("digraph {"); for (int i = 0; i < u->procs.len; i++) { proc_print(&u->procs.data[i]); } puts("}"); } Node *node_dedup_lit(Proc *p, Value v) { /* TODO: this is probably real inefficient for large procedure graphs, * but does it matter? how many nodes are direct children of the start node? * how many literals even usually occur in a procedure? */ for (int i = 0; i < p->start->out.len; i++) { Node *t = p->start->out.data[i]; if (t->type == N_LIT && t->val.type == v.type && t->val.i == v.i) { fprintf(stderr, "deduplicated a node\n"); return t; } } return NULL; } Node *node_new_lit_i64(Proc *p, int64_t i) { Value v = (Value) { T_INT, { .i = i } }; Node *t = node_dedup_lit(p, v); if (t) return t; Node *n = node_new(p, N_LIT, p->start); n->val = v; return n; } static inline int node_op_communative(NodeType t) { NodeType ops[] = { N_OP_ADD, N_OP_MUL, N_OP_AND, N_OP_XOR, N_OP_OR }; for (unsigned i = 0; i < sizeof ops / sizeof *ops; i++) { if (ops[i] == t) return 1; } return 0; } Value node_compute(Node *n, Lexer *l) { Type lit_type = T_BOT; Node **in = n->in.data; for (int i = 0; i < n->in.len; i++) { Node *p = in[i]; if (p->type != N_LIT) break; if (p->val.type != lit_type) { if (lit_type == T_BOT) { lit_type = p->val.type; } else { lit_type = T_BOT; break; } } } if (lit_type == T_INT) { Value v = { .type = lit_type }; switch (n->type) { case N_OP_NEG: v.i = -in[0]->val.i; break; case N_OP_NOT: v.i = ~in[0]->val.i; break; case N_OP_ADD: v.i = in[0]->val.i + in[1]->val.i; break; case N_OP_SUB: v.i = in[0]->val.i - in[1]->val.i; break; case N_OP_MUL: v.i = in[0]->val.i * in[1]->val.i; break; case N_OP_DIV: if (in[1]->val.i == 0) { lex_error_at(l, in[1]->src_pos, LE_ERROR, S("divisor always evaluates to zero")); } v.i = in[0]->val.i / in[1]->val.i; break; case N_OP_AND: v.i = in[0]->val.i & in[1]->val.i; break; case N_OP_OR: v.i = in[0]->val.i | in[1]->val.i; break; case N_OP_XOR: v.i = in[0]->val.i ^ in[1]->val.i; break; case N_OP_SHL: v.i = in[0]->val.u << in[1]->val.u; break; case N_OP_SHR: v.i = in[0]->val.u >> in[1]->val.u; break; default: return n->val; } return v; } return n->val; } /* needs lexer for error reporting */ Node *node_peephole(Node *n, Proc *p, Lexer *l) { if (no_opt) return n; if (n->type != N_LIT) { Value v = node_compute(n, l); if (v.type > T_CONST) { node_kill(n, p); Node *t = node_dedup_lit(p, v); if (t) return t; Node *r = node_new(p, N_LIT, p->start); r->val = v; r->src_pos = n->src_pos; return r; } } Node **in = n->in.data; /* TODO: figure out to do peepholes recursively, without fucking up the graph or having to clone everything */ if (node_op_communative(n->type)) { /* transformations to help encourage constant folding */ /* the overall trend is to move them rightwards */ if (in[0]->type == N_LIT && in[1]->type == n->type && in[1]->in.data[0]->type != N_LIT && in[1]->in.data[1]->type == N_LIT) { /* op(lit, op(X, lit)) -> op(X, op(lit, lit)) */ Node *tmp = in[1]->in.data[0]; in[1]->in.data[0] = in[0]; in[0] = tmp; /* TODO: ...would it break anything at all to just do in[1] = node_peephole(in[1], p, l)? * probably not, right? */ } else if (in[0]->type == n->type && in[0]->in.data[0]->type != N_LIT && in[0]->in.data[1]->type == N_LIT && in[1]->type != N_LIT) { /* op(op(X, lit), Y) -> op(op(X, Y), lit) */ Node *tmp = in[0]->in.data[1]; in[0]->in.data[1] = in[1]; in[1] = tmp; } else if (in[0]->type == N_LIT && in[1]->type != N_LIT) { /* op(lit, X) -> op(X, lit) */ Node *tmp = in[0]; in[0] = in[1]; in[1] = tmp; } } return n; } /* scope */ NameBinding *scope_find(Scope *scope, Str name) { for (ScopeFrame *f = scope->tail; f; f = f->prev) { for (NameBinding *b = f->latest; b; b = b->prev) { if (str_eql(b->name, name)) { return b; } } } return NULL; } ScopeFrame *scope_push(Scope *scope, Proc *proc) { ScopeFrame *f; if (scope->free_scope) { f = scope->free_scope; *f = (ScopeFrame) { 0 }; scope->free_scope = f->prev; } else { f = new(&proc->arena, ScopeFrame); } f->prev = scope->tail; scope->tail = f; return f; } ScopeFrame *scope_pop(Scope *scope, Proc *proc) { ScopeFrame *f = scope->tail; scope->tail = f->prev; f->prev = scope->free_scope; scope->free_scope = f; for (NameBinding *b = f->latest; b; ) { NameBinding *p = b->prev; b->prev = scope->free_bind; scope->free_bind = b; node_remove(proc, b->node, proc->keepalive); b = p; } return scope->tail; } /* returns previous value */ NameBinding *scope_bind(Scope *scope, Str name, Node *value, LexSpan pos, Proc *proc) { NameBinding *prev = scope_find(scope, name); NameBinding *b; if (scope->free_bind) { b = scope->free_bind; *b = (NameBinding) { 0 }; scope->free_bind = b->prev; } else { b = new(&proc->arena, NameBinding); } b->name = name; b->prev = scope->tail->latest; scope->tail->latest = b; b->node = value; b->src_pos = pos; node_add(proc, value, proc->keepalive); return prev; } /* parsing */ Node *parse_expr(Lexer *l, Proc *p); void parse_return(Lexer *l, Proc *p) { lex_next(l); p->stop = node_new(p, N_RETURN, p->start, parse_expr(l, p)); } void parse_let(Lexer *l, Proc *p) { recurse: lex_expect(l, TM_IDENT); Str name = l->ident; LexSpan pos = l->pos; lex_expect(l, TM_EQUALS); lex_next(l); Node *rhs = parse_expr(l, p); NameBinding *b = scope_bind(&p->scope, name, rhs, pos, p); if (b) { lex_error_at(l, pos, LE_WARN, S("shadowing previous declaration")); lex_error_at(l, b->src_pos, LE_WARN, S("declared here")); } if (l->tok == TOK_COMMA) goto recurse; } void parse_stmt(Lexer *l, Proc *p); /* TODO: return node from this! */ void parse_block(Lexer *l, Proc *p) { lex_next(l); scope_push(&p->scope, p); while (l->tok != TOK_RBRACE) { lex_expected_not(l, TM_EOF); parse_stmt(l, p); } scope_pop(&p->scope, p); lex_expected(l, TM_RBRACE); lex_next(l); } void parse_stmt(Lexer *l, Proc *p) { /* TODO */ (void)l; switch (l->tok) { case TOK_RETURN: parse_return(l, p); break; case TOK_LET: parse_let(l, p); break; case TOK_LBRACE: parse_block(l, p); break; default: lex_expected(l, TM_RBRACE); break; } } Proc *parse_proc(Lexer *l, Unit *u) { DA_FIT(&u->procs, u->procs.len + 1); Proc *proc = &u->procs.data[u->procs.len++]; memset(proc, 0, sizeof(Proc)); proc->start = node_new_empty(proc, N_START); proc->keepalive = node_new_empty(proc, N_KEEPALIVE); lex_expect(l, TM_IDENT); proc->name = l->ident; lex_expect(l, TM_LBRACE); lex_next(l); scope_push(&proc->scope, proc); while (l->tok != TOK_RBRACE) { lex_expected_not(l, TM_EOF); parse_stmt(l, proc); } scope_pop(&proc->scope, proc); lex_expected(l, TM_RBRACE); lex_next(l); return proc; } Node *parse_term(Lexer *l, Proc *p) { Node *node = NULL; NodeType op_after = N_START; if (TMASK(l->tok) & (TM_MINUS | TM_PLUS | TM_NOT)) { switch (l->tok) { case TOK_MINUS: op_after = N_OP_NEG; break; case TOK_NOT: op_after = N_OP_NOT; break; default: break; } lex_next(l); } if (l->tok == TOK_LPAREN) { lex_next(l); node = parse_expr(l, p); lex_expected(l, TM_RPAREN); lex_next(l); } else if (l->tok == TOK_IDENT) { NameBinding *b = scope_find(&p->scope, l->ident); if (b) { node = b->node; } else { lex_error(l, LE_ERROR, S("undeclared identifier")); } lex_next(l); } else { lex_expected(l, TM_LIT_NUM); int64_t val = 0; for (int i = 0; i < l->ident.n; i++) { if (!(l->ident.s[i] >= '0' && l->ident.s[i] <= '9')) { lex_error(l, LE_ERROR, S("not a digit")); break; } val = (val * 10) + (l->ident.s[i] - '0'); } node = node_new_lit_i64(p, val); node->src_pos = l->pos; lex_next(l); } if (op_after != N_START) { node = node_new(p, op_after, node_peephole(node, p, l)); } return node_peephole(node, p, l); } /* TODO: operator precedence would be kinda nice actually, sad to say */ Node *parse_expr(Lexer *l, Proc *p) { LexSpan pos = l->pos; Node *lhs = parse_term(l, p); if (TMASK(l->tok) & (TM_PLUS | TM_MINUS | TM_ASTERISK | TM_SLASH | TM_NOT | TM_AND | TM_XOR | TM_OR | TM_SHL | TM_SHR)) { Token t = l->tok; lex_next(l); Node *rhs = parse_expr(l, p); NodeType nt = N_OP_ADD; switch (t) { case TOK_PLUS: nt = N_OP_ADD; break; case TOK_MINUS: nt = N_OP_SUB; break; case TOK_ASTERISK: nt = N_OP_MUL; break; case TOK_SLASH: nt = N_OP_DIV; break; case TOK_NOT: nt = N_OP_NOT; break; case TOK_AND: nt = N_OP_AND; break; case TOK_OR: nt = N_OP_OR; break; case TOK_XOR: nt = N_OP_XOR; break; case TOK_SHL: nt = N_OP_SHL; break; case TOK_SHR: nt = N_OP_SHR; break; default: break; } lhs = node_new(p, nt, lhs, rhs); } Node *n = node_peephole(lhs, p, l); n->src_pos = (LexSpan) { pos.ofs, l->pos.ofs - pos.ofs }; return n; } void parse_toplevel(Lexer *l, Unit *u) { switch (l->tok) { case TOK_PROC: parse_proc(l, u); break; default: lex_expected(l, TM_PROC); break; } } void parse_unit(Lexer *l) { Unit u = { 0 }; unit_init(&u); while (l->tok != TOK_EOF) { parse_toplevel(l, &u); } unit_print(&u); unit_fini(&u); } int main(int argc, const char **argv) { if (argc != 2) { fprintf(stderr, "Usage: %s FILE\n", argv[0]); return 1; } Lexer l = { 0 }; lex_start(&l, argv[1]); parse_unit(&l); lex_free(&l); return 0; }