From 739b4852d2a826ba2985c7db2f5c778050f72250 Mon Sep 17 00:00:00 2001 From: WormHeamer Date: Thu, 7 Aug 2025 02:47:23 -0400 Subject: preliminary peephole optimization of if statements --- ir.c | 133 +++++++++++++++++++++++++++++++++++++++----------------------- ir.h | 8 ++-- main.c | 70 +++++++++++++++++++-------------- test.lang | 4 +- 4 files changed, 131 insertions(+), 84 deletions(-) diff --git a/ir.c b/ir.c index dd8b66c..bd3354b 100644 --- a/ir.c +++ b/ir.c @@ -11,8 +11,9 @@ extern int no_opt; #define IN(n, i) ((n)->in.data[i]) #define OUT(n, i) ((n)->out.data[i]) -#define CAR(n) IN(n, 0) -#define CDR(n) IN(n, 1) +#define CTRL(n) IN(n, 0) +#define CAR(n) IN(n, 1) +#define CDR(n) IN(n, 2) #define CAAR(n) CAR(CAR(n)) #define CADR(n) CDR(CAR(n)) #define CDAR(n) CAR(CDR(n)) @@ -50,6 +51,11 @@ int value_eql(Value *a, Value *b) { Str type_desc(Type *t, Arena *arena) { (void)arena; + switch (t->lvl) { + case T_CTRL: return S("ctrl"); + case T_XCTRL: return S("~ctrl"); + default: break; + } switch (t->t) { case T_TUPLE: return S("tuple"); case T_BOOL: return S("bool"); @@ -140,7 +146,7 @@ void node_die(Node *n, Proc *p) { void node_del_out(Node *n, Node *p) { for (int i = n->out.len - 1; i >= 0; i--) { if (n->out.data[i] == p) { - p->refs--; + if (p) p->refs--; n->out.len--; if (i < n->out.len) { n->out.data[i] = n->out.data[n->out.len]; @@ -153,7 +159,7 @@ void node_del_out(Node *n, Node *p) { void node_del_in(Node *n, Node *p) { for (int i = n->in.len - 1; i >= 0; i--) { if (n->in.data[i] == p) { - p->refs--; + if (p) p->refs--; n->in.len--; if (i < n->in.len) { memmove(&n->in.data[i], &n->in.data[i + 1], sizeof(Node*) * (n->in.len - i)); @@ -164,7 +170,6 @@ void node_del_in(Node *n, Node *p) { } void node_kill(Node *n, Proc *p) { - assert(n->refs > 0); while (n->refs > 0 && n->in.len > 0) node_remove(p, n->in.data[0], n); while (n->refs > 0 && n->out.len > 0) node_remove(p, n, n->out.data[0]); assert(n->refs == 0); @@ -172,17 +177,18 @@ void node_kill(Node *n, Proc *p) { void node_add_out(Proc *p, Node *a, Node *b) { ZDA_PUSH(&p->arena, &a->out, b); - b->refs++; + if (b) b->refs++; } void node_add_in(Proc *p, Node *a, Node *b) { ZDA_PUSH(&p->arena, &a->in, b); - b->refs++; + if (b) b->refs++; } void node_add(Proc *p, Node *src, Node *dest) { - node_add_out(p, src, dest); node_add_in(p, dest, src); + if (!src) return; + node_add_out(p, src, dest); if (dest->src_pos.n == 0) dest->src_pos = src->src_pos; else if (src->src_pos.n != 0) { int lo = dest->src_pos.ofs < src->src_pos.ofs ? dest->src_pos.ofs : src->src_pos.ofs; @@ -192,10 +198,12 @@ void node_add(Proc *p, Node *src, Node *dest) { } void node_remove(Proc *p, Node *src, Node *dest) { - node_del_out(src, dest); node_del_in(dest, src); if (dest->refs < 1) node_die(dest, p); - if (src->out.len < 1) node_kill(src, p); + if (src) { + node_del_out(src, dest); + if (src->out.len < 1) node_kill(src, p); + } } static int global_node_count = 0; @@ -215,10 +223,11 @@ Node *node_new_empty(Proc *p, NodeType t) { } int type_check(Node *); -Node *node_newv(Proc *p, NodeType t, ...) { +Node *node_newv(Proc *p, NodeType t, Node *ctrl, ...) { Node *node = node_new_empty(p, t); va_list ap; - va_start(ap, t); + va_start(ap, ctrl); + node_add(p, ctrl, node); for (;;) { Node *n = va_arg(ap, Node *); if (!n) break; @@ -244,7 +253,7 @@ Node *node_dedup_lit(Proc *p, Value v) { Node *node_new_lit(Proc *p, Value v) { Node *t = node_dedup_lit(p, v); if (t) return t; - Node *n = node_new(p, N_LIT, p->start); + Node *n = node_new(p, N_LIT, NULL, p->start); n->val = v; return n; } @@ -323,9 +332,8 @@ static inline int node_cmp_incompat(NodeType a, NodeType b) { Value node_compute(Node *n, Lexer *l) { Type lit_type = { .lvl = T_BOT }; - Node **in = n->in.data; - for (int i = 0; i < n->in.len; i++) { - Node *p = in[i]; + for (int i = 1; i < n->in.len; i++) { + Node *p = IN(n, i); if (p->val.type.lvl != T_CONST) { lit_type.lvl = T_BOT; break; @@ -346,39 +354,39 @@ Value node_compute(Node *n, Lexer *l) { if (lit_type.t == T_INT) { switch (n->type) { - case N_OP_NEG: v.i = -in[0]->val.i; break; - case N_OP_NOT: v.i = ~in[0]->val.i; break; - case N_OP_ADD: v.i = in[0]->val.i + in[1]->val.i; break; - case N_OP_SUB: v.i = in[0]->val.i - in[1]->val.i; break; - case N_OP_MUL: v.i = in[0]->val.i * in[1]->val.i; break; + case N_OP_NEG: v.i = -CAR(n)->val.i; break; + case N_OP_NOT: v.i = ~CAR(n)->val.i; break; + case N_OP_ADD: v.i = CAR(n)->val.i + CDR(n)->val.i; break; + case N_OP_SUB: v.i = CAR(n)->val.i - CDR(n)->val.i; break; + case N_OP_MUL: v.i = CAR(n)->val.i * CDR(n)->val.i; break; case N_OP_DIV: - if (in[1]->val.i == 0) { - lex_error_at(l, in[1]->src_pos, LE_ERROR, S("divisor always evaluates to zero")); + if (CDR(n)->val.i == 0) { + lex_error_at(l, CDR(n)->src_pos, LE_ERROR, S("divisor always evaluates to zero")); } - v.i = in[0]->val.i / in[1]->val.i; + v.i = CAR(n)->val.i / CDR(n)->val.i; break; - case N_OP_AND: v.i = in[0]->val.i & in[1]->val.i; break; - case N_OP_OR: v.i = in[0]->val.i | in[1]->val.i; break; - case N_OP_XOR: v.i = in[0]->val.i ^ in[1]->val.i; break; - case N_OP_SHL: v.i = in[0]->val.u << in[1]->val.u; break; - case N_OP_SHR: v.i = in[0]->val.u >> in[1]->val.u; break; - case N_CMP_EQL: v.type.t = T_BOOL; v.i = in[0]->val.i == in[1]->val.i; break; - case N_CMP_NEQ: v.type.t = T_BOOL; v.i = in[0]->val.i != in[1]->val.i; break; - case N_CMP_LES: v.type.t = T_BOOL; v.i = in[0]->val.i < in[1]->val.i; break; - case N_CMP_GTR: v.type.t = T_BOOL; v.i = in[0]->val.i > in[1]->val.i; break; - case N_CMP_LTE: v.type.t = T_BOOL; v.i = in[0]->val.i <= in[1]->val.i; break; - case N_CMP_GTE: v.type.t = T_BOOL; v.i = in[0]->val.i >= in[1]->val.i; break; + case N_OP_AND: v.i = CAR(n)->val.i & CDR(n)->val.i; break; + case N_OP_OR: v.i = CAR(n)->val.i | CDR(n)->val.i; break; + case N_OP_XOR: v.i = CAR(n)->val.i ^ CDR(n)->val.i; break; + case N_OP_SHL: v.i = CAR(n)->val.u << CDR(n)->val.u; break; + case N_OP_SHR: v.i = CAR(n)->val.u >> CDR(n)->val.u; break; + case N_CMP_EQL: v.type.t = T_BOOL; v.i = CAR(n)->val.i == CDR(n)->val.i; break; + case N_CMP_NEQ: v.type.t = T_BOOL; v.i = CAR(n)->val.i != CDR(n)->val.i; break; + case N_CMP_LES: v.type.t = T_BOOL; v.i = CAR(n)->val.i < CDR(n)->val.i; break; + case N_CMP_GTR: v.type.t = T_BOOL; v.i = CAR(n)->val.i > CDR(n)->val.i; break; + case N_CMP_LTE: v.type.t = T_BOOL; v.i = CAR(n)->val.i <= CDR(n)->val.i; break; + case N_CMP_GTE: v.type.t = T_BOOL; v.i = CAR(n)->val.i >= CDR(n)->val.i; break; default: return n->val; } return v; } else if (lit_type.t == T_BOOL) { switch (n->type) { - case N_OP_NOT: v.i = !in[0]->val.i; break; - case N_CMP_EQL: v.i = in[0]->val.i == in[1]->val.i; break; - case N_CMP_NEQ: v.i = in[0]->val.i != in[1]->val.i; break; - case N_OP_AND: v.i = in[0]->val.i && in[1]->val.i; break; - case N_OP_OR: v.i = in[0]->val.i || in[1]->val.i; break; - case N_OP_XOR: v.i = in[0]->val.i ^ in[1]->val.i; break; + case N_OP_NOT: v.i = !CAR(n)->val.i; break; + case N_CMP_EQL: v.i = CAR(n)->val.i == CDR(n)->val.i; break; + case N_CMP_NEQ: v.i = CAR(n)->val.i != CDR(n)->val.i; break; + case N_OP_AND: v.i = CAR(n)->val.i && CDR(n)->val.i; break; + case N_OP_OR: v.i = CAR(n)->val.i || CDR(n)->val.i; break; + case N_OP_XOR: v.i = CAR(n)->val.i ^ CDR(n)->val.i; break; default: return n->val; } return v; @@ -389,7 +397,7 @@ Value node_compute(Node *n, Lexer *l) { #include -#define NODE(...) node_peephole(node_new(p, __VA_ARGS__), p, l) +#define NODE(t, ...) node_peephole(node_new(p, t, CTRL(n), __VA_ARGS__), p, l) #define OP(...) NODE(n->type, __VA_ARGS__) static inline int node_eql_i64(Node *n, int64_t i) { @@ -498,7 +506,7 @@ Node *node_idealize(Node *n, Proc *p, Lexer *l) { if (v.type.lvl == T_CONST) { Node *t = node_dedup_lit(p, v); if (t) return t; - Node *r = node_new(p, N_LIT, p->start); + Node *r = node_new(p, N_LIT, NULL, p->start); r->val = v; r->src_pos = n->src_pos; return r; @@ -517,8 +525,8 @@ Node *node_idealize(Node *n, Proc *p, Lexer *l) { } } - if (n->in.len > 1 && same && !same_ptr) { - Node *r = node_new_empty(p, n->type); + if (n->in.len > 2 && same && !same_ptr) { + Node *r = node_new(p, n->type, NULL); for (int i = 0; i < n->in.len; i++) { node_add(p, CAR(n), r); } @@ -691,6 +699,27 @@ zero_no_effect: if (node_eql_i64(CAR(n), 0)) return CDR(n); case N_CMP_GTE: if (same) return node_new_lit_bool(p, 1); break; + + case N_IF_ELSE: + if (T(CAR(n), N_LIT)) { + if (CAR(n)->val.i) { + n->val.tuple.data[1].type.lvl = T_XCTRL; + } else { + n->val.tuple.data[0].type.lvl = T_XCTRL; + } + } + break; + + case N_PHI: + if (same) return CAR(n); + if (IN(CTRL(n), 0)->val.type.lvl == T_XCTRL) { + return CDR(n); + } + if (IN(CTRL(n), 1)->val.type.lvl == T_XCTRL) { + return CAR(n); + } + break; + default: break; } @@ -747,15 +776,15 @@ Node *node_peephole(Node *n, Proc *p, Lexer *l) { void proc_init(Proc *proc, Str name) { memset(proc, 0, sizeof(Proc)); - proc->start = node_new_empty(proc, N_START); + proc->start = node_new(proc, N_START, NULL); proc->start->val.type = (Type) { .lvl = T_BOT, .t = T_TUPLE, .next = NULL }; - proc->stop = node_new_empty(proc, N_STOP); + proc->stop = node_new(proc, N_STOP, NULL); proc->ctrl = proc->start; - proc->keepalive = node_new_empty(proc, N_KEEPALIVE); + proc->keepalive = node_new(proc, N_KEEPALIVE, NULL); proc->name = name; } @@ -838,8 +867,14 @@ NameBinding *scope_update(Scope *scope, Str name, Node *to, Proc *proc) { void scope_collect(Scope *scope, Proc *proc, ScopeNameList *nl, Arena *arena) { for (ScopeFrame *f = scope->tail; f; f = f->prev) { for (NameBinding *b = f->latest; b; b = b->prev) { - node_add_out(proc, b->node, proc->keepalive); + node_add(proc, b->node, proc->keepalive); ZDA_PUSH(arena, nl, (ScopeName) { b->name, b->node }); } } } + +void scope_uncollect(Scope *scope, Proc *proc, ScopeNameList *nl) { + for (int i = 0; i < nl->len; i++) { + node_remove(proc, nl->data[i].node, proc->keepalive); + } +} diff --git a/ir.h b/ir.h index ff75163..c7fa118 100644 --- a/ir.h +++ b/ir.h @@ -12,7 +12,8 @@ typedef enum { T_TOP, /* may or may not be a constant */ T_CONST, /* known compile-time constant */ T_BOT, /* known not a constant */ - T_CTRL /* bottom of control flow */ + T_CTRL, /* control flow bottom */ + T_XCTRL, /* control flow top (dead) */ } TypeLevel; typedef enum { @@ -75,7 +76,7 @@ typedef struct Node { int walked; NodeType type; LexSpan src_pos; - NodeInputs in; + NodeInputs in; /* note: index 0 used for control flow */ NodeOutputs out; Value val; }; @@ -133,7 +134,7 @@ void node_add_in(Proc *p, Node *a, Node *b); void node_remove(Proc *p, Node *src, Node *dest); Node *node_new_empty(Proc *p, NodeType t); -Node *node_newv(Proc *p, NodeType t, ...); +Node *node_newv(Proc *p, NodeType t, Node *ctrl, ...); Node *node_dedup_lit(Proc *p, Value v); Value node_compute(Node *n, Lexer *l); Node *node_peephole(Node *n, Proc *p, Lexer *l); @@ -153,5 +154,6 @@ NameBinding *scope_find(Scope *scope, Str name); NameBinding *scope_bind(Scope *scope, Str name, Node *value, LexSpan pos, Proc *proc); NameBinding *scope_update(Scope *scope, Str name, Node *to, Proc *proc); void scope_collect(Scope *scope, Proc *proc, ScopeNameList *nl, Arena *arena); +void scope_uncollect(Scope *scope, Proc *proc, ScopeNameList *nl); #endif diff --git a/main.c b/main.c index 11c4e00..8265aab 100644 --- a/main.c +++ b/main.c @@ -85,7 +85,7 @@ void parse_assign(Lexer *l, Proc *p) { } } -void merge_scope(Proc *p, ScopeNameList *na, ScopeNameList *nb) { +void merge_scope(Lexer *l, Proc *p, ScopeNameList *na, ScopeNameList *nb) { for (ScopeFrame *f = p->scope.tail; f; f = f->prev) { for (NameBinding *b = f->latest; b; b = b->prev) { int i, j; @@ -94,20 +94,21 @@ void merge_scope(Proc *p, ScopeNameList *na, ScopeNameList *nb) { if (i >= na->len && j >= nb->len) continue; /* no change */ Node *phi; if (i >= na->len) { + if (nb->data[j].node == b->node) continue; phi = node_new(p, N_PHI, p->ctrl, b->node, nb->data[j].node); } else if (j >= na->len) { - phi = node_new(p, N_PHI, p->ctrl, na->data[i].node, b->node); + if (na->data[i].node == b->node) continue; + phi = node_new(p, N_PHI, p->ctrl, b->node, na->data[i].node); } else { + if (na->data[i].node == b->node && nb->data[j].node == b->node) continue; phi = node_new(p, N_PHI, p->ctrl, na->data[i].node, nb->data[j].node); } - node_add(p, phi, p->keepalive); node_remove(p, b->node, p->keepalive); + phi = node_peephole(phi, p, l); + node_add(p, phi, p->keepalive); b->node = phi; } } - - for (int i = 0; i < na->len; i++) node_remove(p, na->data[i].node, p->keepalive); - for (int i = 0; i < nb->len; i++) node_remove(p, nb->data[i].node, p->keepalive); } void parse_if(Lexer *l, Proc *p) { @@ -117,21 +118,21 @@ void parse_if(Lexer *l, Proc *p) { Node *if_node = node_new(p, N_IF_ELSE, p->ctrl, cond); if_node->val = (Value) { .type = { T_TOP, T_TUPLE, NULL }, - .tuple = { - .len = 2, - .cap = 0, - .data = (Value[2]) { - { .type = { T_CTRL, T_NONE, NULL } }, - { .type = { T_CTRL, T_NONE, NULL } }, - } - } + .tuple = { 0 } }; - Node *if_true = node_new(p, N_PROJ, if_node); - Node *if_false = node_new(p, N_PROJ, if_node); + ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { T_CTRL, T_NONE, NULL } }); + ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { T_CTRL, T_NONE, NULL } }); + if_node = node_peephole(if_node, p, l); + Node *if_true = node_peephole(node_new(p, N_PROJ, if_node), p, l); + Node *if_false = node_peephole(node_new(p, N_PROJ, if_node), p, l); ScopeNameList scope_before = { 0 }, scope_true = { 0 }, scope_false = { 0 }; if_true->val.i = 0; if_false->val.i = 1; scope_collect(&p->scope, p, &scope_before, &p->arena); + if (cond->val.type.lvl == T_CONST) { + if (cond->val.i) if_false->val.type.lvl = T_XCTRL; + else if_true->val.type.lvl = T_XCTRL; + } NODE_KEEP(p, cond, { p->ctrl = if_true; lex_expected(l, TM_LBRACE); @@ -147,12 +148,19 @@ void parse_if(Lexer *l, Proc *p) { } }); if (ctrl_else) { - p->ctrl = node_new(p, N_REGION, ctrl_if, ctrl_else); - merge_scope(p, &scope_true, &scope_false); + p->ctrl = node_peephole(node_new(p, N_REGION, ctrl_if, ctrl_else), p, l); + node_add(p, p->ctrl, p->keepalive); + merge_scope(l, p, &scope_true, &scope_false); + node_remove(p, p->ctrl, p->keepalive); } else { - p->ctrl = node_new(p, N_REGION, ctrl_if, if_false); - merge_scope(p, &scope_before, &scope_true); + p->ctrl = node_peephole(node_new(p, N_REGION, ctrl_if, if_false), p, l); + node_add(p, p->ctrl, p->keepalive); + merge_scope(l, p, &scope_true, &scope_before); + node_remove(p, p->ctrl, p->keepalive); } + scope_uncollect(&p->scope, p, &scope_true); + scope_uncollect(&p->scope, p, &scope_false); + scope_uncollect(&p->scope, p, &scope_before); } void parse_stmt(Lexer *l, Proc *p) { @@ -214,12 +222,12 @@ void parse_args_list(Lexer *l, Proc *proc) { lex_expect(l, TM_IDENT | TM_COMMA); if (l->tok == TOK_COMMA) continue; Value v = (Value) { .type = parse_type(l, proc) }; - ZDA_PUSH(&proc->arena, &start->val.tuple, v); lex_expected(l, TM_RPAREN | TM_COMMA); for (int j = 0; j < id; j++) { Node *proj = node_new(proc, N_PROJ, proc->start); proj->val.type = v.type; proj->val.i = i++; + ZDA_PUSH(&proc->arena, &start->val.tuple, v); scope_bind(&proc->scope, idbuf[j].name, proj, idbuf[j].pos, proc); } id = 0; @@ -264,7 +272,7 @@ Node *parse_term(Lexer *l, Proc *p) { default: return node; } if (post_op == N_START) return node; - return node_peephole(node_new(p, post_op, node), p, l); + return node_peephole(node_new(p, post_op, NULL, node), p, l); } if (l->tok == TOK_LPAREN) { lex_next(l); @@ -299,7 +307,7 @@ Node *parse_term(Lexer *l, Proc *p) { lex_next(l); } if (op_after != N_START) { - node = node_new(p, op_after, node_peephole(node, p, l)); + node = node_new(p, op_after, NULL, node_peephole(node, p, l)); } return node_peephole(node, p, l); } @@ -340,7 +348,7 @@ Node *parse_expr(Lexer *l, Proc *p) { NODE_KEEP(p, lhs, { rhs = parse_expr(l, p); }); - lhs = node_peephole(node_new(p, nt, lhs, rhs), p, l); + lhs = node_peephole(node_new(p, nt, NULL, lhs, rhs), p, l); } lhs->src_pos = (LexSpan) { pos.ofs, l->pos.ofs - pos.ofs }; return lhs; @@ -373,9 +381,6 @@ void parse_unit(Lexer *l) { void node_print(Node *n, Proc *p) { if (n->walked) return; - const char *colors[] = { - "red", "blue", "cyan", "green", "orange", "magenta", - }; if (n->type == N_START) { Str s = S(""); int c = n->val.tuple.len; @@ -398,11 +403,12 @@ void node_print(Node *n, Proc *p) { break; } } else if (n->type == N_PROJ) { - printf("\t%d [label=\"PROJ(%ld)\", shape=record]", n->id, n->val.i); + Str d = type_desc(&n->in.data[0]->val.tuple.data[n->val.i].type, &p->arena); + printf("\t%d [label=\"%ld | %.*s\", shape=record]", n->id, n->val.i, (int)d.n, d.s); } else { printf("\t%d [label=\"%s\", shape=record]", n->id, node_type_name(n->type)); } - printf(" [color=%s]\n", colors[n->id % (sizeof colors / sizeof *colors)]); + printf("\n"); n->walked = 1; for (int i = 0; i < n->out.len; i++) { Node *o = n->out.data[i]; @@ -411,7 +417,11 @@ void node_print(Node *n, Proc *p) { } else { int j; for (j = 0; j < o->in.len && o->in.data[j] != n; j++); - printf("\t%d -> %d [color=%s,headlabel=%d]\n", n->id, o->id, colors[o->id % (sizeof colors / sizeof *colors)], j); + if (j == 0) { + printf("\t%d -> %d [color=red,headlabel=%d]\n", n->id, o->id, j); + } else { + printf("\t%d -> %d [headlabel=%d]\n", n->id, o->id, j); + } } } for (int i = 0; i < n->out.len; i++) { diff --git a/test.lang b/test.lang index 00334c9..1f33175 100644 --- a/test.lang +++ b/test.lang @@ -1,6 +1,6 @@ proc main(a, b i64) { - if a > 5 { - a := a - 5 + if a = b { + b := 3 } return a } -- cgit v1.2.3