From 027b6d8f62281f5b513a2b7bfcc02ea833c2cfd2 Mon Sep 17 00:00:00 2001 From: WormHeamer Date: Thu, 7 Aug 2025 22:19:31 -0400 Subject: i thiiiink if statement peepholes work now? --- ir.c | 29 +++++++++++++++++++++------- main.c | 66 +++++++++++++++++++++++++++++++++++++++++++++++---------------- test.lang | 10 +++++----- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/ir.c b/ir.c index d527bfe..a9fcd44 100644 --- a/ir.c +++ b/ir.c @@ -171,6 +171,10 @@ void node_del_in(Node *n, Node *p) { } void node_kill(Node *n, Proc *p) { + if (p->ctrl == n) { + /* probably this is fine */ + p->ctrl = CTRL(n); + } while (n->refs > 0 && n->in.len > 0) node_remove(p, n->in.data[0], n); while (n->refs > 0 && n->out.len > 0) node_remove(p, n, n->out.data[0]); assert(n->refs == 0); @@ -462,6 +466,7 @@ static inline int node_known_not_equiv(Node *a, Node *b) { static int node_equiv_input(Node *a, Node *b) { if (a->in.len != b->in.len) return 0; + if (CTRL(a) != CTRL(b)) return 0; /* note that this means the order of inputs isn't guaranteed, so be * careful what you use this procedure for */ if ((node_op_communative(a->type) || node_op_communative(b->type)) @@ -470,7 +475,7 @@ static int node_equiv_input(Node *a, Node *b) { /* assuming input count is 2 */ return 1; } - for (int i = 0; i < a->in.len; i++) { + for (int i = 1; i < a->in.len; i++) { if (!node_equiv(IN(a, i), IN(b, i))) return 0; } return 1; @@ -548,6 +553,12 @@ Node *node_idealize(Node *n, Proc *p, Lexer *l) { if (T(CDR(n), N_OP_NOT) && !T(CAR(n), N_OP_NOT)) { return NODE(n->type, CDR(n), CAR(n)); } + + if (T(CAR(n), N_PHI) && T(CDR(n), N_PHI) && CTRL(CAR(n)) == CTRL(CDR(n)) + && ((node_equiv(CAAR(n), CDAR(n)) && node_equiv(CADR(n), CDDR(n))) + || (node_equiv(CADR(n), CDAR(n)) && node_equiv(CAAR(n), CDDR(n))))) { + return OP(CAAR(n), CDAR(n)); + } } if (node_op_associative(n->type)) { @@ -711,14 +722,18 @@ zero_no_effect: if (node_eql_i64(CAR(n), 0)) return CDR(n); } break; + case N_PROJ: + if (T(CTRL(n), N_IF_ELSE) && CTRL(n)->val.type.t == T_TUPLE) { + if (CTRL(n)->val.tuple.data[(n->val.i + 1) % CTRL(n)->val.tuple.len].type.lvl == T_XCTRL) { + return CTRL(CTRL(n)); + } + } + break; + case N_PHI: if (same) return CAR(n); - if (IN(CTRL(n), 1)->val.type.lvl == T_XCTRL) { - return CAR(n); - } - if (IN(CTRL(n), 0)->val.type.lvl == T_XCTRL) { - return CDR(n); - } + if (IN(CTRL(n), 1)->val.type.lvl == T_XCTRL) return CAR(n); + if (IN(CTRL(n), 0)->val.type.lvl == T_XCTRL) return CDR(n); break; default: diff --git a/main.c b/main.c index 2341c8f..4a7a3d4 100644 --- a/main.c +++ b/main.c @@ -141,41 +141,67 @@ void parse_if(Lexer *l, Proc *p) { Node *ctrl_if = NULL, *ctrl_else = NULL; Node *if_node = node_new(p, N_IF_ELSE, p->ctrl, cond); if_node->val = (Value) { - .type = { T_TOP, T_TUPLE, NULL }, + .type = { .lvl = T_TOP, .t = T_TUPLE }, .tuple = { 0 } }; ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { .lvl = T_CTRL, .t = T_NONE } }); ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { .lvl = T_CTRL, .t = T_NONE } }); if_node = node_peephole(if_node, p, l); - Node *if_true = node_peephole(node_new(p, N_PROJ, if_node), p, l); - Node *if_false = node_peephole(node_new(p, N_PROJ, if_node), p, l); - ScopeNameList scope_before = { 0 }, scope_true = { 0 }, scope_false = { 0 }; + node_add(p, if_node, p->keepalive); + Node *if_true = node_new(p, N_PROJ, if_node); + Node *if_false = node_new(p, N_PROJ, if_node); if_true->val.i = 0; if_false->val.i = 1; + if_true = node_peephole(if_true, p, l); + if_false = node_peephole(if_false, p, l); + assert(if_true->in.len > 0); + assert(if_false->in.len > 0); + node_add(p, if_true, p->keepalive); + node_add(p, if_false, p->keepalive); + node_remove(p, if_node, p->keepalive); + ScopeNameList scope_before = { 0 }, scope_true = { 0 }, scope_false = { 0 }; scope_collect(&p->scope, p, &scope_before, &p->arena); if (cond->val.type.lvl == T_CONST) { if (cond->val.i) if_false->val.type.lvl = T_XCTRL; else if_true->val.type.lvl = T_XCTRL; } - NODE_KEEP(p, cond, { - p->ctrl = if_true; - lex_expected(l, TM_LBRACE); - parse_block(l, p, &scope_true); - ctrl_if = p->ctrl; - if (l->tok == TOK_ELSE) { - NODE_KEEP(p, ctrl_if, { - p->ctrl = if_false; - lex_expect(l, TM_LBRACE); - parse_block(l, p, &scope_false); - ctrl_else = p->ctrl; - }); + p->ctrl = if_true; + fprintf(stderr, ":+ %d\n", p->ctrl->id); + //assert(if_true->in.len > 0); + lex_expected(l, TM_LBRACE); + parse_block(l, p, &scope_true); + ctrl_if = p->ctrl; + node_add(p, ctrl_if, p->keepalive); + fprintf(stderr, ":- %d\n", p->ctrl->id); + //assert(ctrl_if->in.len > 0); + if (l->tok == TOK_ELSE) { + for (int i = 0; i < scope_before.len; i++) { + scope_update(scope_find(&p->scope, scope_before.data[i].name), scope_before.data[i].node, p); } - }); + p->ctrl = if_false; + lex_expect(l, TM_LBRACE); + parse_block(l, p, &scope_false); + ctrl_else = p->ctrl; + node_add(p, ctrl_else, p->keepalive); + } if (ctrl_else) { + //assert(ctrl_if->in.len > 0); + //assert(ctrl_else->in.len > 0); p->ctrl = node_peephole(node_new(p, N_REGION, ctrl_if, ctrl_else), p, l); + node_remove(p, ctrl_if, p->keepalive); + node_remove(p, ctrl_else, p->keepalive); } else { + //assert(ctrl_if->in.len > 0); + //assert(if_false->in.len > 0); p->ctrl = node_peephole(node_new(p, N_REGION, ctrl_if, if_false), p, l); + node_remove(p, ctrl_if, p->keepalive); + assert(if_true->refs > 0); + assert(if_false->refs > 0); } + node_remove(p, if_true, p->keepalive); + node_remove(p, if_false, p->keepalive); + //p->ctrl = node_peephole(node_new(p, N_REGION, if_true, if_false), p, l); + assert(p->ctrl->in.len > 0); merge_scope(l, p, &scope_before, &scope_true, &scope_false); scope_uncollect(&p->scope, p, &scope_true); scope_uncollect(&p->scope, p, &scope_false); @@ -434,6 +460,12 @@ void node_print(Node *n, Proc *p) { case T_BOOL: printf("\t%d [label=\"%s\"]", n->id, n->val.i ? "true" : "false"); break; + case T_NONE: + if (n->val.type.lvl == T_XCTRL) { + printf("\t%d [label=\"~ctrl\"]", n->id); + break; + } + /* fallthrough */ default: printf("\t%d [label=\"literal %d\"]", n->id, n->id); break; diff --git a/test.lang b/test.lang index 24a931d..2d3c42d 100644 --- a/test.lang +++ b/test.lang @@ -1,8 +1,8 @@ func main(a, b i64) i64 { - if a = b { - let t = a - a := b - b := t + if true { + a := 3 + } else { + a := 5 } - return a + b + return a } -- cgit v1.2.3