From a5c2de5fee67e35c8173b7051675d49648086cbb Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 28 Jan 2016 20:26:40 -0700 Subject: ability to specify function type closes #14 --- src/analyze.cpp | 164 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 114 insertions(+), 50 deletions(-) (limited to 'src/analyze.cpp') diff --git a/src/analyze.cpp b/src/analyze.cpp index c3d66a26f8..38b2a8faca 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -451,44 +451,20 @@ static TypeTableEntry *analyze_type_expr(CodeGen *g, ImportTableEntry *import, B return resolve_type(g, *node_ptr); } -static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry, - ImportTableEntry *import) +static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *import, BlockContext *context, + TypeTableEntry *expected_type, AstNode *node, bool is_naked) { assert(node->type == NodeTypeFnProto); AstNodeFnProto *fn_proto = &node->data.fn_proto; if (fn_proto->skip) { - return; + return g->builtin_types.entry_invalid; } TypeTableEntry *fn_type = new_type_table_entry(TypeTableEntryIdFn); - fn_table_entry->type_entry = fn_type; - fn_type->data.fn.calling_convention = fn_table_entry->internal_linkage ? LLVMFastCallConv : LLVMCCallConv; - - for (int i = 0; i < fn_proto->directives->length; i += 1) { - AstNode *directive_node = fn_proto->directives->at(i); - Buf *name = &directive_node->data.directive.name; - - if (buf_eql_str(name, "attribute")) { - Buf *attr_name = &directive_node->data.directive.param; - if (fn_table_entry->fn_def_node) { - if (buf_eql_str(attr_name, "naked")) { - fn_type->data.fn.is_naked = true; - } else if (buf_eql_str(attr_name, "inline")) { - fn_table_entry->is_inline = true; - } else { - add_node_error(g, directive_node, - buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); - } - } else { - add_node_error(g, directive_node, - buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); - } - } else { - add_node_error(g, directive_node, - buf_sprintf("invalid directive: '%s'", buf_ptr(name))); - } - } + fn_type->data.fn.is_extern = fn_proto->is_extern || (fn_proto->visib_mod == VisibModExport); + fn_type->data.fn.is_naked = is_naked; + fn_type->data.fn.calling_convention = fn_proto->is_extern ? LLVMCCallConv : LLVMFastCallConv; int src_param_count = node->data.fn_proto.params.length; fn_type->size_in_bits = g->pointer_size_bytes * 8; @@ -499,10 +475,9 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t // first, analyze the parameters and return type in order they appear in // source code in order for error messages to be in the best order. buf_resize(&fn_type->name, 0); - const char *export_str = fn_table_entry->internal_linkage ? "" : "export "; - const char *inline_str = fn_table_entry->is_inline ? "inline " : ""; + const char *extern_str = fn_type->data.fn.is_extern ? "extern " : ""; const char *naked_str = fn_type->data.fn.is_naked ? "naked " : ""; - buf_appendf(&fn_type->name, "%s%s%sfn(", export_str, inline_str, naked_str); + buf_appendf(&fn_type->name, "%s%sfn(", extern_str, naked_str); for (int i = 0; i < src_param_count; i += 1) { AstNode *child = node->data.fn_proto.params.at(i); assert(child->type == NodeTypeParamDecl); @@ -525,10 +500,9 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t const char *comma = (src_param_count == 0) ? "" : ", "; buf_appendf(&fn_type->name, "%s...", comma); } - buf_appendf(&fn_type->name, ")"); if (return_type->id != TypeTableEntryIdVoid) { - buf_appendf(&fn_type->name, " %s", buf_ptr(&return_type->name)); + buf_appendf(&fn_type->name, " -> %s", buf_ptr(&return_type->name)); } @@ -593,13 +567,12 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t fn_type->data.fn.gen_param_count = gen_param_index; if (fn_proto->skip) { - return; + return g->builtin_types.entry_invalid; } auto table_entry = import->fn_type_table.maybe_get(&fn_type->name); if (table_entry) { - fn_type = table_entry->value; - fn_table_entry->type_entry = fn_type; + return table_entry->value; } else { fn_type->data.fn.raw_type_ref = LLVMFunctionType(gen_return_type->type_ref, gen_param_types, gen_param_index, fn_type->data.fn.is_var_args); @@ -608,8 +581,56 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t param_di_types, gen_param_index + 1, 0); import->fn_type_table.put(&fn_type->name, fn_type); + + return fn_type; + } +} + + +static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry, + ImportTableEntry *import) +{ + assert(node->type == NodeTypeFnProto); + AstNodeFnProto *fn_proto = &node->data.fn_proto; + + if (fn_proto->skip) { + return; } + bool is_naked = false; + for (int i = 0; i < fn_proto->directives->length; i += 1) { + AstNode *directive_node = fn_proto->directives->at(i); + Buf *name = &directive_node->data.directive.name; + + if (buf_eql_str(name, "attribute")) { + Buf *attr_name = &directive_node->data.directive.param; + if (fn_table_entry->fn_def_node) { + if (buf_eql_str(attr_name, "naked")) { + is_naked = true; + } else if (buf_eql_str(attr_name, "inline")) { + fn_table_entry->is_inline = true; + } else { + add_node_error(g, directive_node, + buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); + } + } else { + add_node_error(g, directive_node, + buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); + } + } else { + add_node_error(g, directive_node, + buf_sprintf("invalid directive: '%s'", buf_ptr(name))); + } + } + + TypeTableEntry *fn_type = analyze_fn_proto_type(g, import, import->block_context, nullptr, node, is_naked); + + if (fn_type->id == TypeTableEntryIdInvalid) { + fn_proto->skip = true; + return; + } + + fn_table_entry->type_entry = fn_type; fn_table_entry->fn_value = LLVMAddFunction(g->module, buf_ptr(&fn_table_entry->symbol_name), fn_type->data.fn.raw_type_ref); @@ -624,7 +645,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t LLVMSetLinkage(fn_table_entry->fn_value, fn_table_entry->internal_linkage ? LLVMInternalLinkage : LLVMExternalLinkage); - if (return_type->id == TypeTableEntryIdUnreachable) { + if (fn_type->data.fn.src_return_type->id == TypeTableEntryIdUnreachable) { LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMNoReturnAttribute); } LLVMSetFunctionCallConv(fn_table_entry->fn_value, fn_type->data.fn.calling_convention); @@ -1353,7 +1374,29 @@ static bool types_match_const_cast_only(TypeTableEntry *expected_type, TypeTable if (expected_type->id == TypeTableEntryIdFn && actual_type->id == TypeTableEntryIdFn) { - zig_panic("TODO types_match_const_cast_only for fns"); + if (expected_type->data.fn.is_extern != actual_type->data.fn.is_extern) { + return false; + } + if (expected_type->data.fn.is_naked != actual_type->data.fn.is_naked) { + return false; + } + if (!types_match_const_cast_only(expected_type->data.fn.src_return_type, + actual_type->data.fn.src_return_type)) + { + return false; + } + if (expected_type->data.fn.src_param_count != actual_type->data.fn.src_param_count) { + return false; + } + for (int i = 0; i < expected_type->data.fn.src_param_count; i += 1) { + // note it's reversed for parameters + if (types_match_const_cast_only(actual_type->data.fn.param_types[i], + expected_type->data.fn.param_types[i])) + { + return false; + } + } + return true; } @@ -2902,6 +2945,18 @@ static TypeTableEntry *analyze_array_type(CodeGen *g, ImportTableEntry *import, } } +static TypeTableEntry *analyze_fn_proto_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, + TypeTableEntry *expected_type, AstNode *node) +{ + TypeTableEntry *type_entry = analyze_fn_proto_type(g, import, context, expected_type, node, false); + + if (type_entry->id == TypeTableEntryIdInvalid) { + return type_entry; + } + + return resolve_expr_const_val_as_type(g, node, type_entry); +} + static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { @@ -4240,6 +4295,9 @@ static TypeTableEntry *analyze_expression(CodeGen *g, ImportTableEntry *import, case NodeTypeArrayType: return_type = analyze_array_type(g, import, context, expected_type, node); break; + case NodeTypeFnProto: + return_type = analyze_fn_proto_expr(g, import, context, expected_type, node); + break; case NodeTypeErrorType: return_type = resolve_expr_const_val_as_type(g, node, g->builtin_types.entry_pure_error); break; @@ -4250,7 +4308,6 @@ static TypeTableEntry *analyze_expression(CodeGen *g, ImportTableEntry *import, case NodeTypeSwitchRange: case NodeTypeDirective: case NodeTypeFnDecl: - case NodeTypeFnProto: case NodeTypeParamDecl: case NodeTypeRoot: case NodeTypeRootExportDecl: @@ -4555,13 +4612,23 @@ static void collect_expr_decl_deps(CodeGen *g, ImportTableEntry *import, AstNode collect_expr_decl_deps(g, import, node->data.switch_range.start, decl_node); collect_expr_decl_deps(g, import, node->data.switch_range.end, decl_node); break; - case NodeTypeVariableDeclaration: case NodeTypeFnProto: + // remember that fn proto node is used for function definitions as well + // as types + for (int i = 0; i < node->data.fn_proto.params.length; i += 1) { + AstNode *param = node->data.fn_proto.params.at(i); + collect_expr_decl_deps(g, import, param, decl_node); + } + collect_expr_decl_deps(g, import, node->data.fn_proto.return_type, decl_node); + break; + case NodeTypeParamDecl: + collect_expr_decl_deps(g, import, node->data.param_decl.type, decl_node); + break; + case NodeTypeVariableDeclaration: case NodeTypeRootExportDecl: case NodeTypeFnDef: case NodeTypeRoot: case NodeTypeFnDecl: - case NodeTypeParamDecl: case NodeTypeDirective: case NodeTypeImport: case NodeTypeCImport: @@ -4705,12 +4772,8 @@ static void detect_top_level_decl_deps(CodeGen *g, ImportTableEntry *import, Ast // determine which other top level declarations this function prototype depends on. TopLevelDecl *decl_node = &node->data.fn_proto.top_level_decl; decl_node->deps.init(1); - for (int i = 0; i < node->data.fn_proto.params.length; i += 1) { - AstNode *param_node = node->data.fn_proto.params.at(i); - assert(param_node->type == NodeTypeParamDecl); - collect_expr_decl_deps(g, import, param_node->data.param_decl.type, decl_node); - } - collect_expr_decl_deps(g, import, node->data.fn_proto.return_type, decl_node); + + collect_expr_decl_deps(g, import, node, decl_node); decl_node->name = name; decl_node->import = import; @@ -4999,11 +5062,12 @@ Expr *get_resolved_expr(AstNode *node) { return &node->data.error_type.resolved_expr; case NodeTypeSwitchExpr: return &node->data.switch_expr.resolved_expr; + case NodeTypeFnProto: + return &node->data.fn_proto.resolved_expr; case NodeTypeSwitchProng: case NodeTypeSwitchRange: case NodeTypeRoot: case NodeTypeRootExportDecl: - case NodeTypeFnProto: case NodeTypeFnDef: case NodeTypeFnDecl: case NodeTypeParamDecl: -- cgit v1.2.3