diff --git a/compare_compilers.sh b/compare_compilers.sh index 25a48289..41c20604 100755 --- a/compare_compilers.sh +++ b/compare_compilers.sh @@ -97,7 +97,7 @@ for action in tokenize parse run; do # The file is skipped, so the two compilers should behave differently if diff tmp/compare_compilers/compiler_written_in_c.txt tmp/compare_compilers/self_hosted.txt >/dev/null; then if [ $fix = yes ]; then - delete_line $error_list_file $file + remove_line $error_list_file $file else echo " Error: Compilers behave the same even though the file is listed in $error_list_file." echo " To fix this error, delete the \"$file\" line from $error_list_file (or run again with --fix)." diff --git a/self_hosted/ast.jou b/self_hosted/ast.jou index 35602473..42649b62 100644 --- a/self_hosted/ast.jou +++ b/self_hosted/ast.jou @@ -275,8 +275,7 @@ class AstCall: enum AstStatementKind: ExpressionStatement # Evaluate an expression. Discard the result. - ReturnWithValue - ReturnWithoutValue + Return If WhileLoop ForLoop @@ -299,7 +298,7 @@ class AstStatement: if_statement: AstIfStatement while_loop: AstConditionAndBody for_loop: AstForLoop - return_value: AstExpression # AstStatementKind::ReturnWithValue + return_value: AstExpression* # AstStatementKind::Return (can be NULL) assignment: AstAssignment var_declaration: AstNameTypeValue # AstStatementKind::DeclareLocalVar @@ -308,11 +307,10 @@ class AstStatement: if self->kind == AstStatementKind::ExpressionStatement: printf("expression statement\n") self->expression.print(tp.print_prefix(True)) - elif self->kind == AstStatementKind::ReturnWithValue: - printf("return a value\n") - self->return_value.print(tp.print_prefix(True)) - elif self->kind == AstStatementKind::ReturnWithoutValue: - printf("return without a value\n") + elif self->kind == AstStatementKind::Return: + printf("return\n") + if self->return_value != NULL: + self->return_value->print(tp.print_prefix(True)) elif self->kind == AstStatementKind::If: printf("if\n") self->if_statement.print(tp) @@ -353,8 +351,9 @@ class AstStatement: def free(self) -> void: if self->kind == AstStatementKind::ExpressionStatement: self->expression.free() - if self->kind == AstStatementKind::ReturnWithValue: - self->return_value.free() + if self->kind == AstStatementKind::Return and self->return_value != NULL: + self->return_value->free() + free(self->return_value) if self->kind == AstStatementKind::If: self->if_statement.free() if self->kind == AstStatementKind::ForLoop: @@ -516,16 +515,14 @@ class AstImport: enum AstToplevelStatementKind: Import - FunctionDeclaration - FunctionDefinition + Function ClassDefinition GlobalVariableDeclaration class AstToplevelStatement: # TODO: union the_import: AstImport # must be first - decl_signature: AstSignature - funcdef: AstFunctionDef + function: AstFunction classdef: AstClassDef global_var: AstNameTypeValue @@ -540,12 +537,12 @@ class AstToplevelStatement: self->the_import.specified_path, self->the_import.resolved_path, ) - elif self->kind == AstToplevelStatementKind::FunctionDeclaration: - printf("Declare a function: ") - self->decl_signature.print() - elif self->kind == AstToplevelStatementKind::FunctionDefinition: - printf("Define a function: ") - self->funcdef.print() + elif self->kind == AstToplevelStatementKind::Function: + if self->function.body.nstatements == 0: + printf("Declare a function: ") + else: + printf("Define a function: ") + self->function.print() elif self->kind == AstToplevelStatementKind::ClassDefinition: printf("Define a ") self->classdef.print() @@ -596,9 +593,9 @@ class AstFile: self->body[i].free() free(self->body) -class AstFunctionDef: +class AstFunction: signature: AstSignature - body: AstBody + body: AstBody # empty body means declaration, otherwise it's a definition def print(self) -> void: self->signature.print() @@ -613,7 +610,7 @@ class AstClassDef: name_location: Location fields: AstNameTypeValue* nfields: int - methods: AstFunctionDef* + methods: AstFunction* nmethods: int def print(self) -> void: diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index 826f92ad..7b257906 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -84,9 +84,12 @@ class AstToIR: def do_statement(self, ast: AstStatement*) -> void: if ast->kind == AstStatementKind::ExpressionStatement: self->do_expression(&ast->expression) - elif ast->kind == AstStatementKind::ReturnWithValue: - return_value = self->do_expression(&ast->return_value) - LLVMBuildRet(self->builder, return_value) + elif ast->kind == AstStatementKind::Return: + if ast->return_value != NULL: + return_value = self->do_expression(ast->return_value) + LLVMBuildRet(self->builder, return_value) + else: + LLVMBuildRetVoid(self->builder) # If more code follows, place it into a new block that never actually runs self->new_block("after_return") else: @@ -98,14 +101,17 @@ class AstToIR: self->do_statement(&body->statements[i]) # The function must already be declared. - def define_function(self, funcdef: AstFunctionDef*) -> void: + def define_function(self, funcdef: AstFunction*) -> void: llvm_func = LLVMGetNamedFunction(self->module, &funcdef->signature.name[0]) assert(llvm_func != NULL) assert(self->current_function == NULL) self->current_function = llvm_func + self->new_block("start") + assert(funcdef->body.nstatements > 0) # it is a definition self->do_body(&funcdef->body) LLVMBuildUnreachable(self->builder) + self->current_function = NULL @@ -124,8 +130,8 @@ def create_llvm_ir(ast: AstFile*, typectx: TypeContext*) -> LLVMModule*: a2i.declare_function(&typectx->functions[i]) for i = 0; i < ast->body_len; i++: - if ast->body[i].kind == AstToplevelStatementKind::FunctionDefinition: - a2i.define_function(&ast->body[i].funcdef) + if ast->body[i].kind == AstToplevelStatementKind::Function and ast->body[i].function.body.nstatements > 0: + a2i.define_function(&ast->body[i].function) LLVMDisposeBuilder(a2i.builder) return module diff --git a/self_hosted/parser.jou b/self_hosted/parser.jou index 1fad79be..d266d7cd 100644 --- a/self_hosted/parser.jou +++ b/self_hosted/parser.jou @@ -413,11 +413,10 @@ def parse_oneline_statement(tokens: Token**) -> AstStatement: result = AstStatement{ location = (*tokens)->location } if (*tokens)->is_keyword("return"): ++*tokens - if (*tokens)->kind == TokenKind::Newline: - result.kind = AstStatementKind::ReturnWithoutValue - else: - result.kind = AstStatementKind::ReturnWithValue - result.return_value = parse_expression(tokens) + result.kind = AstStatementKind::Return + if (*tokens)->kind != TokenKind::Newline: + result.return_value = malloc(sizeof *result.return_value) + *result.return_value = parse_expression(tokens) elif (*tokens)->is_keyword("break"): ++*tokens result.kind = AstStatementKind::Break @@ -554,8 +553,8 @@ def parse_body(tokens: Token**) -> AstBody: return AstBody{ statements = result, nstatements = n } -def parse_funcdef(tokens: Token**) -> AstFunctionDef: - return AstFunctionDef{ +def parse_funcdef(tokens: Token**) -> AstFunction: + return AstFunction{ signature = parse_function_signature(tokens), body = parse_body(tokens), } @@ -594,8 +593,8 @@ def parse_toplevel_node(dest: AstFile*, tokens: Token**, stdlib_path: byte*) -> elif (*tokens)->is_keyword("def"): ++*tokens - ts.kind = AstToplevelStatementKind::FunctionDefinition - ts.funcdef = parse_funcdef(tokens) + ts.kind = AstToplevelStatementKind::Function + ts.function = parse_funcdef(tokens) elif (*tokens)->is_keyword("declare"): ++*tokens @@ -609,8 +608,8 @@ def parse_toplevel_node(dest: AstFile*, tokens: Token**, stdlib_path: byte*) -> "a value cannot be given when declaring a global variable", ) else: - ts.kind = AstToplevelStatementKind::FunctionDeclaration - ts.decl_signature = parse_function_signature(tokens) + ts.kind = AstToplevelStatementKind::Function + ts.function.signature = parse_function_signature(tokens) eat_newline(tokens) elif (*tokens)->is_keyword("class"): diff --git a/self_hosted/runs_wrong.txt b/self_hosted/runs_wrong.txt index c99ca56c..a6a1bbbe 100644 --- a/self_hosted/runs_wrong.txt +++ b/self_hosted/runs_wrong.txt @@ -108,7 +108,6 @@ tests/syntax_error/multidot_float.jou tests/syntax_error/self_outside_class.jou tests/syntax_error/triple_equals.jou tests/too_long/name.jou -tests/wrong_type/arg.jou tests/wrong_type/array_mixed_types.jou tests/wrong_type/array_mixed_types_ptr.jou tests/wrong_type/array_to_ptr.jou diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index 5b0efbc8..10db8666 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -130,24 +130,16 @@ def typecheck_stage2_signatures_globals_structbodies(ctx: TypeContext*, ast_file for i = 0; i < ast_file->body_len; i++: ts = &ast_file->body[i] - # TODO: get rid of declare/define copy pasta - if ts->kind == AstToplevelStatementKind::FunctionDeclaration: + if ts->kind == AstToplevelStatementKind::Function: # TODO: terrible hack: skip functions that use FILE, such as fopen() # Will be no longer needed once struct FILE works. - if ts->decl_signature.name[0] == 'f' or strcmp(&ts->decl_signature.name[0], "rewind") == 0: + if ts->function.body.nstatements == 0 and ( + ts->function.signature.name[0] == 'f' + or strcmp(&ts->function.signature.name[0], "rewind") == 0 + ): continue - sig = handle_signature(ctx, &ts->decl_signature) - ctx->functions = realloc(ctx->functions, sizeof ctx->functions[0] * (ctx->nfunctions + 1)) - ctx->functions[ctx->nfunctions++] = sig.copy() - exports = realloc(exports, sizeof exports[0] * (nexports + 1)) - exports[nexports++] = ExportSymbol{ - kind = ExportSymbolKind::Function, - name = sig.name, - signature = sig, - } - if ts->kind == AstToplevelStatementKind::FunctionDefinition: - sig = handle_signature(ctx, &ts->funcdef.signature) + sig = handle_signature(ctx, &ts->function.signature) ctx->functions = realloc(ctx->functions, sizeof ctx->functions[0] * (ctx->nfunctions + 1)) ctx->functions[ctx->nfunctions++] = sig.copy() exports = realloc(exports, sizeof exports[0] * (nexports + 1)) @@ -324,26 +316,37 @@ def typecheck_statement(ctx: TypeContext*, statement: AstStatement*) -> void: if statement->kind == AstStatementKind::ExpressionStatement: typecheck_expression_maybe_void(ctx, &statement->expression) - elif statement->kind == AstStatementKind::ReturnWithValue: + elif statement->kind == AstStatementKind::Return: name = &ctx->current_function_signature->name[0] - if ctx->current_function_signature->return_type == NULL: - msg: byte[500] + # TODO: check for noreturn functions + + return_type = ctx->current_function_signature->return_type + msg: byte[500] + + if statement->return_value != NULL and return_type == NULL: snprintf(&msg[0], sizeof msg, "function '%s' cannot return a value because it was defined with '-> void'", name) fail(statement->location, name) + if statement->return_value == NULL and return_type != NULL: + snprintf( + &msg[0], sizeof msg, + "function '%s' must return a value because it was defined with '-> %s'", + name, &return_type->name[0]) + fail(statement->location, name) - cast_error_msg: byte[500] - snprintf( - &cast_error_msg[0], sizeof cast_error_msg, - "attempting to return a value of type from function '%s' defined with '-> '", - name, - ) - typecheck_expression_with_implicit_cast( - ctx, - &statement->return_value, - ctx->current_function_signature->return_type, - &cast_error_msg[0], - ) + if statement->return_value != NULL: + cast_error_msg: byte[500] + snprintf( + &cast_error_msg[0], sizeof cast_error_msg, + "attempting to return a value of type from function '%s' defined with '-> '", + name, + ) + typecheck_expression_with_implicit_cast( + ctx, + statement->return_value, + return_type, + &cast_error_msg[0], + ) else: assert(False) @@ -355,13 +358,13 @@ def typecheck_body(ctx: TypeContext*, body: AstBody*) -> void: def typecheck_stage3_function_and_method_bodies(ctx: TypeContext*, ast_file: AstFile*) -> void: for i = 0; i < ast_file->body_len; i++: ts = &ast_file->body[i] - if ts->kind != AstToplevelStatementKind::FunctionDefinition: + if ts->kind != AstToplevelStatementKind::Function or ts->function.body.nstatements == 0: continue - sig = ctx->find_function(&ts->funcdef.signature.name[0]) + sig = ctx->find_function(&ts->function.signature.name[0]) assert(sig != NULL) assert(ctx->current_function_signature == NULL) ctx->current_function_signature = sig - typecheck_body(ctx, &ts->funcdef.body) + typecheck_body(ctx, &ts->function.body) ctx->current_function_signature = NULL diff --git a/src/build_cfg.c b/src/build_cfg.c index f9785e38..e61f43a5 100644 --- a/src/build_cfg.c +++ b/src/build_cfg.c @@ -830,15 +830,13 @@ static void build_statement(struct State *st, const AstStatement *stmt) break; } - case AST_STMT_RETURN_VALUE: - { - const LocalVariable *retvalue = build_expression(st, &stmt->data.expression); - const LocalVariable *retvariable = find_local_var(st, "return"); - assert(retvariable); - add_unary_op(st, stmt->location, CF_VARCPY, retvalue, retvariable); - } - __attribute__((fallthrough)); - case AST_STMT_RETURN_WITHOUT_VALUE: + case AST_STMT_RETURN: + if (stmt->data.returnvalue) { + const LocalVariable *retvalue = build_expression(st, stmt->data.returnvalue); + const LocalVariable *retvariable = find_local_var(st, "return"); + assert(retvariable); + add_unary_op(st, stmt->location, CF_VARCPY, retvalue, retvariable); + } st->current_block->iftrue = &st->cfg->end_block; st->current_block->iffalse = &st->cfg->end_block; st->current_block = add_block(st); // an unreachable block @@ -908,8 +906,8 @@ CfGraphFile build_control_flow_graphs(AstToplevelNode *ast, FileTypes *filetypes struct State st = { .filetypes = filetypes }; while (ast->kind != AST_TOPLEVEL_END_OF_FILE) { - if(ast->kind == AST_TOPLEVEL_DEFINE_FUNCTION) { - CfGraph *g = build_function_or_method(&st, NULL, ast->data.funcdef.signature.name, &ast->data.funcdef.body); + if(ast->kind == AST_TOPLEVEL_FUNCTION && ast->data.function.body.nstatements > 0) { + CfGraph *g = build_function_or_method(&st, NULL, ast->data.function.signature.name, &ast->data.function.body); Append(&result.graphs, g); } @@ -923,7 +921,7 @@ CfGraphFile build_control_flow_graphs(AstToplevelNode *ast, FileTypes *filetypes } assert(classtype); - for (AstFunctionDef *m = ast->data.classdef.methods.ptr; m < End(ast->data.classdef.methods); m++) { + for (AstFunction *m = ast->data.classdef.methods.ptr; m < End(ast->data.classdef.methods); m++) { CfGraph *g = build_function_or_method(&st, classtype, m->signature.name, &m->body); Append(&result.graphs, g); } diff --git a/src/free.c b/src/free.c index 6ba3e223..5080fdb1 100644 --- a/src/free.c +++ b/src/free.c @@ -156,9 +156,14 @@ static void free_statement(const AstStatement *stmt) free_ast_body(&stmt->data.forloop.body); break; case AST_STMT_EXPRESSION_STATEMENT: - case AST_STMT_RETURN_VALUE: free_expression(&stmt->data.expression); break; + case AST_STMT_RETURN: + if (stmt->data.returnvalue) { + free_expression(stmt->data.returnvalue); + free(stmt->data.returnvalue); + } + break; case AST_STMT_DECLARE_LOCAL_VAR: free_name_type_value(&stmt->data.vardecl); break; @@ -171,7 +176,6 @@ static void free_statement(const AstStatement *stmt) free_expression(&stmt->data.assignment.target); free_expression(&stmt->data.assignment.value); break; - case AST_STMT_RETURN_WITHOUT_VALUE: case AST_STMT_BREAK: case AST_STMT_CONTINUE: break; @@ -189,12 +193,9 @@ void free_ast(AstToplevelNode *topnodelist) { for (AstToplevelNode *t = topnodelist; t->kind != AST_TOPLEVEL_END_OF_FILE; t++) { switch(t->kind) { - case AST_TOPLEVEL_DECLARE_FUNCTION: - free_ast_signature(&t->data.funcdef.signature); - break; - case AST_TOPLEVEL_DEFINE_FUNCTION: - free_ast_signature(&t->data.funcdef.signature); - free_ast_body(&t->data.funcdef.body); + case AST_TOPLEVEL_FUNCTION: + free_ast_signature(&t->data.function.signature); + free_ast_body(&t->data.function.body); break; case AST_TOPLEVEL_DECLARE_GLOBAL_VARIABLE: case AST_TOPLEVEL_DEFINE_GLOBAL_VARIABLE: @@ -204,7 +205,7 @@ void free_ast(AstToplevelNode *topnodelist) for (const AstNameTypeValue *ntv = t->data.classdef.fields.ptr; ntv < End(t->data.classdef.fields); ntv++) free_name_type_value(ntv); free(t->data.classdef.fields.ptr); - for (const AstFunctionDef *m = t->data.classdef.methods.ptr; m < End(t->data.classdef.methods); m++) { + for (const AstFunction *m = t->data.classdef.methods.ptr; m < End(t->data.classdef.methods); m++) { free_ast_signature(&m->signature); free_ast_body(&m->body); } diff --git a/src/jou_compiler.h b/src/jou_compiler.h index 5f0f3e65..db0cc96c 100644 --- a/src/jou_compiler.h +++ b/src/jou_compiler.h @@ -28,7 +28,7 @@ typedef struct AstNameTypeValue AstNameTypeValue; typedef struct AstIfStatement AstIfStatement; typedef struct AstStatement AstStatement; typedef struct AstToplevelNode AstToplevelNode; -typedef struct AstFunctionDef AstFunctionDef; +typedef struct AstFunction AstFunction; typedef struct AstClassDef AstClassDef; typedef struct AstEnumDef AstEnumDef; typedef struct AstImport AstImport; @@ -269,8 +269,7 @@ struct AstAssignment { struct AstStatement { Location location; enum AstStatementKind { - AST_STMT_RETURN_VALUE, - AST_STMT_RETURN_WITHOUT_VALUE, + AST_STMT_RETURN, AST_STMT_IF, AST_STMT_WHILE, AST_STMT_FOR, @@ -286,7 +285,8 @@ struct AstStatement { AST_STMT_EXPRESSION_STATEMENT, // Evaluate an expression and discard the result. } kind; union { - AstExpression expression; // for AST_STMT_EXPRESSION_STATEMENT, AST_STMT_RETURN + AstExpression expression; // AST_STMT_EXPRESSION_STATEMENT + AstExpression *returnvalue; // AST_STMT_RETURN (can be NULL) AstConditionAndBody whileloop; AstIfStatement ifstatement; AstForLoop forloop; @@ -295,15 +295,15 @@ struct AstStatement { } data; }; -struct AstFunctionDef { +struct AstFunction { AstSignature signature; - AstBody body; + AstBody body; // empty body means declaration, otherwise it's definition }; struct AstClassDef { char name[100]; List(AstNameTypeValue) fields; - List(AstFunctionDef) methods; + List(AstFunction) methods; }; struct AstEnumDef { @@ -323,9 +323,8 @@ struct AstToplevelNode { Location location; enum AstToplevelNodeKind { AST_TOPLEVEL_END_OF_FILE, // indicates end of array of AstToplevelNodeKind - AST_TOPLEVEL_DECLARE_FUNCTION, + AST_TOPLEVEL_FUNCTION, AST_TOPLEVEL_DECLARE_GLOBAL_VARIABLE, - AST_TOPLEVEL_DEFINE_FUNCTION, AST_TOPLEVEL_DEFINE_GLOBAL_VARIABLE, AST_TOPLEVEL_DEFINE_CLASS, AST_TOPLEVEL_DEFINE_ENUM, @@ -333,7 +332,7 @@ struct AstToplevelNode { } kind; union { AstNameTypeValue globalvar; // AST_TOPLEVEL_DECLARE_GLOBAL_VARIABLE - AstFunctionDef funcdef; // AST_TOPLEVEL_DECLARE_FUNCTION, AST_TOPLEVEL_DEFINE_FUNCTION (body is empty for declaring) + AstFunction function; AstClassDef classdef; // AST_TOPLEVEL_DEFINE_CLASS AstEnumDef enumdef; // AST_TOPLEVEL_DEFINE_ENUM AstImport import; // AST_TOPLEVEL_IMPORT diff --git a/src/main.c b/src/main.c index 1a1cab32..1ac21045 100644 --- a/src/main.c +++ b/src/main.c @@ -306,9 +306,8 @@ static char *find_stdlib() static bool astnode_conflicts_with_an_import(const AstToplevelNode *astnode, const ExportSymbol *import) { switch(astnode->kind) { - case AST_TOPLEVEL_DECLARE_FUNCTION: - case AST_TOPLEVEL_DEFINE_FUNCTION: - return import->kind == EXPSYM_FUNCTION && !strcmp(import->name, astnode->data.funcdef.signature.name); + case AST_TOPLEVEL_FUNCTION: + return import->kind == EXPSYM_FUNCTION && !strcmp(import->name, astnode->data.function.signature.name); case AST_TOPLEVEL_DECLARE_GLOBAL_VARIABLE: case AST_TOPLEVEL_DEFINE_GLOBAL_VARIABLE: return import->kind == EXPSYM_GLOBAL_VAR && !strcmp(import->name, astnode->data.globalvar.name); diff --git a/src/parse.c b/src/parse.c index 6c72d00d..0fadada9 100644 --- a/src/parse.c +++ b/src/parse.c @@ -698,11 +698,10 @@ static AstStatement parse_oneline_statement(const Token **tokens) AstStatement result = { .location = (*tokens)->location }; if (is_keyword(*tokens, "return")) { ++*tokens; - if ((*tokens)->type == TOKEN_NEWLINE) { - result.kind = AST_STMT_RETURN_WITHOUT_VALUE; - } else { - result.kind = AST_STMT_RETURN_VALUE; - result.data.expression = parse_expression(tokens); + result.kind = AST_STMT_RETURN; + if ((*tokens)->type != TOKEN_NEWLINE) { + result.data.returnvalue = malloc(sizeof *result.data.returnvalue); + *result.data.returnvalue = parse_expression(tokens); } } else if (is_keyword(*tokens, "break")) { ++*tokens; @@ -791,12 +790,12 @@ static AstBody parse_body(const Token **tokens) return (AstBody){ .statements=result.ptr, .nstatements=result.len }; } -static AstFunctionDef parse_funcdef(const Token **tokens, bool is_method) +static AstFunction parse_funcdef(const Token **tokens, bool is_method) { assert(is_keyword(*tokens, "def")); ++*tokens; - struct AstFunctionDef funcdef = {0}; + struct AstFunction funcdef = {0}; funcdef.signature = parse_function_signature(tokens, is_method); if (funcdef.signature.takes_varargs) { // TODO: support "def foo(x: str, ...)" in some way @@ -918,13 +917,13 @@ static AstToplevelNode parse_toplevel_node(const Token **tokens, const char *std eat_newline(tokens); } else if (is_keyword(*tokens, "def")) { ++*tokens; // skip 'def' keyword - result.kind = AST_TOPLEVEL_DEFINE_FUNCTION; - result.data.funcdef.signature = parse_function_signature(tokens, false); - if (result.data.funcdef.signature.takes_varargs) { + result.kind = AST_TOPLEVEL_FUNCTION; + result.data.function.signature = parse_function_signature(tokens, false); + if (result.data.function.signature.takes_varargs) { // TODO: support "def foo(x: str, ...)" in some way fail_with_error((*tokens)->location, "functions with variadic arguments cannot be defined yet"); } - result.data.funcdef.body = parse_body(tokens); + result.data.function.body = parse_body(tokens); } else if (is_keyword(*tokens, "declare")) { ++*tokens; if (is_keyword(*tokens, "global")) { @@ -937,8 +936,8 @@ static AstToplevelNode parse_toplevel_node(const Token **tokens, const char *std "a value cannot be given when declaring a global variable"); } } else { - result.kind = AST_TOPLEVEL_DECLARE_FUNCTION; - result.data.funcdef.signature = parse_function_signature(tokens, false); + result.kind = AST_TOPLEVEL_FUNCTION; + result.data.function.signature = parse_function_signature(tokens, false); } eat_newline(tokens); } else if (is_keyword(*tokens, "global")) { diff --git a/src/print.c b/src/print.c index 9c0b67d7..9d702faf 100644 --- a/src/print.c +++ b/src/print.c @@ -300,12 +300,10 @@ static void print_ast_statement(const AstStatement *stmt, struct TreePrinter tp) printf("expression statement\n"); print_ast_expression(&stmt->data.expression, print_tree_prefix(tp, true)); break; - case AST_STMT_RETURN_VALUE: - printf("return a value\n"); - print_ast_expression(&stmt->data.expression, print_tree_prefix(tp, true)); - break; - case AST_STMT_RETURN_WITHOUT_VALUE: - printf("return without a value\n"); + case AST_STMT_RETURN: + printf("return\n"); + if (stmt->data.returnvalue) + print_ast_expression(stmt->data.returnvalue, print_tree_prefix(tp, true)); break; case AST_STMT_IF: printf("if\n"); @@ -410,14 +408,10 @@ void print_ast(const AstToplevelNode *topnodelist) print_ast_type(&t->data.globalvar.type); printf("\n"); break; - case AST_TOPLEVEL_DECLARE_FUNCTION: - printf("Declare a function: "); - print_ast_function_signature(&t->data.funcdef.signature); - break; - case AST_TOPLEVEL_DEFINE_FUNCTION: - printf("Define a function: "); - print_ast_function_signature(&t->data.funcdef.signature); - print_ast_body(&t->data.funcdef.body, (struct TreePrinter){0}); + case AST_TOPLEVEL_FUNCTION: + printf("%s a function: ", t->data.function.body.nstatements == 0 ? "Declare" : "Define"); + print_ast_function_signature(&t->data.function.signature); + print_ast_body(&t->data.function.body, (struct TreePrinter){0}); break; case AST_TOPLEVEL_DEFINE_CLASS: printf("Define a class \"%s\" with %d fields and %d methods:\n", @@ -429,7 +423,7 @@ void print_ast(const AstToplevelNode *topnodelist) print_ast_type(&ntv->type); printf("\n"); } - for (const AstFunctionDef *m = t->data.classdef.methods.ptr; m < End(t->data.classdef.methods); m++) { + for (const AstFunction *m = t->data.classdef.methods.ptr; m < End(t->data.classdef.methods); m++) { printf(" method "); print_ast_function_signature(&m->signature); print_ast_body(&m->body, (struct TreePrinter){.prefix=" "}); diff --git a/src/typecheck.c b/src/typecheck.c index 8ec4baaa..92f31a1b 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -246,7 +246,7 @@ static const Type *handle_class_members_stage2(FileTypes *ft, const AstClassDef Append(&type->data.classdata.fields, f); } - for (const AstFunctionDef *m = classdef->methods.ptr; m < End(classdef->methods); m++) { + for (const AstFunction *m = classdef->methods.ptr; m < End(classdef->methods); m++) { // Don't handle the method body yet: that is a part of stage 3, not stage 2 Signature sig = handle_signature(ft, &m->signature, type); Append(&type->data.classdata.methods, sig); @@ -267,10 +267,9 @@ ExportSymbol *typecheck_stage2_signatures_globals_structbodies(FileTypes *ft, co case AST_TOPLEVEL_DEFINE_GLOBAL_VARIABLE: Append(&exports, handle_global_var(ft, &ast->data.globalvar, true)); break; - case AST_TOPLEVEL_DECLARE_FUNCTION: - case AST_TOPLEVEL_DEFINE_FUNCTION: + case AST_TOPLEVEL_FUNCTION: { - Signature sig = handle_signature(ft, &ast->data.funcdef.signature, NULL); + Signature sig = handle_signature(ft, &ast->data.function.signature, NULL); ExportSymbol es = { .kind = EXPSYM_FUNCTION, .data.funcsignature = sig }; safe_strcpy(es.name, sig.name); Append(&exports, es); @@ -1135,44 +1134,37 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt) break; } - case AST_STMT_RETURN_VALUE: - { + case AST_STMT_RETURN: if (ft->current_fom_types->signature.is_noreturn) fail_with_error( stmt->location, "function '%s' cannot return because it was defined with '-> noreturn'", ft->current_fom_types->signature.name); - if(!ft->current_fom_types->signature.returntype){ + const Type *returntype = ft->current_fom_types->signature.returntype; + if (stmt->data.returnvalue && !returntype) { fail_with_error( stmt->location, "function '%s' cannot return a value because it was defined with '-> void'", ft->current_fom_types->signature.name); } - - char msg[200]; - snprintf(msg, sizeof msg, - "attempting to return a value of type FROM from function '%s' defined with '-> TO'", - ft->current_fom_types->signature.name); - typecheck_expression_with_implicit_cast( - ft, &stmt->data.expression, find_local_var(ft, "return")->type, msg); - break; - } - - case AST_STMT_RETURN_WITHOUT_VALUE: - if (ft->current_fom_types->signature.is_noreturn) - fail_with_error( - stmt->location, - "function '%s' cannot return because it was defined with '-> noreturn'", - ft->current_fom_types->signature.name); - - if (ft->current_fom_types->signature.returntype) { + if (returntype && !stmt->data.returnvalue) { fail_with_error( stmt->location, "a return value is needed, because the return type of function '%s' is %s", ft->current_fom_types->signature.name, ft->current_fom_types->signature.returntype->name); } + + if (stmt->data.returnvalue) { + char msg[200]; + snprintf(msg, sizeof msg, + "attempting to return a value of type FROM from function '%s' defined with '-> TO'", + ft->current_fom_types->signature.name); + typecheck_expression_with_implicit_cast( + ft, stmt->data.returnvalue, find_local_var(ft, "return")->type, msg); + } + break; case AST_STMT_DECLARE_LOCAL_VAR: @@ -1215,16 +1207,16 @@ static void typecheck_function_or_method_body(FileTypes *ft, const Signature *si void typecheck_stage3_function_and_method_bodies(FileTypes *ft, const AstToplevelNode *ast) { for (; ast->kind != AST_TOPLEVEL_END_OF_FILE; ast++) { - if (ast->kind == AST_TOPLEVEL_DEFINE_FUNCTION) { + if (ast->kind == AST_TOPLEVEL_FUNCTION && ast->data.function.body.nstatements > 0) { const Signature *sig = NULL; for (struct SignatureAndUsedPtr *f = ft->functions.ptr; f < End(ft->functions); f++) { - if (!strcmp(f->signature.name, ast->data.funcdef.signature.name)) { + if (!strcmp(f->signature.name, ast->data.function.signature.name)) { sig = &f->signature; break; } } assert(sig); - typecheck_function_or_method_body(ft, sig, &ast->data.funcdef.body); + typecheck_function_or_method_body(ft, sig, &ast->data.function.body); } if (ast->kind == AST_TOPLEVEL_DEFINE_CLASS) { @@ -1237,7 +1229,7 @@ void typecheck_stage3_function_and_method_bodies(FileTypes *ft, const AstTopleve } assert(classtype); - for (AstFunctionDef *m = ast->data.classdef.methods.ptr; m < End(ast->data.classdef.methods); m++) { + for (AstFunction *m = ast->data.classdef.methods.ptr; m < End(ast->data.classdef.methods); m++) { Signature *sig = NULL; for (Signature *s = classtype->data.classdata.methods.ptr; s < End(classtype->data.classdata.methods); s++) { if (!strcmp(s->name, m->signature.name)) {