From a49e785e8536acd6d5ff2c6bebf8d9902d2f3620 Mon Sep 17 00:00:00 2001 From: WormHeamer Date: Sun, 3 Aug 2025 21:51:29 -0400 Subject: add booleans and comparison operators --- ir.c | 54 ++++++++++++++++++++++++++++++++++++----- ir.h | 8 ++++++- lex.h | 2 ++ main.c | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++------- test.lang | 8 +------ 5 files changed, 132 insertions(+), 22 deletions(-) diff --git a/ir.c b/ir.c index 0ef19b9..8a2baf5 100644 --- a/ir.c +++ b/ir.c @@ -7,10 +7,17 @@ extern int no_opt; int type_eql(Type *a, Type *b) { if (a->t != b->t) return 0; + if (a->lvl != b->lvl) return 0; if (a->next != b->next) return 0; return a->next ? type_eql(a->next, b->next) : 1; } +int type_base_eql(Type *a, Type *b) { + if (a->t != b->t) return 0; + if (a->next != b->next) return 0; + return a->next ? type_base_eql(a->next, b->next) : 1; +} + /* nodes */ const char *node_type_name(NodeType t) { @@ -24,14 +31,20 @@ const char *node_type_name(NodeType t) { "and", "or", "xor", "lshift", "rshift", "neg", "not", + "equal", + "not-equal", + "less", + "greater", + "less-or-equal", + "greater-or-equal", "value" }; return names[t]; } void node_die(Node *n, Proc *p) { - n->prev_free = p->free_list; - p->free_list = n; + /*n->prev_free = p->free_list; + p->free_list = n;*/ } void node_del_out(Node *n, Node *p) { @@ -135,8 +148,7 @@ Node *node_dedup_lit(Proc *p, Value v) { return NULL; } -Node *node_new_lit_i64(Proc *p, int64_t i) { - Value v = (Value) { { .lvl = T_CONST, .t = T_INT }, { .i = i } }; +Node *node_new_lit(Proc *p, Value v) { Node *t = node_dedup_lit(p, v); if (t) return t; Node *n = node_new(p, N_LIT, p->start); @@ -144,6 +156,14 @@ Node *node_new_lit_i64(Proc *p, int64_t i) { return n; } +Node *node_new_lit_i64(Proc *p, int64_t i) { + return node_new_lit(p, (Value) { { .lvl = T_CONST, .t = T_INT }, { .i = i } }); +} + +Node *node_new_lit_bool(Proc *p, int b) { + return node_new_lit(p, (Value) { { .lvl = T_CONST, .t = T_BOOL }, { .i = b } }); +} + static inline int node_op_communative(NodeType t) { NodeType ops[] = { N_OP_ADD, N_OP_MUL, N_OP_AND, N_OP_XOR, N_OP_OR }; for (unsigned i = 0; i < sizeof ops / sizeof *ops; i++) { @@ -170,8 +190,12 @@ Value node_compute(Node *n, Lexer *l) { } } } - if (lit_type.lvl == T_CONST && lit_type.t == T_INT) { - Value v = { .type = lit_type }; + + if (lit_type.lvl != T_CONST) return n->val; + + Value v = { .type = lit_type }; + + if (lit_type.t == T_INT) { switch (n->type) { case N_OP_NEG: v.i = -in[0]->val.i; break; case N_OP_NOT: v.i = ~in[0]->val.i; break; @@ -189,10 +213,28 @@ Value node_compute(Node *n, Lexer *l) { case N_OP_XOR: v.i = in[0]->val.i ^ in[1]->val.i; break; case N_OP_SHL: v.i = in[0]->val.u << in[1]->val.u; break; case N_OP_SHR: v.i = in[0]->val.u >> in[1]->val.u; break; + case N_CMP_EQL: v.type.t = T_BOOL; v.i = in[0]->val.i == in[1]->val.i; break; + case N_CMP_NEQ: v.type.t = T_BOOL; v.i = in[0]->val.i != in[1]->val.i; break; + case N_CMP_LES: v.type.t = T_BOOL; v.i = in[0]->val.i < in[1]->val.i; break; + case N_CMP_GTR: v.type.t = T_BOOL; v.i = in[0]->val.i > in[1]->val.i; break; + case N_CMP_LTE: v.type.t = T_BOOL; v.i = in[0]->val.i <= in[1]->val.i; break; + case N_CMP_GTE: v.type.t = T_BOOL; v.i = in[0]->val.i >= in[1]->val.i; break; + default: return n->val; + } + return v; + } else if (lit_type.t == T_BOOL) { + switch (n->type) { + case N_OP_NOT: v.i = !in[0]->val.i; break; + case N_CMP_EQL: v.i = in[0]->val.i == in[1]->val.i; break; + case N_CMP_NEQ: v.i = in[0]->val.i != in[1]->val.i; break; + case N_OP_AND: v.i = in[0]->val.i && in[1]->val.i; break; + case N_OP_OR: v.i = in[0]->val.i || in[1]->val.i; break; + case N_OP_XOR: v.i = in[0]->val.i ^ in[1]->val.i; break; default: return n->val; } return v; } + return n->val; } diff --git a/ir.h b/ir.h index 72586f2..432c96c 100644 --- a/ir.h +++ b/ir.h @@ -16,6 +16,7 @@ typedef enum { typedef enum { T_TUPLE, + T_BOOL, T_INT } BaseType; @@ -26,6 +27,7 @@ typedef struct Type { } Type; int type_eql(Type *a, Type *b); +int type_base_eql(Type *a, Type *b); /* nodes */ @@ -39,6 +41,7 @@ typedef enum { N_OP_AND, N_OP_OR, N_OP_XOR, N_OP_SHL, N_OP_SHR, N_OP_NEG, N_OP_NOT, + N_CMP_EQL, N_CMP_NEQ, N_CMP_LES, N_CMP_GTR, N_CMP_LTE, N_CMP_GTE, N_VALUE } NodeType; @@ -103,10 +106,13 @@ void node_remove(Proc *p, Node *src, Node *dest); Node *node_new_empty(Proc *p, NodeType t); Node *node_newv(Proc *p, NodeType t, ...); Node *node_dedup_lit(Proc *p, Value v); -Node *node_new_lit_i64(Proc *p, int64_t i); Value node_compute(Node *n, Lexer *l); Node *node_peephole(Node *n, Proc *p, Lexer *l); +Node *node_new_lit(Proc *p, Value v); +Node *node_new_lit_bool(Proc *p, int b); +Node *node_new_lit_i64(Proc *p, int64_t i); + #define node_new(...) node_newv(__VA_ARGS__, NULL) void proc_init(Proc *proc, Str name); diff --git a/lex.h b/lex.h index a83905c..a9820a1 100644 --- a/lex.h +++ b/lex.h @@ -11,6 +11,8 @@ X(VAR, "var")\ X(CONST, "const")\ X(RETURN, "return")\ + X(TRUE, "true")\ + X(FALSE, "false")\ X(LBRACE, "{")\ X(RBRACE, "}")\ X(LPAREN, "(")\ diff --git a/main.c b/main.c index ee759dd..f0af5e0 100644 --- a/main.c +++ b/main.c @@ -11,7 +11,7 @@ #include "dynarr.h" #include "ir.h" -int no_opt = 1; +int no_opt = 0; typedef struct { DYNARR(Proc) procs; @@ -119,16 +119,58 @@ Proc *parse_proc(Lexer *l, Unit *u) { return proc; } +int type_check(Node *n) { + fprintf(stderr, "::\n"); + for (int i = 0; i < n->in.len; i++) { + fprintf(stderr, "%d: %d/%d\n", i, + n->in.data[i]->val.type.lvl, + n->in.data[i]->val.type.t); + } + switch (n->type) { + case N_OP_NEG: + n->val.type = (Type) { .lvl = T_TOP, .t = T_INT }; + return n->in.data[0]->val.type.t == T_INT; + case N_OP_NOT: + n->val.type = (Type) { .lvl = T_TOP, .t = n->in.data[0]->val.type.t }; + return n->in.data[0]->val.type.t == T_INT || n->in.data[0]->val.type.t == T_BOOL; + case N_OP_ADD: case N_OP_SUB: case N_OP_MUL: case N_OP_DIV: + case N_OP_AND: case N_OP_OR: case N_OP_XOR: + case N_OP_SHL: case N_OP_SHR: + n->val.type = (Type) { .lvl = T_TOP, .t = T_INT }; + return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT; + case N_CMP_LES: case N_CMP_GTR: + case N_CMP_LTE: case N_CMP_GTE: + n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL }; + return n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT; + case N_CMP_EQL: + case N_CMP_NEQ: + n->val.type = (Type) { .lvl = T_TOP, .t = T_BOOL }; + return (n->in.data[0]->val.type.t == T_INT && n->in.data[1]->val.type.t == T_INT) + || (n->in.data[0]->val.type.t == T_BOOL && n->in.data[1]->val.type.t == T_BOOL); + default: + return 1; + } +} + Node *parse_term(Lexer *l, Proc *p) { Node *node = NULL; NodeType op_after = N_START; if (TMASK(l->tok) & (TM_MINUS | TM_PLUS | TM_NOT)) { - switch (l->tok) { - case TOK_MINUS: op_after = N_OP_NEG; break; - case TOK_NOT: op_after = N_OP_NOT; break; - default: break; - } + Token t = l->tok; lex_next(l); + node = parse_term(l, p); + NodeType post_op = N_START; + switch (t) { + case TOK_MINUS: post_op = N_OP_NEG; break; + case TOK_NOT: post_op = N_OP_NOT; break; + default: return node; + } + if (post_op == N_START) return node; + Node *r = node_new(p, post_op, node); + if (!type_check(r)) { + lex_error_at(l, r->src_pos, LE_ERROR, S("type mismatch")); + } + return node_peephole(r, p, l); } if (l->tok == TOK_LPAREN) { lex_next(l); @@ -143,6 +185,9 @@ Node *parse_term(Lexer *l, Proc *p) { lex_error(l, LE_ERROR, S("undeclared identifier")); } lex_next(l); + } else if (TMASK(l->tok) & (TM_TRUE | TM_FALSE)) { + node = node_new_lit_bool(p, l->tok == TOK_TRUE); + lex_next(l); } else { lex_expected(l, TM_LIT_NUM); int64_t val = 0; @@ -167,7 +212,9 @@ Node *parse_term(Lexer *l, Proc *p) { Node *parse_expr(Lexer *l, Proc *p) { LexSpan pos = l->pos; Node *lhs = parse_term(l, p); - if (TMASK(l->tok) & (TM_PLUS | TM_MINUS | TM_ASTERISK | TM_SLASH | TM_NOT | TM_AND | TM_XOR | TM_OR | TM_SHL | TM_SHR)) { + if (TMASK(l->tok) & (TM_PLUS | TM_MINUS | TM_ASTERISK | TM_SLASH + | TM_NOT | TM_AND | TM_XOR | TM_OR | TM_SHL | TM_SHR + | TM_EQL | TM_NEQ | TM_LES | TM_GTR | TM_LTE | TM_GTE)) { Token t = l->tok; lex_next(l); Node *rhs = parse_expr(l, p); @@ -183,12 +230,21 @@ Node *parse_expr(Lexer *l, Proc *p) { case TOK_XOR: nt = N_OP_XOR; break; case TOK_SHL: nt = N_OP_SHL; break; case TOK_SHR: nt = N_OP_SHR; break; + case TOK_EQL: nt = N_CMP_EQL; break; + case TOK_NEQ: nt = N_CMP_NEQ; break; + case TOK_LES: nt = N_CMP_LES; break; + case TOK_GTR: nt = N_CMP_GTR; break; + case TOK_LTE: nt = N_CMP_LTE; break; + case TOK_GTE: nt = N_CMP_GTE; break; default: break; } lhs = node_new(p, nt, lhs, rhs); } Node *n = node_peephole(lhs, p, l); n->src_pos = (LexSpan) { pos.ofs, l->pos.ofs - pos.ofs }; + if (!type_check(n)) { + lex_error_at(l, n->src_pos, LE_ERROR, S("type mismatch")); + } return n; } @@ -219,7 +275,17 @@ void parse_unit(Lexer *l) { void node_print(Node *n) { if (n->type == N_LIT) { - printf("\t%d [label=\"%ld\"]\n", n->id, n->val.i); + switch (n->val.type.t) { + case T_INT: + printf("\t%d [label=\"%ld\"]\n", n->id, n->val.i); + break; + case T_BOOL: + printf("\t%d [label=\"%s\"]\n", n->id, n->val.i ? "true" : "false"); + break; + default: + printf("\t%d [label=\"literal %d\"]\n", n->id, n->id); + break; + } } else { printf("\t%d [label=\"%s\", shape=record]\n", n->id, node_type_name(n->type)); } diff --git a/test.lang b/test.lang index a0a325f..a1c1183 100644 --- a/test.lang +++ b/test.lang @@ -9,11 +9,5 @@ // also single-line now proc main { - let a = 3 << 2 - let b = 1 << a - let c = b xor 2381 - b := b + c - a := a + 12 - b := b + 27 - return (a + b) * c + return (5 > 4) = (4 < 5) } -- cgit v1.2.3