summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWormHeamer2025-08-04 22:13:25 -0400
committerWormHeamer2025-08-04 22:13:25 -0400
commit88b01f43312eeceba87a1378be5cd63bb11f167f (patch)
treec7046558223ac692958ca1c7b9da71f05493f148
parent22aa7559a37c49def05e8a43b12561b17a4af258 (diff)
fix bug of lhs getting culled if same node optimized out of rhs
-rw-r--r--Makefile5
-rw-r--r--ir.c62
-rw-r--r--ir.h12
-rw-r--r--main.c61
-rw-r--r--test.lang5
5 files changed, 64 insertions, 81 deletions
diff --git a/Makefile b/Makefile
index 48e0b1c..b77cd62 100644
--- a/Makefile
+++ b/Makefile
@@ -10,9 +10,10 @@ OBJ != find -type f -name '*.c' | sed 's/\.c$$/.o/'
DEBUG = 0
GDB != which gf2 2> /dev/null || which gdb
-CFLAGS_1 = -g3 -fsanitize=undefined
+CFLAGS_0 = -Os
+CFLAGS_1 = -O0 -g3 -fsanitize=undefined
LDFLAGS_1 = -g3 -fsanitize=undefined
-LDFLAGS_0 = -Os -s
+LDFLAGS_0 = -s
PREFIX ?= ${HOME}/.local
BINDIR = ${PREFIX}/bin
diff --git a/ir.c b/ir.c
index a35c80c..62d6600 100644
--- a/ir.c
+++ b/ir.c
@@ -119,53 +119,42 @@ const char *node_type_name(NodeType t) {
}
void node_die(Node *n, Proc *p) {
+ assert(n->refs == 0);
n->prev_free = p->free_list;
p->free_list = n;
}
void node_del_out(Node *n, Node *p) {
- for (int i = 0; i < n->out.len; i++) {
+ for (int i = n->out.len - 1; i >= 0; i--) {
if (n->out.data[i] == p) {
p->refs--;
- if (i + 1 < n->out.len) {
- memmove(&n->out.data[i], &n->out.data[i + 1], sizeof(Node*) * (n->out.len - i - 1));
- }
n->out.len--;
- i--;
+ if (i < n->out.len) {
+ n->out.data[i] = n->out.data[n->out.len];
+ }
break;
}
}
}
void node_del_in(Node *n, Node *p) {
- for (int i = 0; i < n->in.len; i++) {
+ for (int i = n->in.len - 1; i >= 0; i--) {
if (n->in.data[i] == p) {
p->refs--;
- if (i + 1 < n->in.len) {
- memmove(&n->in.data[i], &n->in.data[i + 1], sizeof(Node*) * (n->in.len - i - 1));
- }
n->in.len--;
- i--;
+ if (i < n->in.len) {
+ memmove(&n->in.data[i], &n->in.data[i + 1], sizeof(Node*) * (n->in.len - i));
+ }
break;
}
}
}
void node_kill(Node *n, Proc *p) {
- if (n->refs < 1) return;
- while (n->in.len > 0) {
- int i = --n->in.len;
- node_del_out(n->in.data[i], n);
- n->in.data[i]->refs--;
- if (n->in.data[i]->out.len < 1) node_kill(n->in.data[i], p);
- }
- while (n->out.len > 0) {
- int i = --n->out.len;
- node_del_in(n->out.data[i], n);
- n->out.data[i]->refs--;
- if (n->out.data[i]->refs < 1) node_die(n->out.data[i], p);
- }
- node_die(n, p);
+ assert(n->refs > 0);
+ while (n->refs > 0 && n->in.len > 0) node_remove(p, n->in.data[0], n);
+ while (n->refs > 0 && n->out.len > 0) node_remove(p, n, n->out.data[0]);
+ assert(n->refs == 0);
}
void node_add_out(Proc *p, Node *a, Node *b) {
@@ -174,8 +163,7 @@ void node_add_out(Proc *p, Node *a, Node *b) {
}
void node_add_in(Proc *p, Node *a, Node *b) {
- assert(a->in.len < NODE_INPUT_MAX);
- a->in.data[a->in.len++] = b;
+ ZDA_PUSH(&a->in, b, &p->arena);
b->refs++;
}
@@ -194,8 +182,7 @@ void node_remove(Proc *p, Node *src, Node *dest) {
node_del_out(src, dest);
node_del_in(dest, src);
if (dest->refs < 1) node_die(dest, p);
- if (src->refs < 1) node_die(src, p);
- else if (src->out.len < 1) node_kill(src, p);
+ if (src->out.len < 1) node_kill(src, p);
}
static int global_node_count = 0;
@@ -364,14 +351,6 @@ Value node_compute(Node *n, Lexer *l) {
#include <stdio.h>
-/* replace an in[] with a peepholed ver */
-void node_peephole_in(Node *n, int idx, Proc *p, Lexer *l) {
- node_del_out(n->in.data[idx], n);
- Node *r = node_peephole(n->in.data[idx], p, l);
- node_add_out(p, r, n);
- n->in.data[idx] = r;
-}
-
#define NODE(...) node_peephole(node_new(p, __VA_ARGS__), p, l)
#define OP(...) NODE(n->type, __VA_ARGS__)
@@ -481,8 +460,8 @@ Node *node_idealize(Node *n, Proc *p, Lexer *l) {
if (T(CAR(n), N_LIT) && !T(CDR(n), N_LIT)) return OP(CDR(n), CAR(n));
/* op(X, op(Y,Z)) -> op(op(Y,Z), X) */
- if (!T(CAR(n), n->type) && T(CDR(n), n->type)
- && C(CAR(n), CDAR(n))) return OP(CDR(n), CAR(n));
+ /*if (!T(CAR(n), n->type) && T(CDR(n), n->type)
+ && C(CAR(n), CDAR(n))) return OP(CDR(n), CAR(n));*/
/* op(op(X,Y), op(Z, lit)) -> op(op(X, op(Y, Z)), lit) */
if (T(CAR(n), n->type) && T(CDR(n), n->type) && T(CDDR(n), N_LIT)
@@ -601,20 +580,23 @@ zero_no_effect: if (node_eql_i64(CAR(n), 0)) return CDR(n);
break;
}
-
-
return NULL;
}
Node *node_peephole(Node *n, Proc *p, Lexer *l) {
+ assert(n->refs > 0);
+ node_add_out(p, n, p->keepalive);
Node *r = node_idealize(n, p, l);
if (r) {
r->src_pos = n->src_pos;
/* make sure r doesn't get deleted even if connected to n */
node_add_out(p, r, p->keepalive);
+ node_del_out(n, p->keepalive);
node_kill(n, p);
node_del_out(r, p->keepalive);
n = r;
+ } else {
+ node_del_out(n, p->keepalive);
}
/* FIXME: figure out why this shows the wrong position when in an assignment */
return n;
diff --git a/ir.h b/ir.h
index 112ab8e..799a0e0 100644
--- a/ir.h
+++ b/ir.h
@@ -62,20 +62,14 @@ typedef enum {
const char *node_type_name(NodeType t);
-#define NODE_INPUT_MAX 2
-
-typedef struct {
- struct Node *data[NODE_INPUT_MAX];
- ptrdiff_t len;
-} NodeInputs;
-
-typedef DYNARR(struct Node *) NodeOutputs;
+typedef DYNARR(struct Node *) NodeList;
+typedef NodeList NodeInputs, NodeOutputs;
typedef struct Node {
+ int id, refs;
union {
struct Node *prev_free;
struct {
- int id, refs;
int walked;
NodeType type;
LexSpan src_pos;
diff --git a/main.c b/main.c
index ce8981b..8533e49 100644
--- a/main.c
+++ b/main.c
@@ -213,41 +213,46 @@ Node *parse_term(Lexer *l, Proc *p) {
return node_peephole(node, p, l);
}
+NodeType tok_to_bin_op(Token t) {
+ switch (t) {
+ case TOK_PLUS: return N_OP_ADD; break;
+ case TOK_MINUS: return N_OP_SUB; break;
+ case TOK_ASTERISK: return N_OP_MUL; break;
+ case TOK_SLASH: return N_OP_DIV; break;
+ case TOK_NOT: return N_OP_NOT; break;
+ case TOK_AND: return N_OP_AND; break;
+ case TOK_OR: return N_OP_OR; break;
+ case TOK_XOR: return N_OP_XOR; break;
+ case TOK_SHL: return N_OP_SHL; break;
+ case TOK_SHR: return N_OP_SHR; break;
+ case TOK_EQL: return N_CMP_EQL; break;
+ case TOK_NEQ: return N_CMP_NEQ; break;
+ case TOK_LES: return N_CMP_LES; break;
+ case TOK_GTR: return N_CMP_GTR; break;
+ case TOK_LTE: return N_CMP_LTE; break;
+ case TOK_GTE: return N_CMP_GTE; break;
+ default: return N_START; break;
+ }
+}
+
/* TODO: operator precedence would be kinda nice actually, sad to say */
Node *parse_expr(Lexer *l, Proc *p) {
LexSpan pos = l->pos;
Node *lhs = parse_term(l, p);
- if (TMASK(l->tok) & (TM_PLUS | TM_MINUS | TM_ASTERISK | TM_SLASH
- | TM_NOT | TM_AND | TM_XOR | TM_OR | TM_SHL | TM_SHR
- | TM_EQL | TM_NEQ | TM_LES | TM_GTR | TM_LTE | TM_GTE)) {
- Token t = l->tok;
+ NodeType nt = tok_to_bin_op(l->tok);;
+ assert(lhs->refs > 0);
+ if (nt != N_START) {
lex_next(l);
+ /* necessary because if lhs is a deduplicated literal, it may be an input to rhs
+ * and therefore culled by peephole optimizations */
+ node_add(p, lhs, p->keepalive);
Node *rhs = parse_expr(l, p);
- NodeType nt = N_OP_ADD;
- switch (t) {
- case TOK_PLUS: nt = N_OP_ADD; break;
- case TOK_MINUS: nt = N_OP_SUB; break;
- case TOK_ASTERISK: nt = N_OP_MUL; break;
- case TOK_SLASH: nt = N_OP_DIV; break;
- case TOK_NOT: nt = N_OP_NOT; break;
- case TOK_AND: nt = N_OP_AND; break;
- case TOK_OR: nt = N_OP_OR; break;
- case TOK_XOR: nt = N_OP_XOR; break;
- case TOK_SHL: nt = N_OP_SHL; break;
- case TOK_SHR: nt = N_OP_SHR; break;
- case TOK_EQL: nt = N_CMP_EQL; break;
- case TOK_NEQ: nt = N_CMP_NEQ; break;
- case TOK_LES: nt = N_CMP_LES; break;
- case TOK_GTR: nt = N_CMP_GTR; break;
- case TOK_LTE: nt = N_CMP_LTE; break;
- case TOK_GTE: nt = N_CMP_GTE; break;
- default: break;
- }
- lhs = node_new(p, nt, lhs, rhs);
+ Node *n = node_peephole(node_new(p, nt, lhs, rhs), p, l);
+ node_remove(p, lhs, p->keepalive);
+ lhs = n;
}
- Node *n = node_peephole(lhs, p, l);
- n->src_pos = (LexSpan) { pos.ofs, l->pos.ofs - pos.ofs };
- return n;
+ lhs->src_pos = (LexSpan) { pos.ofs, l->pos.ofs - pos.ofs };
+ return lhs;
}
void parse_toplevel(Lexer *l, Unit *u) {
diff --git a/test.lang b/test.lang
index 2277fc3..c2661a4 100644
--- a/test.lang
+++ b/test.lang
@@ -9,6 +9,7 @@
// also single-line now
proc main(a i64) {
- return a + -a
- // (true = 0) = (a xor b)
+ let x = (a + -a) = (a xor a)
+ let y = (a + a) = (a * 2)
+ return x & y & (a = a) & ((a + 2) <> a)
}