summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWormHeamer2025-08-04 00:43:51 -0400
committerWormHeamer2025-08-04 00:43:51 -0400
commita5e5749e41de721c2e982f42f6ba27fc2b6d69c1 (patch)
treeb35b1934468b49384cb185a058e4a1098cb9379e
parent487e48e985c6fa6762454af661f666fbe77fcdd1 (diff)
add projection nodes, fix peephole optimization
-rw-r--r--dynarr.h1
-rw-r--r--ir.c159
-rw-r--r--ir.h28
-rw-r--r--main.c114
-rw-r--r--strio.h1
-rw-r--r--test.lang4
6 files changed, 218 insertions, 89 deletions
diff --git a/dynarr.h b/dynarr.h
index fb577ef..b5c52b6 100644
--- a/dynarr.h
+++ b/dynarr.h
@@ -38,6 +38,7 @@ void *zda_fit(void **data, ptrdiff_t size, ptrdiff_t align, ptrdiff_t fit, ptrdi
#include <stdio.h>
#include <stdlib.h>
+#include <stdio.h>
void *da_fit(void **data, ptrdiff_t size, ptrdiff_t fit, ptrdiff_t *cap) {
ptrdiff_t c = *cap;
diff --git a/ir.c b/ir.c
index 8a2baf5..255d578 100644
--- a/ir.c
+++ b/ir.c
@@ -1,5 +1,8 @@
#include <stdarg.h>
+#include <assert.h>
+
#include "ir.h"
+#include "strio.h"
extern int no_opt;
@@ -18,6 +21,55 @@ int type_base_eql(Type *a, Type *b) {
return a->next ? type_base_eql(a->next, b->next) : 1;
}
+Str type_desc(Type *t, Arena *arena) {
+ (void)arena;
+ switch (t->t) {
+ case T_TUPLE: return S("tuple");
+ case T_BOOL: return S("bool");
+ case T_INT: return S("i64");
+ default: return S("N/A");
+ }
+}
+
+void type_err(Node *n, Lexer *l) {
+ Str s = S("");
+ for (int i = 0; i < n->in.len; i++) {
+ if (i > 0) str_cat(&s, S(", "), &l->arena);
+ str_cat(&s, type_desc(&n->in.data[i]->val.type, &l->arena), &l->arena);
+ }
+ lex_error_at(l, n->src_pos, LE_ERROR, str_fmt(&l->arena, "type error (%S)", s));
+}
+
+int type_check(Node *n) {
+ switch (n->type) {
+ case N_OP_NEG:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_INT };
+ return n->in.data[0]->val.type.t == T_INT;
+ case N_OP_NOT:
+ n->val.type = (Type) { .lvl = T_TOP, .t = n->in.data[0]->val.type.t };
+ return n->in.data[0]->val.type.t == T_INT || n->in.data[0]->val.type.t == T_BOOL;
+ case N_OP_AND: case N_OP_OR: case N_OP_XOR:
+ n->val.type = (Type) { .lvl = T_TOP, .t = n->in.data[0]->val.type.t };
+ return (n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT)
+ || (n->in.data[0]->val.type.t == T_BOOL && n->in.data[1]->val.type.t == T_BOOL);
+ case N_OP_ADD: case N_OP_SUB: case N_OP_MUL: case N_OP_DIV:
+ case N_OP_SHL: case N_OP_SHR:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_INT };
+ return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT;
+ case N_CMP_LES: case N_CMP_GTR:
+ case N_CMP_LTE: case N_CMP_GTE:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL };
+ return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT;
+ case N_CMP_EQL:
+ case N_CMP_NEQ:
+ n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL };
+ return (n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT)
+ || (n->in.data[0]->val.type.t == T_BOOL && n->in.data[1]->val.type.t == T_BOOL);
+ default:
+ return 1;
+ }
+}
+
/* nodes */
const char *node_type_name(NodeType t) {
@@ -90,11 +142,19 @@ void node_kill(Node *n, Proc *p) {
node_die(n, p);
}
+void node_add_out(Proc *p, Node *a, Node *b) {
+ ZDA_PUSH(&a->out, b, &p->arena);
+ b->refs++;
+}
+
+void node_add_in(Proc *p, Node *a, Node *b) {
+ ZDA_PUSH(&a->in, b, &p->arena);
+ b->refs++;
+}
+
void node_add(Proc *p, Node *src, Node *dest) {
- ZDA_PUSH(&src->out, dest, &p->arena);
- ZDA_PUSH(&dest->in, src, &p->arena);
- src->refs++;
- dest->refs++;
+ node_add_out(p, src, dest);
+ node_add_in(p, dest, src);
if (dest->src_pos.n == 0) dest->src_pos = src->src_pos;
}
@@ -122,6 +182,7 @@ Node *node_new_empty(Proc *p, NodeType t) {
return n;
}
+int type_check(Node *);
Node *node_newv(Proc *p, NodeType t, ...) {
Node *node = node_new_empty(p, t);
va_list ap;
@@ -136,6 +197,7 @@ Node *node_newv(Proc *p, NodeType t, ...) {
}
Node *node_dedup_lit(Proc *p, Value v) {
+ return NULL;
/* TODO: this is probably real inefficient for large procedure graphs,
* but does it matter? how many nodes are direct children of the start node?
* how many literals even usually occur in a procedure? */
@@ -238,14 +300,25 @@ Value node_compute(Node *n, Lexer *l) {
return n->val;
}
+#include <stdio.h>
+
+/* replace an in[] with a peepholed ver */
+void node_peephole_in(Node *n, int idx, Proc *p, Lexer *l) {
+ node_del_out(n->in.data[idx], n);
+ Node *r = node_peephole(n->in.data[idx], p, l);
+ node_add_out(p, r, n);
+ n->in.data[idx] = r;
+}
+
+#define NODE(...) node_peephole(node_new(p, __VA_ARGS__), p, l)
+
/* needs lexer for error reporting */
-Node *node_peephole(Node *n, Proc *p, Lexer *l) {
- if (no_opt) return n;
+Node *node_idealize(Node *n, Proc *p, Lexer *l) {
+ if (no_opt) return NULL;
if (n->type != N_LIT) {
Value v = node_compute(n, l);
if (v.type.lvl == T_CONST) {
- node_kill(n, p);
Node *t = node_dedup_lit(p, v);
if (t) return t;
Node *r = node_new(p, N_LIT, p->start);
@@ -260,33 +333,58 @@ Node *node_peephole(Node *n, Proc *p, Lexer *l) {
if (node_op_communative(n->type)) {
/* transformations to help encourage constant folding */
/* the overall trend is to move them rightwards */
- if (in[0]->type == N_LIT
- && in[1]->type == n->type
- && in[1]->in.data[0]->type != N_LIT
- && in[1]->in.data[1]->type == N_LIT) {
- /* op(lit, op(X, lit)) -> op(X, op(lit, lit)) */
- Node *tmp = in[1]->in.data[0];
- in[1]->in.data[0] = in[0];
- in[0] = tmp;
- /* TODO: ...would it break anything at all to just do in[1] = node_peephole(in[1], p, l)?
- * probably not, right? */
- } else if (in[0]->type == n->type
+
+ if ((in[0]->type == N_LIT && in[1]->type != N_LIT)
+ || (in[1]->type == n->type && in[0]->type != n->type)) {
+
+ if (in[1]->type == n->type) {
+ fprintf(stderr, "op(X, op(Y, Z)) -> op(op(Y, Z), X)\n");
+ } else {
+ fprintf(stderr, "op(lit, X) -> op(X, lit)\n");
+ }
+
+ return NODE(n->type, in[1], in[0]);
+ }
+
+ if (in[1]->type == N_LIT
+ && in[0]->type == n->type
+ && in[0]->in.data[0]->type != N_LIT
+ && in[0]->in.data[1]->type == N_LIT) {
+
+ fprintf(stderr, "op(op(X, lit), lit) -> op(X, op(lit, lit))\n");
+
+ return NODE(n->type,
+ in[0]->in.data[0],
+ NODE(n->type, in[0]->in.data[1], in[1]));
+ }
+
+ /* op(op(X, lit), Y) -> op(op(X, Y), lit) */
+ if (in[0]->type == n->type
&& in[0]->in.data[0]->type != N_LIT
&& in[0]->in.data[1]->type == N_LIT
&& in[1]->type != N_LIT) {
- /* op(op(X, lit), Y) -> op(op(X, Y), lit) */
- Node *tmp = in[0]->in.data[1];
- in[0]->in.data[1] = in[1];
- in[1] = tmp;
- } else if (in[0]->type == N_LIT && in[1]->type != N_LIT) {
- /* op(lit, X) -> op(X, lit) */
- Node *tmp = in[0];
- in[0] = in[1];
- in[1] = tmp;
+ fprintf(stderr, "op(op(X, lit), Y) -> op(op(X, Y), lit)\n");
+
+ return NODE(n->type,
+ NODE(n->type, in[0]->in.data[0], in[1]),
+ in[0]->in.data[1]);
}
}
- return n;
+ return NULL;
+}
+
+Node *node_peephole(Node *n, Proc *p, Lexer *l) {
+ Node *r = node_idealize(n, p, l);
+ if (r) {
+ /* make sure r doesn't get deleted even if connected to n */
+ node_add_out(p, r, p->keepalive);
+ node_kill(n, p);
+ node_del_out(r, p->keepalive);
+ return r;
+ } else {
+ return n;
+ }
}
/* procedures */
@@ -294,6 +392,11 @@ Node *node_peephole(Node *n, Proc *p, Lexer *l) {
void proc_init(Proc *proc, Str name) {
memset(proc, 0, sizeof(Proc));
proc->start = node_new_empty(proc, N_START);
+ proc->start->val.type = (Type) {
+ .lvl = T_BOT,
+ .t = T_TUPLE,
+ .next = NULL
+ };
proc->keepalive = node_new_empty(proc, N_KEEPALIVE);
proc->name = name;
}
diff --git a/ir.h b/ir.h
index 432c96c..0a8a9c2 100644
--- a/ir.h
+++ b/ir.h
@@ -15,19 +15,33 @@ typedef enum {
} TypeLevel;
typedef enum {
+ T_NONE,
T_TUPLE,
T_BOOL,
T_INT
} BaseType;
typedef struct Type {
- BaseType t;
TypeLevel lvl;
+ BaseType t;
struct Type *next;
} Type;
+typedef struct Value {
+ Type type;
+ union {
+ int64_t i;
+ uint64_t u;
+ DYNARR(struct Value) tuple;
+ };
+} Value;
+
+struct Node;
int type_eql(Type *a, Type *b);
int type_base_eql(Type *a, Type *b);
+int type_check(struct Node *n);
+Str type_desc(Type *t, Arena *arena);
+void type_err(struct Node *n, Lexer *l);
/* nodes */
@@ -47,14 +61,6 @@ typedef enum {
const char *node_type_name(NodeType t);
-typedef struct {
- Type type;
- union {
- int64_t i;
- uint64_t u;
- };
-} Value;
-
typedef struct Node {
union {
struct Node *prev_free;
@@ -101,7 +107,11 @@ void node_die(Node *n, Proc *p);
void node_del_out(Node *n, Node *p);
void node_del_in(Node *n, Node *p);
void node_kill(Node *n, Proc *p);
+
void node_add(Proc *p, Node *src, Node *dest);
+void node_add_out(Proc *p, Node *a, Node *b);
+void node_add_in(Proc *p, Node *a, Node *b);
+
void node_remove(Proc *p, Node *src, Node *dest);
Node *node_new_empty(Proc *p, NodeType t);
Node *node_newv(Proc *p, NodeType t, ...);
diff --git a/main.c b/main.c
index 1a4a6d1..f1665e5 100644
--- a/main.c
+++ b/main.c
@@ -9,6 +9,7 @@
#include "lex.h"
#include "arena.h"
#include "dynarr.h"
+#include "strio.h"
#include "ir.h"
int no_opt = 0;
@@ -101,14 +102,53 @@ void parse_stmt(Lexer *l, Proc *p) {
}
}
+
+Type parse_type(Lexer *l, Proc *proc) {
+ (void)proc;
+ Type t = { .lvl = T_BOT };
+ if (str_eql(l->ident, S("i64"))) {
+ t.t = T_INT;
+ } else if (str_eql(l->ident, S("bool"))) {
+ t.t = T_BOOL;
+ } else {
+ lex_error(l, LE_ERROR, S("unknown type"));
+ }
+ lex_next(l);
+ return t;
+}
+
+void parse_args_list(Lexer *l, Proc *proc) {
+ Node *start = proc->start;
+ int i = 0;
+ while (l->tok != TOK_RPAREN && l->tok != TOK_EOF) {
+ lex_expect(l, TM_IDENT);
+ Str name = l->ident;
+ LexSpan pos = l->pos;
+ lex_expect(l, TM_IDENT);
+ Value v = (Value) { .type = parse_type(l, proc) };
+ ZDA_PUSH(&start->val.tuple, v, &proc->arena);
+ lex_expected(l, TM_RPAREN | TM_COMMA);
+ Node *proj = node_new(proc, N_PROJ, proc->start);
+ proj->val.type = v.type;
+ proj->val.i = i++;
+ scope_bind(&proc->scope, name, proj, pos, proc);
+ }
+ lex_expected(l, TM_RPAREN);
+ lex_next(l);
+}
+
Proc *parse_proc(Lexer *l, Unit *u) {
DA_FIT(&u->procs, u->procs.len + 1);
Proc *proc = &u->procs.data[u->procs.len++];
lex_expect(l, TM_IDENT);
proc_init(proc, l->ident);
- lex_expect(l, TM_LBRACE);
- lex_next(l);
scope_push(&proc->scope, proc);
+ lex_expect(l, TM_LPAREN | TM_LBRACE);
+ if (l->tok == TOK_LPAREN) {
+ parse_args_list(l, proc);
+ }
+ lex_expected(l, TM_LBRACE);
+ lex_next(l);
while (l->tok != TOK_RBRACE) {
lex_expected_not(l, TM_EOF);
parse_stmt(l, proc);
@@ -119,39 +159,6 @@ Proc *parse_proc(Lexer *l, Unit *u) {
return proc;
}
-int type_check(Node *n) {
- /*fprintf(stderr, "::\n");
- for (int i = 0; i < n->in.len; i++) {
- fprintf(stderr, "%d: %d/%d\n", i,
- n->in.data[i]->val.type.lvl,
- n->in.data[i]->val.type.t);
- }*/
- switch (n->type) {
- case N_OP_NEG:
- n->val.type = (Type) { .lvl = T_TOP, .t = T_INT };
- return n->in.data[0]->val.type.t == T_INT;
- case N_OP_NOT:
- n->val.type = (Type) { .lvl = T_TOP, .t = n->in.data[0]->val.type.t };
- return n->in.data[0]->val.type.t == T_INT || n->in.data[0]->val.type.t == T_BOOL;
- case N_OP_ADD: case N_OP_SUB: case N_OP_MUL: case N_OP_DIV:
- case N_OP_AND: case N_OP_OR: case N_OP_XOR:
- case N_OP_SHL: case N_OP_SHR:
- n->val.type = (Type) { .lvl = T_TOP, .t = T_INT };
- return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT;
- case N_CMP_LES: case N_CMP_GTR:
- case N_CMP_LTE: case N_CMP_GTE:
- n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL };
- return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT;
- case N_CMP_EQL:
- case N_CMP_NEQ:
- n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL };
- return (n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT)
- || (n->in.data[0]->val.type.t == T_BOOL && n->in.data[1]->val.type.t == T_BOOL);
- default:
- return 1;
- }
-}
-
Node *parse_term(Lexer *l, Proc *p) {
Node *node = NULL;
NodeType op_after = N_START;
@@ -166,11 +173,7 @@ Node *parse_term(Lexer *l, Proc *p) {
default: return node;
}
if (post_op == N_START) return node;
- Node *r = node_new(p, post_op, node);
- if (!type_check(r)) {
- lex_error_at(l, r->src_pos, LE_ERROR, S("type mismatch"));
- }
- return node_peephole(r, p, l);
+ return node_peephole(node_new(p, post_op, node), p, l);
}
if (l->tok == TOK_LPAREN) {
lex_next(l);
@@ -242,9 +245,6 @@ Node *parse_expr(Lexer *l, Proc *p) {
}
Node *n = node_peephole(lhs, p, l);
n->src_pos = (LexSpan) { pos.ofs, l->pos.ofs - pos.ofs };
- if (!type_check(n)) {
- lex_error_at(l, n->src_pos, LE_ERROR, S("type mismatch"));
- }
return n;
}
@@ -273,8 +273,17 @@ void parse_unit(Lexer *l) {
/* graph output */
-void node_print(Node *n) {
- if (n->type == N_LIT) {
+void node_print(Node *n, Proc *p) {
+ if (n->type == N_START) {
+ Str s = S("");
+ int c = n->val.tuple.len;
+ Value *v = n->val.tuple.data;
+ for (int i = 0; i < c; i++) {
+ if (i > 0) str_cat(&s, S(", "), &p->arena);
+ str_cat(&s, type_desc(&v[i].type, &p->arena), &p->arena);
+ }
+ printf("\t%d [label=\"start(%.*s)\"]\n", n->id, (int)s.n, s.s);
+ } else if (n->type == N_LIT) {
switch (n->val.type.t) {
case T_INT:
printf("\t%d [label=\"%ld\"]\n", n->id, n->val.i);
@@ -286,6 +295,8 @@ void node_print(Node *n) {
printf("\t%d [label=\"literal %d\"]\n", n->id, n->id);
break;
}
+ } else if (n->type == N_PROJ) {
+ printf("\t%d [label=\"PROJ(%ld)\", shape=record]\n", n->id, n->val.i);
} else {
printf("\t%d [label=\"%s\", shape=record]\n", n->id, node_type_name(n->type));
}
@@ -294,21 +305,24 @@ void node_print(Node *n) {
}
n->walked = 1;
for (int i = 0; i < n->out.len; i++) {
- if (n->out.data[i]->type == N_LIT) {
- printf("\t%d -> %d [style=dashed]\n", n->id, n->out.data[i]->id);
+ Node *o = n->out.data[i];
+ if (o->type == N_LIT) {
+ printf("\t%d -> %d [style=dashed]\n", n->id, o->id);
} else {
- printf("\t%d -> %d\n", n->id, n->out.data[i]->id);
+ int j;
+ for (j = 0; j < o->in.len && o->in.data[j] != n; j++);
+ printf("\t%d -> %d [label=%d]\n", n->id, o->id, j);
}
}
for (int i = 0; i < n->out.len; i++) {
- node_print(n->out.data[i]);
+ node_print(n->out.data[i], p);
}
}
void proc_print(Proc *p) {
if (p->start) {
printf("\t\"%.*s\" -> %d\n", (int)p->name.n, p->name.s, p->start->id);
- node_print(p->start);
+ node_print(p->start, p);
if (no_opt) {
for (NameBinding *b = p->scope.free_bind; b; b = b->prev) {
uint64_t id = (uintptr_t)b->node;
diff --git a/strio.h b/strio.h
index 0d6c6f5..51cf44a 100644
--- a/strio.h
+++ b/strio.h
@@ -2,6 +2,7 @@
#define STRIO_H
#include <stdarg.h>
+#include <stdio.h>
#include "str.h"
#include "arena.h"
diff --git a/test.lang b/test.lang
index a1c1183..0bfcd0a 100644
--- a/test.lang
+++ b/test.lang
@@ -8,6 +8,6 @@
// also single-line now
-proc main {
- return (5 > 4) = (4 < 5)
+proc main(a i64, b i64) {
+ return (4 + (3 * (a + 2 + b + 4) * 5) + 8) < 3
}