Skip to content

Commit

Permalink
Small code cleanup: combine related things (#318)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Mar 11, 2023
1 parent d64c460 commit ba84433
Show file tree
Hide file tree
Showing 13 changed files with 151 additions and 165 deletions.
2 changes: 1 addition & 1 deletion compare_compilers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ for action in tokenize parse run; do
# The file is skipped, so the two compilers should behave differently
if diff tmp/compare_compilers/compiler_written_in_c.txt tmp/compare_compilers/self_hosted.txt >/dev/null; then
if [ $fix = yes ]; then
delete_line $error_list_file $file
remove_line $error_list_file $file
else
echo " Error: Compilers behave the same even though the file is listed in $error_list_file."
echo " To fix this error, delete the \"$file\" line from $error_list_file (or run again with --fix)."
Expand Down
43 changes: 20 additions & 23 deletions self_hosted/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ class AstCall:

enum AstStatementKind:
ExpressionStatement # Evaluate an expression. Discard the result.
ReturnWithValue
ReturnWithoutValue
Return
If
WhileLoop
ForLoop
Expand All @@ -299,7 +298,7 @@ class AstStatement:
if_statement: AstIfStatement
while_loop: AstConditionAndBody
for_loop: AstForLoop
return_value: AstExpression # AstStatementKind::ReturnWithValue
return_value: AstExpression* # AstStatementKind::Return (can be NULL)
assignment: AstAssignment
var_declaration: AstNameTypeValue # AstStatementKind::DeclareLocalVar

Expand All @@ -308,11 +307,10 @@ class AstStatement:
if self->kind == AstStatementKind::ExpressionStatement:
printf("expression statement\n")
self->expression.print(tp.print_prefix(True))
elif self->kind == AstStatementKind::ReturnWithValue:
printf("return a value\n")
self->return_value.print(tp.print_prefix(True))
elif self->kind == AstStatementKind::ReturnWithoutValue:
printf("return without a value\n")
elif self->kind == AstStatementKind::Return:
printf("return\n")
if self->return_value != NULL:
self->return_value->print(tp.print_prefix(True))
elif self->kind == AstStatementKind::If:
printf("if\n")
self->if_statement.print(tp)
Expand Down Expand Up @@ -353,8 +351,9 @@ class AstStatement:
def free(self) -> void:
if self->kind == AstStatementKind::ExpressionStatement:
self->expression.free()
if self->kind == AstStatementKind::ReturnWithValue:
self->return_value.free()
if self->kind == AstStatementKind::Return and self->return_value != NULL:
self->return_value->free()
free(self->return_value)
if self->kind == AstStatementKind::If:
self->if_statement.free()
if self->kind == AstStatementKind::ForLoop:
Expand Down Expand Up @@ -516,16 +515,14 @@ class AstImport:

enum AstToplevelStatementKind:
Import
FunctionDeclaration
FunctionDefinition
Function
ClassDefinition
GlobalVariableDeclaration

class AstToplevelStatement:
# TODO: union
the_import: AstImport # must be first
decl_signature: AstSignature
funcdef: AstFunctionDef
function: AstFunction
classdef: AstClassDef
global_var: AstNameTypeValue

Expand All @@ -540,12 +537,12 @@ class AstToplevelStatement:
self->the_import.specified_path,
self->the_import.resolved_path,
)
elif self->kind == AstToplevelStatementKind::FunctionDeclaration:
printf("Declare a function: ")
self->decl_signature.print()
elif self->kind == AstToplevelStatementKind::FunctionDefinition:
printf("Define a function: ")
self->funcdef.print()
elif self->kind == AstToplevelStatementKind::Function:
if self->function.body.nstatements == 0:
printf("Declare a function: ")
else:
printf("Define a function: ")
self->function.print()
elif self->kind == AstToplevelStatementKind::ClassDefinition:
printf("Define a ")
self->classdef.print()
Expand Down Expand Up @@ -596,9 +593,9 @@ class AstFile:
self->body[i].free()
free(self->body)

class AstFunctionDef:
class AstFunction:
signature: AstSignature
body: AstBody
body: AstBody # empty body means declaration, otherwise it's a definition

def print(self) -> void:
self->signature.print()
Expand All @@ -613,7 +610,7 @@ class AstClassDef:
name_location: Location
fields: AstNameTypeValue*
nfields: int
methods: AstFunctionDef*
methods: AstFunction*
nmethods: int

def print(self) -> void:
Expand Down
18 changes: 12 additions & 6 deletions self_hosted/create_llvm_ir.jou
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,12 @@ class AstToIR:
def do_statement(self, ast: AstStatement*) -> void:
if ast->kind == AstStatementKind::ExpressionStatement:
self->do_expression(&ast->expression)
elif ast->kind == AstStatementKind::ReturnWithValue:
return_value = self->do_expression(&ast->return_value)
LLVMBuildRet(self->builder, return_value)
elif ast->kind == AstStatementKind::Return:
if ast->return_value != NULL:
return_value = self->do_expression(ast->return_value)
LLVMBuildRet(self->builder, return_value)
else:
LLVMBuildRetVoid(self->builder)
# If more code follows, place it into a new block that never actually runs
self->new_block("after_return")
else:
Expand All @@ -98,14 +101,17 @@ class AstToIR:
self->do_statement(&body->statements[i])

# The function must already be declared.
def define_function(self, funcdef: AstFunctionDef*) -> void:
def define_function(self, funcdef: AstFunction*) -> void:
llvm_func = LLVMGetNamedFunction(self->module, &funcdef->signature.name[0])
assert(llvm_func != NULL)
assert(self->current_function == NULL)
self->current_function = llvm_func

self->new_block("start")
assert(funcdef->body.nstatements > 0) # it is a definition
self->do_body(&funcdef->body)
LLVMBuildUnreachable(self->builder)

self->current_function = NULL


Expand All @@ -124,8 +130,8 @@ def create_llvm_ir(ast: AstFile*, typectx: TypeContext*) -> LLVMModule*:
a2i.declare_function(&typectx->functions[i])

for i = 0; i < ast->body_len; i++:
if ast->body[i].kind == AstToplevelStatementKind::FunctionDefinition:
a2i.define_function(&ast->body[i].funcdef)
if ast->body[i].kind == AstToplevelStatementKind::Function and ast->body[i].function.body.nstatements > 0:
a2i.define_function(&ast->body[i].function)

LLVMDisposeBuilder(a2i.builder)
return module
21 changes: 10 additions & 11 deletions self_hosted/parser.jou
Original file line number Diff line number Diff line change
Expand Up @@ -413,11 +413,10 @@ def parse_oneline_statement(tokens: Token**) -> AstStatement:
result = AstStatement{ location = (*tokens)->location }
if (*tokens)->is_keyword("return"):
++*tokens
if (*tokens)->kind == TokenKind::Newline:
result.kind = AstStatementKind::ReturnWithoutValue
else:
result.kind = AstStatementKind::ReturnWithValue
result.return_value = parse_expression(tokens)
result.kind = AstStatementKind::Return
if (*tokens)->kind != TokenKind::Newline:
result.return_value = malloc(sizeof *result.return_value)
*result.return_value = parse_expression(tokens)
elif (*tokens)->is_keyword("break"):
++*tokens
result.kind = AstStatementKind::Break
Expand Down Expand Up @@ -554,8 +553,8 @@ def parse_body(tokens: Token**) -> AstBody:

return AstBody{ statements = result, nstatements = n }

def parse_funcdef(tokens: Token**) -> AstFunctionDef:
return AstFunctionDef{
def parse_funcdef(tokens: Token**) -> AstFunction:
return AstFunction{
signature = parse_function_signature(tokens),
body = parse_body(tokens),
}
Expand Down Expand Up @@ -594,8 +593,8 @@ def parse_toplevel_node(dest: AstFile*, tokens: Token**, stdlib_path: byte*) ->

elif (*tokens)->is_keyword("def"):
++*tokens
ts.kind = AstToplevelStatementKind::FunctionDefinition
ts.funcdef = parse_funcdef(tokens)
ts.kind = AstToplevelStatementKind::Function
ts.function = parse_funcdef(tokens)

elif (*tokens)->is_keyword("declare"):
++*tokens
Expand All @@ -609,8 +608,8 @@ def parse_toplevel_node(dest: AstFile*, tokens: Token**, stdlib_path: byte*) ->
"a value cannot be given when declaring a global variable",
)
else:
ts.kind = AstToplevelStatementKind::FunctionDeclaration
ts.decl_signature = parse_function_signature(tokens)
ts.kind = AstToplevelStatementKind::Function
ts.function.signature = parse_function_signature(tokens)
eat_newline(tokens)

elif (*tokens)->is_keyword("class"):
Expand Down
1 change: 0 additions & 1 deletion self_hosted/runs_wrong.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ tests/syntax_error/multidot_float.jou
tests/syntax_error/self_outside_class.jou
tests/syntax_error/triple_equals.jou
tests/too_long/name.jou
tests/wrong_type/arg.jou
tests/wrong_type/array_mixed_types.jou
tests/wrong_type/array_mixed_types_ptr.jou
tests/wrong_type/array_to_ptr.jou
Expand Down
67 changes: 35 additions & 32 deletions self_hosted/typecheck.jou
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,16 @@ def typecheck_stage2_signatures_globals_structbodies(ctx: TypeContext*, ast_file

for i = 0; i < ast_file->body_len; i++:
ts = &ast_file->body[i]
# TODO: get rid of declare/define copy pasta
if ts->kind == AstToplevelStatementKind::FunctionDeclaration:
if ts->kind == AstToplevelStatementKind::Function:
# TODO: terrible hack: skip functions that use FILE, such as fopen()
# Will be no longer needed once struct FILE works.
if ts->decl_signature.name[0] == 'f' or strcmp(&ts->decl_signature.name[0], "rewind") == 0:
if ts->function.body.nstatements == 0 and (
ts->function.signature.name[0] == 'f'
or strcmp(&ts->function.signature.name[0], "rewind") == 0
):
continue

sig = handle_signature(ctx, &ts->decl_signature)
ctx->functions = realloc(ctx->functions, sizeof ctx->functions[0] * (ctx->nfunctions + 1))
ctx->functions[ctx->nfunctions++] = sig.copy()
exports = realloc(exports, sizeof exports[0] * (nexports + 1))
exports[nexports++] = ExportSymbol{
kind = ExportSymbolKind::Function,
name = sig.name,
signature = sig,
}
if ts->kind == AstToplevelStatementKind::FunctionDefinition:
sig = handle_signature(ctx, &ts->funcdef.signature)
sig = handle_signature(ctx, &ts->function.signature)
ctx->functions = realloc(ctx->functions, sizeof ctx->functions[0] * (ctx->nfunctions + 1))
ctx->functions[ctx->nfunctions++] = sig.copy()
exports = realloc(exports, sizeof exports[0] * (nexports + 1))
Expand Down Expand Up @@ -324,26 +316,37 @@ def typecheck_statement(ctx: TypeContext*, statement: AstStatement*) -> void:
if statement->kind == AstStatementKind::ExpressionStatement:
typecheck_expression_maybe_void(ctx, &statement->expression)

elif statement->kind == AstStatementKind::ReturnWithValue:
elif statement->kind == AstStatementKind::Return:
name = &ctx->current_function_signature->name[0]

if ctx->current_function_signature->return_type == NULL:
msg: byte[500]
# TODO: check for noreturn functions

return_type = ctx->current_function_signature->return_type
msg: byte[500]

if statement->return_value != NULL and return_type == NULL:
snprintf(&msg[0], sizeof msg, "function '%s' cannot return a value because it was defined with '-> void'", name)
fail(statement->location, name)
if statement->return_value == NULL and return_type != NULL:
snprintf(
&msg[0], sizeof msg,
"function '%s' must return a value because it was defined with '-> %s'",
name, &return_type->name[0])
fail(statement->location, name)

cast_error_msg: byte[500]
snprintf(
&cast_error_msg[0], sizeof cast_error_msg,
"attempting to return a value of type <from> from function '%s' defined with '-> <to>'",
name,
)
typecheck_expression_with_implicit_cast(
ctx,
&statement->return_value,
ctx->current_function_signature->return_type,
&cast_error_msg[0],
)
if statement->return_value != NULL:
cast_error_msg: byte[500]
snprintf(
&cast_error_msg[0], sizeof cast_error_msg,
"attempting to return a value of type <from> from function '%s' defined with '-> <to>'",
name,
)
typecheck_expression_with_implicit_cast(
ctx,
statement->return_value,
return_type,
&cast_error_msg[0],
)

else:
assert(False)
Expand All @@ -355,13 +358,13 @@ def typecheck_body(ctx: TypeContext*, body: AstBody*) -> void:
def typecheck_stage3_function_and_method_bodies(ctx: TypeContext*, ast_file: AstFile*) -> void:
for i = 0; i < ast_file->body_len; i++:
ts = &ast_file->body[i]
if ts->kind != AstToplevelStatementKind::FunctionDefinition:
if ts->kind != AstToplevelStatementKind::Function or ts->function.body.nstatements == 0:
continue

sig = ctx->find_function(&ts->funcdef.signature.name[0])
sig = ctx->find_function(&ts->function.signature.name[0])
assert(sig != NULL)

assert(ctx->current_function_signature == NULL)
ctx->current_function_signature = sig
typecheck_body(ctx, &ts->funcdef.body)
typecheck_body(ctx, &ts->function.body)
ctx->current_function_signature = NULL
22 changes: 10 additions & 12 deletions src/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -830,15 +830,13 @@ static void build_statement(struct State *st, const AstStatement *stmt)
break;
}

case AST_STMT_RETURN_VALUE:
{
const LocalVariable *retvalue = build_expression(st, &stmt->data.expression);
const LocalVariable *retvariable = find_local_var(st, "return");
assert(retvariable);
add_unary_op(st, stmt->location, CF_VARCPY, retvalue, retvariable);
}
__attribute__((fallthrough));
case AST_STMT_RETURN_WITHOUT_VALUE:
case AST_STMT_RETURN:
if (stmt->data.returnvalue) {
const LocalVariable *retvalue = build_expression(st, stmt->data.returnvalue);
const LocalVariable *retvariable = find_local_var(st, "return");
assert(retvariable);
add_unary_op(st, stmt->location, CF_VARCPY, retvalue, retvariable);
}
st->current_block->iftrue = &st->cfg->end_block;
st->current_block->iffalse = &st->cfg->end_block;
st->current_block = add_block(st); // an unreachable block
Expand Down Expand Up @@ -908,8 +906,8 @@ CfGraphFile build_control_flow_graphs(AstToplevelNode *ast, FileTypes *filetypes
struct State st = { .filetypes = filetypes };

while (ast->kind != AST_TOPLEVEL_END_OF_FILE) {
if(ast->kind == AST_TOPLEVEL_DEFINE_FUNCTION) {
CfGraph *g = build_function_or_method(&st, NULL, ast->data.funcdef.signature.name, &ast->data.funcdef.body);
if(ast->kind == AST_TOPLEVEL_FUNCTION && ast->data.function.body.nstatements > 0) {
CfGraph *g = build_function_or_method(&st, NULL, ast->data.function.signature.name, &ast->data.function.body);
Append(&result.graphs, g);
}

Expand All @@ -923,7 +921,7 @@ CfGraphFile build_control_flow_graphs(AstToplevelNode *ast, FileTypes *filetypes
}
assert(classtype);

for (AstFunctionDef *m = ast->data.classdef.methods.ptr; m < End(ast->data.classdef.methods); m++) {
for (AstFunction *m = ast->data.classdef.methods.ptr; m < End(ast->data.classdef.methods); m++) {
CfGraph *g = build_function_or_method(&st, classtype, m->signature.name, &m->body);
Append(&result.graphs, g);
}
Expand Down
Loading

0 comments on commit ba84433

Please sign in to comment.