summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ir.c14
-rw-r--r--ir.h6
-rw-r--r--lex.h5
-rw-r--r--main.c119
-rw-r--r--test.lang8
5 files changed, 98 insertions, 54 deletions
diff --git a/ir.c b/ir.c
index bd3354b..d527bfe 100644
--- a/ir.c
+++ b/ir.c
@@ -60,6 +60,7 @@ Str type_desc(Type *t, Arena *arena) {
case T_TUPLE: return S("tuple");
case T_BOOL: return S("bool");
case T_INT: return S("i64");
+ case T_PTR: return str_fmt(arena, "^%S", type_desc(t->next, arena));
default: return S("N/A");
}
}
@@ -712,12 +713,12 @@ zero_no_effect: if (node_eql_i64(CAR(n), 0)) return CDR(n);
case N_PHI:
if (same) return CAR(n);
- if (IN(CTRL(n), 0)->val.type.lvl == T_XCTRL) {
- return CDR(n);
- }
if (IN(CTRL(n), 1)->val.type.lvl == T_XCTRL) {
return CAR(n);
}
+ if (IN(CTRL(n), 0)->val.type.lvl == T_XCTRL) {
+ return CDR(n);
+ }
break;
default:
@@ -854,12 +855,11 @@ NameBinding *scope_bind(Scope *scope, Str name, Node *value, LexSpan pos, Proc *
return prev;
}
-NameBinding *scope_update(Scope *scope, Str name, Node *to, Proc *proc) {
- NameBinding *b = scope_find(scope, name);
- if (!b) return NULL;
+NameBinding *scope_update(NameBinding *b, Node *to, Proc *proc) {
+ Node *n = b->node;
node_add(proc, to, proc->keepalive);
- node_remove(proc, b->node, proc->keepalive);
b->node = to;
+ node_remove(proc, n, proc->keepalive);
return b;
}
diff --git a/ir.h b/ir.h
index c7fa118..cba0bcc 100644
--- a/ir.h
+++ b/ir.h
@@ -20,7 +20,8 @@ typedef enum {
T_NONE,
T_TUPLE,
T_BOOL,
- T_INT
+ T_INT,
+ T_PTR
} BaseType;
typedef struct Type {
@@ -114,6 +115,7 @@ typedef struct {
Str name;
Node *start, *stop, *ctrl, *keepalive;
Node *free_list;
+ Type ret_type;
Scope scope;
} Proc;
@@ -152,7 +154,7 @@ ScopeFrame *scope_push(Scope *scope, Proc *proc);
ScopeFrame *scope_pop(Scope *scope, Proc *proc);
NameBinding *scope_find(Scope *scope, Str name);
NameBinding *scope_bind(Scope *scope, Str name, Node *value, LexSpan pos, Proc *proc);
-NameBinding *scope_update(Scope *scope, Str name, Node *to, Proc *proc);
+NameBinding *scope_update(NameBinding *b, Node *to, Proc *proc);
void scope_collect(Scope *scope, Proc *proc, ScopeNameList *nl, Arena *arena);
void scope_uncollect(Scope *scope, Proc *proc, ScopeNameList *nl);
diff --git a/lex.h b/lex.h
index dc23296..e80edcb 100644
--- a/lex.h
+++ b/lex.h
@@ -7,6 +7,7 @@
X(EOF, "end-of-file")\
X(IDENT, "identifier")\
X(PROC, "proc")\
+ X(FUNC, "func")\
X(LET, "let")\
X(VAR, "var")\
X(CONST, "const")\
@@ -38,6 +39,7 @@
X(GTR, ">")\
X(LTE, "<=")\
X(GTE, ">=")\
+ X(DEREF, "^")\
X(LIT_STR, "string")\
X(LIT_CHAR, "character")\
X(LIT_NUM, "number")
@@ -55,7 +57,8 @@
X(TOK_COMMA, ',')\
X(TOK_NOT, '~')\
X(TOK_AND, '&')\
- X(TOK_OR, '|')
+ X(TOK_OR, '|')\
+ X(TOK_DEREF, '^')
typedef enum {
diff --git a/main.c b/main.c
index 8265aab..2341c8f 100644
--- a/main.c
+++ b/main.c
@@ -35,7 +35,19 @@ Node *parse_expr(Lexer *l, Proc *p);
void parse_return(Lexer *l, Proc *p) {
lex_next(l);
- Node *n = node_new(p, N_RETURN, p->ctrl, parse_expr(l, p));
+ Node *n;
+ if (p->ret_type.t == T_NONE) {
+ n = node_new(p, N_RETURN, p->ctrl);
+ } else {
+ Node *e = parse_expr(l, p);
+ if (!type_base_eql(&e->val.type, &p->ret_type)) {
+ lex_error_at(l, e->src_pos, LE_ERROR,
+ str_fmt(&p->arena, "incorrect return type (expected %S, got %S)",
+ type_desc(&p->ret_type, &p->arena),
+ type_desc(&e->val.type, &p->arena)));
+ }
+ n = node_new(p, N_RETURN, p->ctrl, e);
+ }
node_add(p, n, p->stop);
p->ctrl = n;
}
@@ -79,36 +91,48 @@ void parse_assign(Lexer *l, Proc *p) {
LexSpan pos = l->pos;
lex_expect(l, TM_ASSIGN);
lex_next(l);
- Node *e = parse_expr(l, p);
- if (!scope_update(&p->scope, name, e, p)) {
+ NameBinding *b = scope_find(&p->scope, name);
+ if (!b) {
lex_error_at(l, pos, LE_ERROR, S("undeclared identifier"));
}
+ Node *e = parse_expr(l, p);
+ if (!type_base_eql(&e->val.type, &b->node->val.type)) {
+ lex_error_at(l, pos, LE_ERROR,
+ str_fmt(&p->arena, "tried to assign value of type %S to variable of type %S",
+ type_desc(&e->val.type, &p->arena),
+ type_desc(&b->node->val.type, &p->arena)));
+ }
+ scope_update(b, e, p);
}
-void merge_scope(Lexer *l, Proc *p, ScopeNameList *na, ScopeNameList *nb) {
- for (ScopeFrame *f = p->scope.tail; f; f = f->prev) {
- for (NameBinding *b = f->latest; b; b = b->prev) {
- int i, j;
- for (i = 0; i < na->len && !str_eql(na->data[i].name, b->name); i++);
- for (j = 0; j < nb->len && !str_eql(nb->data[j].name, b->name); j++);
- if (i >= na->len && j >= nb->len) continue; /* no change */
- Node *phi;
- if (i >= na->len) {
- if (nb->data[j].node == b->node) continue;
- phi = node_new(p, N_PHI, p->ctrl, b->node, nb->data[j].node);
- } else if (j >= na->len) {
- if (na->data[i].node == b->node) continue;
- phi = node_new(p, N_PHI, p->ctrl, b->node, na->data[i].node);
- } else {
- if (na->data[i].node == b->node && nb->data[j].node == b->node) continue;
- phi = node_new(p, N_PHI, p->ctrl, na->data[i].node, nb->data[j].node);
- }
- node_remove(p, b->node, p->keepalive);
- phi = node_peephole(phi, p, l);
- node_add(p, phi, p->keepalive);
- b->node = phi;
+void merge_scope(Lexer *l, Proc *p, ScopeNameList *before, ScopeNameList *ntrue, ScopeNameList *nfalse) {
+ node_add(p, p->ctrl, p->keepalive);
+ for (int i = 0; i < before->len; i++) {
+ int j, k;
+ ScopeName *b4 = &before->data[i];
+ for (j = 0; j < ntrue->len && !str_eql(ntrue->data[j].name, b4->name); j++);
+ for (k = 0; k < nfalse->len && !str_eql(nfalse->data[k].name, b4->name); k++);
+ ScopeName *yes = j < ntrue->len ? &ntrue->data[j] : NULL;
+ ScopeName *no = k < nfalse->len ? &nfalse->data[k] : NULL;
+ if (!yes && !no) continue; /* no change */
+ Node *phi;
+ if (!no) {
+ if (yes->node == b4->node) continue;
+ phi = node_new(p, N_PHI, p->ctrl, yes->node, b4->node);
+ } else if (!yes) {
+ if (no->node == b4->node) continue;
+ phi = node_new(p, N_PHI, p->ctrl, b4->node, no->node);
+ } else {
+ if (yes->node == b4->node && no->node == b4->node) continue;
+ phi = node_new(p, N_PHI, p->ctrl, yes->node, no->node);
}
+ fprintf(stderr, "phi('%.*s', %d, %d)\n", (int)b4->name.n, b4->name.s, phi->in.data[1]->id, phi->in.data[2]->id);
+ phi = node_peephole(phi, p, l);
+ NameBinding *b = scope_find(&p->scope, b4->name);
+ assert(b);
+ scope_update(b, phi, p);
}
+ node_remove(p, p->ctrl, p->keepalive);
}
void parse_if(Lexer *l, Proc *p) {
@@ -120,8 +144,8 @@ void parse_if(Lexer *l, Proc *p) {
.type = { T_TOP, T_TUPLE, NULL },
.tuple = { 0 }
};
- ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { T_CTRL, T_NONE, NULL } });
- ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { T_CTRL, T_NONE, NULL } });
+ ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { .lvl = T_CTRL, .t = T_NONE } });
+ ZDA_PUSH(&p->arena, &if_node->val.tuple, (Value) { .type = { .lvl = T_CTRL, .t = T_NONE } });
if_node = node_peephole(if_node, p, l);
Node *if_true = node_peephole(node_new(p, N_PROJ, if_node), p, l);
Node *if_false = node_peephole(node_new(p, N_PROJ, if_node), p, l);
@@ -149,15 +173,10 @@ void parse_if(Lexer *l, Proc *p) {
});
if (ctrl_else) {
p->ctrl = node_peephole(node_new(p, N_REGION, ctrl_if, ctrl_else), p, l);
- node_add(p, p->ctrl, p->keepalive);
- merge_scope(l, p, &scope_true, &scope_false);
- node_remove(p, p->ctrl, p->keepalive);
} else {
p->ctrl = node_peephole(node_new(p, N_REGION, ctrl_if, if_false), p, l);
- node_add(p, p->ctrl, p->keepalive);
- merge_scope(l, p, &scope_true, &scope_before);
- node_remove(p, p->ctrl, p->keepalive);
}
+ merge_scope(l, p, &scope_before, &scope_true, &scope_false);
scope_uncollect(&p->scope, p, &scope_true);
scope_uncollect(&p->scope, p, &scope_false);
scope_uncollect(&p->scope, p, &scope_before);
@@ -191,10 +210,19 @@ void parse_stmt(Lexer *l, Proc *p) {
Type parse_type(Lexer *l, Proc *proc) {
(void)proc;
Type t = { .lvl = T_BOT };
- if (str_eql(l->ident, S("i64"))) {
- t.t = T_INT;
- } else if (str_eql(l->ident, S("bool"))) {
- t.t = T_BOOL;
+ if (l->tok == TOK_DEREF) {
+ lex_next(l);
+ t.t = T_PTR;
+ t.next = new(&proc->arena, Type);
+ *t.next = parse_type(l, proc);
+ return t;
+ }
+ if (l->tok == TOK_IDENT) {
+ if (str_eql(l->ident, S("i64"))) {
+ t.t = T_INT;
+ } else if (str_eql(l->ident, S("bool"))) {
+ t.t = T_BOOL;
+ }
} else {
lex_error(l, LE_ERROR, S("unknown type"));
}
@@ -219,7 +247,7 @@ void parse_args_list(Lexer *l, Proc *proc) {
idbuf[id].name = l->ident;
idbuf[id].pos = l->pos;
id++;
- lex_expect(l, TM_IDENT | TM_COMMA);
+ lex_next(l);
if (l->tok == TOK_COMMA) continue;
Value v = (Value) { .type = parse_type(l, proc) };
lex_expected(l, TM_RPAREN | TM_COMMA);
@@ -237,6 +265,7 @@ void parse_args_list(Lexer *l, Proc *proc) {
}
Proc *parse_proc(Lexer *l, Unit *u) {
+ int has_ret = l->tok == TOK_FUNC;
DA_FIT(&u->procs, u->procs.len + 1);
Proc *proc = &u->procs.data[u->procs.len++];
lex_expect(l, TM_IDENT);
@@ -246,6 +275,11 @@ Proc *parse_proc(Lexer *l, Unit *u) {
if (l->tok == TOK_LPAREN) {
parse_args_list(l, proc);
}
+ if (has_ret) {
+ proc->ret_type = parse_type(l, proc);
+ } else {
+ proc->ret_type = (Type) { .lvl = T_BOT, .t = T_NONE };
+ }
lex_expected(l, TM_LBRACE);
lex_next(l);
while (l->tok != TOK_RBRACE) {
@@ -339,6 +373,7 @@ Node *parse_expr(Lexer *l, Proc *p) {
LexSpan pos = l->pos;
Node *lhs = parse_term(l, p);
NodeType nt = tok_to_bin_op(l->tok);;
+ if (lhs->refs <= 0) lex_error(l, LE_ERROR, S("dead lhs"));
assert(lhs->refs > 0);
if (nt != N_START) {
lex_next(l);
@@ -357,10 +392,11 @@ Node *parse_expr(Lexer *l, Proc *p) {
void parse_toplevel(Lexer *l, Unit *u) {
switch (l->tok) {
case TOK_PROC:
+ case TOK_FUNC:
parse_proc(l, u);
break;
default:
- lex_expected(l, TM_PROC);
+ lex_expected(l, TM_PROC | TM_FUNC);
break;
}
}
@@ -404,7 +440,7 @@ void node_print(Node *n, Proc *p) {
}
} else if (n->type == N_PROJ) {
Str d = type_desc(&n->in.data[0]->val.tuple.data[n->val.i].type, &p->arena);
- printf("\t%d [label=\"%ld | %.*s\", shape=record]", n->id, n->val.i, (int)d.n, d.s);
+ printf("\t%d [label=\"%.*s(%ld)\", shape=record]", n->id, (int)d.n, d.s, n->val.i);
} else {
printf("\t%d [label=\"%s\", shape=record]", n->id, node_type_name(n->type));
}
@@ -431,7 +467,8 @@ void node_print(Node *n, Proc *p) {
void proc_print(Proc *p) {
if (p->start) {
- printf("\t\"%.*s\" -> %d\n", (int)p->name.n, p->name.s, p->start->id);
+ Str d = type_desc(&p->ret_type, &p->arena);
+ printf("\t\"%.*s %.*s\" -> %d\n", (int)p->name.n, p->name.s, (int)d.n, d.s, p->start->id);
node_print(p->start, p);
if (no_opt) {
for (NameBinding *b = p->scope.free_bind; b; b = b->prev) {
diff --git a/test.lang b/test.lang
index 1f33175..24a931d 100644
--- a/test.lang
+++ b/test.lang
@@ -1,6 +1,8 @@
-proc main(a, b i64) {
+func main(a, b i64) i64 {
if a = b {
- b := 3
+ let t = a
+ a := b
+ b := t
}
- return a
+ return a + b
}