#include #include #include #include #include #include #include #include "lex.h" #include "arena.h" #include "dynarr.h" /* node graph */ typedef enum { N_START, N_RETURN, N_LIT, N_OP_ADD, N_OP_SUB, N_OP_MUL, N_OP_DIV, N_OP_AND, N_OP_OR, N_OP_XOR, N_OP_NEG, N_OP_NOT } NodeType; const char *node_type_name[] = { "start", "return", "integer literal", "add", "sub", "mul", "div", "and", "or", "xor", "neg", "not", }; typedef struct Node { union { struct Node *prev_free; struct { int id, refs; int walked; NodeType type; LexSpan src_pos; DYNARR(struct Node *) in, out; union { struct { int64_t i; } lit; }; }; }; } Node; typedef struct NameBinding { Str name; struct NameBinding *prev; } NameBinding; typedef struct { NameBinding *last; } Scope; typedef struct { Arena arena; Str name; Node *start, *stop; Node *free_list; } Proc; typedef struct { DYNARR(Proc) procs; } Unit; void unit_init(Unit *u) { (void)u; } void unit_fini(Unit *u) { free(u->procs.data); } void node_kill(Node *n, Proc *p); void node_die(Node *n, Proc *p) { n->prev_free = p->free_list; p->free_list = n; } void node_del_out(Node *n, Node *p) { for (int i = 0; i < n->out.len; i++) { if (n->out.data[i] == p) { p->refs--; if (i + 1 < n->out.len) { memmove(&n->out.data[i], &n->out.data[i + 1], sizeof(Node*) * (n->out.len - i - 1)); } n->out.len--; i--; } } } void node_del_in(Node *n, Node *p) { for (int i = 0; i < n->in.len; i++) { if (n->in.data[i] == p) { p->refs--; if (i + 1 < n->in.len) { memmove(&n->in.data[i], &n->in.data[i + 1], sizeof(Node*) * (n->in.len - i - 1)); } n->in.len--; i--; } } } void node_kill(Node *n, Proc *p) { for (int i = 0; i < n->in.len; i++) { node_del_out(n->in.data[i], n); n->in.data[i]->refs--; if (n->in.data[i]->out.len < 1) node_kill(n->in.data[i], p); } for (int i = 0; i < n->out.len; i++) { node_del_in(n->out.data[i], n); n->out.data[i]->refs--; if (n->out.data[i]->refs < 1) node_die(n->out.data[i], p); } n->in.len = 0; n->out.len = 0; node_die(n, p); } void node_add(Proc *p, Node *src, Node *dest) { ZDA_PUSH(&src->out, dest, &p->arena); ZDA_PUSH(&dest->in, src, &p->arena); src->refs++; dest->refs++; if (dest->src_pos.n == 0) dest->src_pos = src->src_pos; else { LexSpan *sp = &src->src_pos, *dp = &dest->src_pos; if (sp->ofs < dp->ofs) { dp->n = (dp->ofs + dp->n) - sp->ofs; dp->ofs = sp->ofs; } if ((sp->ofs + sp->n) > (dp->ofs + dp->n)) { dp->n = (sp->ofs + sp->n) - dp->ofs; } } } static int global_node_count = 0; Node *node_new_empty(Proc *p, NodeType t) { Node *n; if (p->free_list) { n = p->free_list; p->free_list = n->prev_free; memset(n, 0, sizeof(Node)); } else { n = new(&p->arena, Node); } n->type = t; n->id = global_node_count++; return n; } Node *node_new(Proc *p, NodeType t, ...) { Node *node = node_new_empty(p, t); va_list ap; va_start(ap, t); for (;;) { Node *n = va_arg(ap, Node *); if (!n) break; node_add(p, n, node); } va_end(ap); return node; } #define node_new(...) node_new(__VA_ARGS__, NULL) void node_print(Node *n) { if (n->type == N_LIT) { printf("\t%d [label=\"%ld\"]\n", n->id, n->lit.i); } else { printf("\t%d [label=\"%s\", shape=record]\n", n->id, node_type_name[n->type]); } if (n->walked) { return; } n->walked = 1; for (int i = 0; i < n->out.len; i++) { if (n->out.data[i]->type == N_LIT) { printf("\t%d -> %d [style=dashed]\n", n->id, n->out.data[i]->id); } else { printf("\t%d -> %d\n", n->id, n->out.data[i]->id); } } for (int i = 0; i < n->out.len; i++) { node_print(n->out.data[i]); } } void proc_print(Proc *p) { if (p->start) { printf("\t\"%.*s\" -> %d\n", (int)p->name.n, p->name.s, p->start->id); node_print(p->start); } } void unit_print(Unit *u) { puts("digraph {"); for (int i = 0; i < u->procs.len; i++) { proc_print(&u->procs.data[i]); } puts("}"); } Node *node_new_lit_i64(Proc *p, int64_t i) { Node *n = node_new(p, N_LIT, p->start); n->lit.i = i; return n; } static inline int node_op_communative(NodeType t) { NodeType ops[] = { N_OP_ADD, N_OP_MUL, N_OP_AND, N_OP_XOR, N_OP_OR }; for (unsigned i = 0; i < sizeof ops / sizeof *ops; i++) { if (ops[i] == t) return 1; } return 0; } /* needs lexer for error reporting */ Node *node_peephole(Node *n, Proc *p, Lexer *l) { return n; /* * Is there a method to convey these transformations a little more nicely? * * add(lit, lit) -> lit * add(lit, add(N, lit)) -> add(N, add(lit, lit)) * */ Node **in = n->in.data; Node *r = n; int all_lit = 1; for (int i = 0; i < n->in.len; i++) { if (n->in.data[i]->type != N_LIT) all_lit = 0; } if (all_lit) { switch (n->type) { case N_OP_NEG: r = node_new_lit_i64(p, -in[0]->lit.i); break; case N_OP_NOT: r = node_new_lit_i64(p, ~in[0]->lit.i); break; case N_OP_ADD: r = node_new_lit_i64(p, in[0]->lit.i + in[1]->lit.i); break; case N_OP_SUB: r = node_new_lit_i64(p, in[0]->lit.i - in[1]->lit.i); break; case N_OP_MUL: r = node_new_lit_i64(p, in[0]->lit.i * in[1]->lit.i); break; case N_OP_DIV: if (in[1]->lit.i == 0) { lex_error_at(l, n->src_pos, LE_ERROR, S("Division by zero")); } r = node_new_lit_i64(p, in[0]->lit.i / in[1]->lit.i); break; case N_OP_AND: r = node_new_lit_i64(p, in[0]->lit.i & in[1]->lit.i); break; case N_OP_OR: r = node_new_lit_i64(p, in[0]->lit.i | in[1]->lit.i); break; case N_OP_XOR: r = node_new_lit_i64(p, in[0]->lit.i ^ in[1]->lit.i); break; default: break; } } else { if (node_op_communative(n->type) && in[0]->type == N_LIT && in[1]->type == n->type && in[1]->in.data[0]->type != N_LIT && in[1]->in.data[1]->type == N_LIT) { /* op(lit, op(X, lit)) -> op(X, op(lit, lit)) */ Node *tmp = in[1]->in.data[0]; in[1]->in.data[0] = in[0]; in[0] = tmp; } } if (r != n) { r->src_pos = n->src_pos; if (n->out.len == 0) node_kill(n, p); } return r; } /* parsing */ Node *parse_expr(Lexer *l, Proc *p); Node *parse_return(Lexer *l, Proc *p) { lex_next(l); return node_new(p, N_RETURN, p->start, parse_expr(l, p)); } Node *parse_stmt(Lexer *l, Proc *p) { /* TODO */ (void)l; switch (l->tok) { case TOK_RETURN: return parse_return(l, p); break; default: lex_expected(l, TM_RBRACE); return NULL; } } Proc *parse_proc(Lexer *l, Unit *u) { DA_FIT(&u->procs, u->procs.len + 1); Proc *proc = &u->procs.data[u->procs.len++]; memset(proc, 0, sizeof(Proc)); proc->start = node_new_empty(proc, N_START); lex_expect(l, TM_IDENT); proc->name = l->ident; lex_expect(l, TM_LBRACE); lex_next(l); while (l->tok != TOK_RBRACE) { lex_expected_not(l, TM_EOF); proc->stop = parse_stmt(l, proc); } lex_expected(l, TM_RBRACE); lex_next(l); return proc; } Node *parse_term(Lexer *l, Proc *p) { Node *node = NULL; int negate_after = l->tok == TOK_MINUS; if (TMASK(l->tok) & (TM_MINUS | TM_PLUS)) { lex_next(l); } if (l->tok == TOK_LPAREN) { lex_next(l); node = parse_expr(l, p); lex_expected(l, TM_RPAREN); lex_next(l); } else { lex_expected(l, TM_LIT_NUM); int64_t val = 0; for (int i = 0; i < l->ident.n; i++) { if (!(l->ident.s[i] >= '0' && l->ident.s[i] <= '9')) { lex_error(l, LE_ERROR, S("Not a digit")); break; } val = (val * 10) + (l->ident.s[i] - '0'); } node = node_new_lit_i64(p, val); node->src_pos = l->pos; lex_next(l); } if (negate_after) { node = node_new(p, N_OP_NEG, node); } return node_peephole(node, p, l); } Node *parse_expr(Lexer *l, Proc *p) { Node *lhs = parse_term(l, p); if (l->tok == TOK_LPAREN) { lex_next(l); puts("args_start"); for (;;) { parse_expr(l, p); lex_expected(l, TM_COMMA | TM_RPAREN); if (l->tok == TOK_RPAREN) break; lex_next(l); } lex_next(l); puts("args_end"); puts("func_call"); } if (TMASK(l->tok) & (TM_PLUS | TM_MINUS | TM_ASTERISK | TM_SLASH | TM_NOT | TM_AND | TM_XOR | TM_OR)) { Token t = l->tok; lex_next(l); Node *rhs = parse_expr(l, p); NodeType nt = N_OP_ADD; switch (t) { case TOK_PLUS: nt = N_OP_ADD; break; case TOK_MINUS: nt = N_OP_SUB; break; case TOK_ASTERISK: nt = N_OP_MUL; break; case TOK_SLASH: nt = N_OP_DIV; break; case TOK_NOT: nt = N_OP_NOT; break; case TOK_AND: nt = N_OP_AND; break; case TOK_OR: nt = N_OP_OR; break; case TOK_XOR: nt = N_OP_XOR; break; default: break; } lhs = node_new(p, nt, lhs, rhs); } return node_peephole(lhs, p, l); } void parse_toplevel(Lexer *l, Unit *u) { switch (l->tok) { case TOK_PROC: parse_proc(l, u); break; default: lex_expected(l, TM_PROC); break; } } void parse_unit(Lexer *l) { Unit u = { 0 }; unit_init(&u); while (l->tok != TOK_EOF) { parse_toplevel(l, &u); } unit_print(&u); unit_fini(&u); } int main(int argc, const char **argv) { if (argc != 2) { fprintf(stderr, "Usage: %s FILE\n", argv[0]); return 1; } Lexer l = { 0 }; lex_start(&l, argv[1]); parse_unit(&l); lex_free(&l); return 0; }