From 6ab8b2aab4b146a7d1d882686199eace19989011 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 30 Aug 2019 20:06:02 -0400 Subject: support recursive async and non-async functions which heap allocate their own frames related: #1006 --- test/compile_errors.zig | 2 ++ 1 file changed, 2 insertions(+) (limited to 'test/compile_errors.zig') diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 91916e6f38..a9e99f4799 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -1051,6 +1051,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { \\const Foo = struct {}; \\export fn a() void { \\ const T = [*c]Foo; + \\ var t: T = undefined; \\} , "tmp.zig:3:19: error: C pointers cannot point to non-C-ABI-compatible type 'Foo'", @@ -2290,6 +2291,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { "error union operator with non error set LHS", \\comptime { \\ const z = i32!i32; + \\ var x: z = undefined; \\} , "tmp.zig:2:15: error: expected error set type, found type 'i32'", -- cgit v1.2.3 From 5c3a9a1a3eef82ffad17bc295da05ecccd9006a5 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 31 Aug 2019 18:50:16 -0400 Subject: improvements to `@asyncCall` * `await @asyncCall` generates better code. See #3065 * `@asyncCall` works with a real `@Frame(func)` in addition to a byte slice. Closes #3072 * `@asyncCall` allows passing `{}` (a void value) as the result pointer, which uses the result location inside the frame. Closes #3068 * support `await @asyncCall` on a non-async function. This is in preparation for safe recursion (#1006). --- src/all_types.hpp | 2 + src/analyze.cpp | 4 + src/codegen.cpp | 61 +++++++---- src/ir.cpp | 225 ++++++++++++++++++++++++-------------- test/compile_errors.zig | 16 +++ test/stage1/behavior/async_fn.zig | 105 +++++++++++++++++- 6 files changed, 308 insertions(+), 105 deletions(-) (limited to 'test/compile_errors.zig') diff --git a/src/all_types.hpp b/src/all_types.hpp index aee6d3994f..d9e1dc44ca 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2719,6 +2719,7 @@ struct IrInstructionCallSrc { IrInstruction *new_stack; FnInline fn_inline; bool is_async; + bool is_async_call_builtin; bool is_comptime; }; @@ -2735,6 +2736,7 @@ struct IrInstructionCallGen { IrInstruction *new_stack; FnInline fn_inline; bool is_async; + bool is_async_call_builtin; }; struct IrInstructionConst { diff --git a/src/analyze.cpp b/src/analyze.cpp index df5b27784a..dfdf06aa5a 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5727,6 +5727,10 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) { for (size_t i = 0; i < fn->call_list.length; i += 1) { IrInstructionCallGen *call = fn->call_list.at(i); + if (call->new_stack != nullptr) { + // don't need to allocate a frame for this + continue; + } ZigFn *callee = call->fn_entry; if (callee == nullptr) { add_node_error(g, call->base.source_node, diff --git a/src/codegen.cpp b/src/codegen.cpp index 33713a9b30..890724d950 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3826,17 +3826,18 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr LLVMValueRef awaiter_init_val; LLVMValueRef ret_ptr; if (callee_is_async) { - if (instruction->is_async) { - if (instruction->new_stack == nullptr) { - awaiter_init_val = zero; + if (instruction->new_stack == nullptr) { + if (instruction->is_async) { frame_result_loc = result_loc; - - if (ret_has_bits) { - // Use the result location which is inside the frame if this is an async call. - ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, ""); - } - } else if (cc == CallingConventionAsync) { - awaiter_init_val = zero; + } else { + frame_result_loc = ir_llvm_value(g, instruction->frame_result_loc); + } + } else { + if (instruction->new_stack->value.type->id == ZigTypeIdPointer && + instruction->new_stack->value.type->data.pointer.child_type->id == ZigTypeIdFnFrame) + { + frame_result_loc = ir_llvm_value(g, instruction->new_stack); + } else { LLVMValueRef frame_slice_ptr = ir_llvm_value(g, instruction->new_stack); if (ir_want_runtime_safety(g, &instruction->base)) { LLVMValueRef given_len_ptr = LLVMBuildStructGEP(g->builder, frame_slice_ptr, slice_len_index, ""); @@ -3856,15 +3857,37 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr } LLVMValueRef frame_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_slice_ptr, slice_ptr_index, ""); LLVMValueRef frame_ptr = LLVMBuildLoad(g->builder, frame_ptr_ptr, ""); - frame_result_loc = LLVMBuildBitCast(g->builder, frame_ptr, - get_llvm_type(g, instruction->base.value.type), ""); + if (instruction->fn_entry == nullptr) { + ZigType *anyframe_type = get_any_frame_type(g, src_return_type); + frame_result_loc = LLVMBuildBitCast(g->builder, frame_ptr, get_llvm_type(g, anyframe_type), ""); + } else { + ZigType *ptr_frame_type = get_pointer_to_type(g, + get_fn_frame_type(g, instruction->fn_entry), false); + frame_result_loc = LLVMBuildBitCast(g->builder, frame_ptr, + get_llvm_type(g, ptr_frame_type), ""); + } + } + } + if (instruction->is_async) { + if (instruction->new_stack == nullptr) { + awaiter_init_val = zero; if (ret_has_bits) { - // Use the result location provided to the @asyncCall builtin - ret_ptr = result_loc; + // Use the result location which is inside the frame if this is an async call. + ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, ""); } } else { - zig_unreachable(); + awaiter_init_val = zero; + + if (ret_has_bits) { + if (result_loc != nullptr) { + // Use the result location provided to the @asyncCall builtin + ret_ptr = result_loc; + } else { + // no result location provided to @asyncCall - use the one inside the frame. + ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, ""); + } + } } // even if prefix_arg_err_ret_stack is true, let the async function do its own @@ -3872,7 +3895,6 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr } else { // async function called as a normal function - frame_result_loc = ir_llvm_value(g, instruction->frame_result_loc); awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, ""); // caller's own frame pointer if (ret_has_bits) { if (result_loc == nullptr) { @@ -3988,7 +4010,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr uint32_t arg_start_i = frame_index_arg(g, fn_type->data.fn.fn_type_id.return_type); LLVMValueRef casted_frame; - if (instruction->new_stack != nullptr) { + if (instruction->new_stack != nullptr && instruction->fn_entry == nullptr) { // We need the frame type to be a pointer to a struct that includes the args size_t field_count = arg_start_i + gen_param_values.length; LLVMTypeRef *field_types = allocate_nonzero(field_count); @@ -4014,7 +4036,8 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr if (instruction->is_async) { gen_resume(g, fn_val, frame_result_loc, ResumeIdCall); if (instruction->new_stack != nullptr) { - return frame_result_loc; + return LLVMBuildBitCast(g->builder, frame_result_loc, + get_llvm_type(g, instruction->base.value.type), ""); } return nullptr; } else { @@ -4041,7 +4064,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr } } - if (instruction->new_stack == nullptr) { + if (instruction->new_stack == nullptr || instruction->is_async_call_builtin) { result = ZigLLVMBuildCall(g->builder, fn_val, gen_param_values.items, (unsigned)gen_param_values.length, llvm_cc, fn_inline, ""); } else if (instruction->is_async) { diff --git a/src/ir.cpp b/src/ir.cpp index ad81b27a93..b8a81ba5c9 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1382,7 +1382,7 @@ static IrInstruction *ir_build_union_field_ptr(IrBuilder *irb, Scope *scope, Ast static IrInstruction *ir_build_call_src(IrBuilder *irb, Scope *scope, AstNode *source_node, ZigFn *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args, - bool is_comptime, FnInline fn_inline, bool is_async, + bool is_comptime, FnInline fn_inline, bool is_async, bool is_async_call_builtin, IrInstruction *new_stack, ResultLoc *result_loc) { IrInstructionCallSrc *call_instruction = ir_build_instruction(irb, scope, source_node); @@ -1393,6 +1393,7 @@ static IrInstruction *ir_build_call_src(IrBuilder *irb, Scope *scope, AstNode *s call_instruction->args = args; call_instruction->arg_count = arg_count; call_instruction->is_async = is_async; + call_instruction->is_async_call_builtin = is_async_call_builtin; call_instruction->new_stack = new_stack; call_instruction->result_loc = result_loc; @@ -1410,7 +1411,7 @@ static IrInstruction *ir_build_call_src(IrBuilder *irb, Scope *scope, AstNode *s static IrInstructionCallGen *ir_build_call_gen(IrAnalyze *ira, IrInstruction *source_instruction, ZigFn *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args, - FnInline fn_inline, bool is_async, IrInstruction *new_stack, + FnInline fn_inline, bool is_async, IrInstruction *new_stack, bool is_async_call_builtin, IrInstruction *result_loc, ZigType *return_type) { IrInstructionCallGen *call_instruction = ir_build_instruction(&ira->new_irb, @@ -1422,6 +1423,7 @@ static IrInstructionCallGen *ir_build_call_gen(IrAnalyze *ira, IrInstruction *so call_instruction->args = args; call_instruction->arg_count = arg_count; call_instruction->is_async = is_async; + call_instruction->is_async_call_builtin = is_async_call_builtin; call_instruction->new_stack = new_stack; call_instruction->result_loc = result_loc; @@ -4351,6 +4353,54 @@ static IrInstruction *ir_gen_this(IrBuilder *irb, Scope *orig_scope, AstNode *no zig_unreachable(); } +static IrInstruction *ir_gen_async_call(IrBuilder *irb, Scope *scope, AstNode *await_node, AstNode *call_node, + LVal lval, ResultLoc *result_loc) +{ + size_t arg_offset = 3; + if (call_node->data.fn_call_expr.params.length < arg_offset) { + add_node_error(irb->codegen, call_node, + buf_sprintf("expected at least %" ZIG_PRI_usize " arguments, found %" ZIG_PRI_usize, + arg_offset, call_node->data.fn_call_expr.params.length)); + return irb->codegen->invalid_instruction; + } + + AstNode *bytes_node = call_node->data.fn_call_expr.params.at(0); + IrInstruction *bytes = ir_gen_node(irb, bytes_node, scope); + if (bytes == irb->codegen->invalid_instruction) + return bytes; + + AstNode *ret_ptr_node = call_node->data.fn_call_expr.params.at(1); + IrInstruction *ret_ptr = ir_gen_node(irb, ret_ptr_node, scope); + if (ret_ptr == irb->codegen->invalid_instruction) + return ret_ptr; + + AstNode *fn_ref_node = call_node->data.fn_call_expr.params.at(2); + IrInstruction *fn_ref = ir_gen_node(irb, fn_ref_node, scope); + if (fn_ref == irb->codegen->invalid_instruction) + return fn_ref; + + size_t arg_count = call_node->data.fn_call_expr.params.length - arg_offset; + + // last "arg" is return pointer + IrInstruction **args = allocate(arg_count + 1); + + for (size_t i = 0; i < arg_count; i += 1) { + AstNode *arg_node = call_node->data.fn_call_expr.params.at(i + arg_offset); + IrInstruction *arg = ir_gen_node(irb, arg_node, scope); + if (arg == irb->codegen->invalid_instruction) + return arg; + args[i] = arg; + } + + args[arg_count] = ret_ptr; + + bool is_async = await_node == nullptr; + bool is_async_call_builtin = true; + IrInstruction *call = ir_build_call_src(irb, scope, call_node, nullptr, fn_ref, arg_count, args, false, + FnInlineAuto, is_async, is_async_call_builtin, bytes, result_loc); + return ir_lval_wrap(irb, scope, call, lval, result_loc); +} + static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNode *node, LVal lval, ResultLoc *result_loc) { @@ -4360,7 +4410,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo Buf *name = fn_ref_expr->data.symbol_expr.symbol; auto entry = irb->codegen->builtin_fn_table.maybe_get(name); - if (!entry) { // new built in not found + if (!entry) { add_node_error(irb->codegen, node, buf_sprintf("invalid builtin function: '%s'", buf_ptr(name))); return irb->codegen->invalid_instruction; @@ -5224,7 +5274,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo FnInline fn_inline = (builtin_fn->id == BuiltinFnIdInlineCall) ? FnInlineAlways : FnInlineNever; IrInstruction *call = ir_build_call_src(irb, scope, node, nullptr, fn_ref, arg_count, args, false, - fn_inline, false, nullptr, result_loc); + fn_inline, false, false, nullptr, result_loc); return ir_lval_wrap(irb, scope, call, lval, result_loc); } case BuiltinFnIdNewStackCall: @@ -5257,53 +5307,11 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo } IrInstruction *call = ir_build_call_src(irb, scope, node, nullptr, fn_ref, arg_count, args, false, - FnInlineAuto, false, new_stack, result_loc); + FnInlineAuto, false, false, new_stack, result_loc); return ir_lval_wrap(irb, scope, call, lval, result_loc); } case BuiltinFnIdAsyncCall: - { - size_t arg_offset = 3; - if (node->data.fn_call_expr.params.length < arg_offset) { - add_node_error(irb->codegen, node, - buf_sprintf("expected at least %" ZIG_PRI_usize " arguments, found %" ZIG_PRI_usize, - arg_offset, node->data.fn_call_expr.params.length)); - return irb->codegen->invalid_instruction; - } - - AstNode *bytes_node = node->data.fn_call_expr.params.at(0); - IrInstruction *bytes = ir_gen_node(irb, bytes_node, scope); - if (bytes == irb->codegen->invalid_instruction) - return bytes; - - AstNode *ret_ptr_node = node->data.fn_call_expr.params.at(1); - IrInstruction *ret_ptr = ir_gen_node(irb, ret_ptr_node, scope); - if (ret_ptr == irb->codegen->invalid_instruction) - return ret_ptr; - - AstNode *fn_ref_node = node->data.fn_call_expr.params.at(2); - IrInstruction *fn_ref = ir_gen_node(irb, fn_ref_node, scope); - if (fn_ref == irb->codegen->invalid_instruction) - return fn_ref; - - size_t arg_count = node->data.fn_call_expr.params.length - arg_offset; - - // last "arg" is return pointer - IrInstruction **args = allocate(arg_count + 1); - - for (size_t i = 0; i < arg_count; i += 1) { - AstNode *arg_node = node->data.fn_call_expr.params.at(i + arg_offset); - IrInstruction *arg = ir_gen_node(irb, arg_node, scope); - if (arg == irb->codegen->invalid_instruction) - return arg; - args[i] = arg; - } - - args[arg_count] = ret_ptr; - - IrInstruction *call = ir_build_call_src(irb, scope, node, nullptr, fn_ref, arg_count, args, false, - FnInlineAuto, true, bytes, result_loc); - return ir_lval_wrap(irb, scope, call, lval, result_loc); - } + return ir_gen_async_call(irb, scope, nullptr, node, lval, result_loc); case BuiltinFnIdTypeId: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); @@ -5607,7 +5615,7 @@ static IrInstruction *ir_gen_fn_call(IrBuilder *irb, Scope *scope, AstNode *node bool is_async = node->data.fn_call_expr.is_async; IrInstruction *fn_call = ir_build_call_src(irb, scope, node, nullptr, fn_ref, arg_count, args, false, - FnInlineAuto, is_async, nullptr, result_loc); + FnInlineAuto, is_async, false, nullptr, result_loc); return ir_lval_wrap(irb, scope, fn_call, lval, result_loc); } @@ -7900,6 +7908,19 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *scope, AstNode *n { assert(node->type == NodeTypeAwaitExpr); + AstNode *expr_node = node->data.await_expr.expr; + if (expr_node->type == NodeTypeFnCallExpr && expr_node->data.fn_call_expr.is_builtin) { + AstNode *fn_ref_expr = expr_node->data.fn_call_expr.fn_ref_expr; + Buf *name = fn_ref_expr->data.symbol_expr.symbol; + auto entry = irb->codegen->builtin_fn_table.maybe_get(name); + if (entry != nullptr) { + BuiltinFnEntry *builtin_fn = entry->value; + if (builtin_fn->id == BuiltinFnIdAsyncCall) { + return ir_gen_async_call(irb, scope, node, expr_node, lval, result_loc); + } + } + } + ZigFn *fn_entry = exec_fn_entry(irb->exec); if (!fn_entry) { add_node_error(irb->codegen, node, buf_sprintf("await outside function definition")); @@ -7915,7 +7936,7 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *scope, AstNode *n return irb->codegen->invalid_instruction; } - IrInstruction *target_inst = ir_gen_node_extra(irb, node->data.await_expr.expr, scope, LValPtr, nullptr); + IrInstruction *target_inst = ir_gen_node_extra(irb, expr_node, scope, LValPtr, nullptr); if (target_inst == irb->codegen->invalid_instruction) return irb->codegen->invalid_instruction; @@ -15244,44 +15265,61 @@ static IrInstruction *ir_analyze_instruction_reset_result(IrAnalyze *ira, IrInst return ir_const_void(ira, &instruction->base); } +static IrInstruction *get_async_call_result_loc(IrAnalyze *ira, IrInstructionCallSrc *call_instruction, + ZigType *fn_ret_type) +{ + ir_assert(call_instruction->is_async_call_builtin, &call_instruction->base); + IrInstruction *ret_ptr_uncasted = call_instruction->args[call_instruction->arg_count]->child; + if (type_is_invalid(ret_ptr_uncasted->value.type)) + return ira->codegen->invalid_instruction; + if (ret_ptr_uncasted->value.type->id == ZigTypeIdVoid) { + // Result location will be inside the async frame. + return nullptr; + } + return ir_implicit_cast(ira, ret_ptr_uncasted, get_pointer_to_type(ira->codegen, fn_ret_type, false)); +} + static IrInstruction *ir_analyze_async_call(IrAnalyze *ira, IrInstructionCallSrc *call_instruction, ZigFn *fn_entry, ZigType *fn_type, IrInstruction *fn_ref, IrInstruction **casted_args, size_t arg_count, IrInstruction *casted_new_stack) { - if (casted_new_stack != nullptr) { - // this is an @asyncCall - + if (fn_entry == nullptr) { if (fn_type->data.fn.fn_type_id.cc != CallingConventionAsync) { ir_add_error(ira, fn_ref, buf_sprintf("expected async function, found '%s'", buf_ptr(&fn_type->name))); return ira->codegen->invalid_instruction; } - - IrInstruction *ret_ptr = call_instruction->args[call_instruction->arg_count]->child; - if (type_is_invalid(ret_ptr->value.type)) + if (casted_new_stack == nullptr) { + ir_add_error(ira, fn_ref, buf_sprintf("function is not comptime-known; @asyncCall required")); + return ira->codegen->invalid_instruction; + } + } + if (casted_new_stack != nullptr) { + ZigType *fn_ret_type = fn_type->data.fn.fn_type_id.return_type; + IrInstruction *ret_ptr = get_async_call_result_loc(ira, call_instruction, fn_ret_type); + if (ret_ptr != nullptr && type_is_invalid(ret_ptr->value.type)) return ira->codegen->invalid_instruction; - ZigType *anyframe_type = get_any_frame_type(ira->codegen, fn_type->data.fn.fn_type_id.return_type); + ZigType *anyframe_type = get_any_frame_type(ira->codegen, fn_ret_type); - IrInstructionCallGen *call_gen = ir_build_call_gen(ira, &call_instruction->base, nullptr, fn_ref, - arg_count, casted_args, FnInlineAuto, true, casted_new_stack, ret_ptr, anyframe_type); + IrInstructionCallGen *call_gen = ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, + arg_count, casted_args, FnInlineAuto, true, casted_new_stack, + call_instruction->is_async_call_builtin, ret_ptr, anyframe_type); return &call_gen->base; - } else if (fn_entry == nullptr) { - ir_add_error(ira, fn_ref, buf_sprintf("function is not comptime-known; @asyncCall required")); - return ira->codegen->invalid_instruction; - } - - ZigType *frame_type = get_fn_frame_type(ira->codegen, fn_entry); - IrInstruction *result_loc = ir_resolve_result(ira, &call_instruction->base, call_instruction->result_loc, - frame_type, nullptr, true, true, false); - if (type_is_invalid(result_loc->value.type) || instr_is_unreachable(result_loc)) { - return result_loc; + } else { + ZigType *frame_type = get_fn_frame_type(ira->codegen, fn_entry); + IrInstruction *result_loc = ir_resolve_result(ira, &call_instruction->base, call_instruction->result_loc, + frame_type, nullptr, true, true, false); + if (type_is_invalid(result_loc->value.type) || instr_is_unreachable(result_loc)) { + return result_loc; + } + result_loc = ir_implicit_cast(ira, result_loc, get_pointer_to_type(ira->codegen, frame_type, false)); + if (type_is_invalid(result_loc->value.type)) + return ira->codegen->invalid_instruction; + return &ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, arg_count, + casted_args, FnInlineAuto, true, casted_new_stack, call_instruction->is_async_call_builtin, + result_loc, frame_type)->base; } - result_loc = ir_implicit_cast(ira, result_loc, get_pointer_to_type(ira->codegen, frame_type, false)); - if (type_is_invalid(result_loc->value.type)) - return ira->codegen->invalid_instruction; - return &ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, arg_count, - casted_args, FnInlineAuto, true, nullptr, result_loc, frame_type)->base; } static bool ir_analyze_fn_call_inline_arg(IrAnalyze *ira, AstNode *fn_proto_node, IrInstruction *arg, Scope **exec_scope, size_t *next_proto_i) @@ -15790,16 +15828,27 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c IrInstruction *casted_new_stack = nullptr; if (call_instruction->new_stack != nullptr) { - ZigType *u8_ptr = get_pointer_to_type_extra(ira->codegen, ira->codegen->builtin_types.entry_u8, - false, false, PtrLenUnknown, target_fn_align(ira->codegen->zig_target), 0, 0, false); - ZigType *u8_slice = get_slice_type(ira->codegen, u8_ptr); IrInstruction *new_stack = call_instruction->new_stack->child; if (type_is_invalid(new_stack->value.type)) return ira->codegen->invalid_instruction; - casted_new_stack = ir_implicit_cast(ira, new_stack, u8_slice); - if (type_is_invalid(casted_new_stack->value.type)) - return ira->codegen->invalid_instruction; + if (call_instruction->is_async_call_builtin && + fn_entry != nullptr && new_stack->value.type->id == ZigTypeIdPointer && + new_stack->value.type->data.pointer.child_type->id == ZigTypeIdFnFrame) + { + ZigType *needed_frame_type = get_pointer_to_type(ira->codegen, + get_fn_frame_type(ira->codegen, fn_entry), false); + casted_new_stack = ir_implicit_cast(ira, new_stack, needed_frame_type); + if (type_is_invalid(casted_new_stack->value.type)) + return ira->codegen->invalid_instruction; + } else { + ZigType *u8_ptr = get_pointer_to_type_extra(ira->codegen, ira->codegen->builtin_types.entry_u8, + false, false, PtrLenUnknown, target_fn_align(ira->codegen->zig_target), 0, 0, false); + ZigType *u8_slice = get_slice_type(ira->codegen, u8_ptr); + casted_new_stack = ir_implicit_cast(ira, new_stack, u8_slice); + if (type_is_invalid(casted_new_stack->value.type)) + return ira->codegen->invalid_instruction; + } } if (fn_type->data.fn.is_generic) { @@ -16010,7 +16059,11 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c FnTypeId *impl_fn_type_id = &impl_fn->type_entry->data.fn.fn_type_id; IrInstruction *result_loc; - if (handle_is_ptr(impl_fn_type_id->return_type)) { + if (call_instruction->is_async_call_builtin) { + result_loc = get_async_call_result_loc(ira, call_instruction, impl_fn_type_id->return_type); + if (result_loc != nullptr && type_is_invalid(result_loc->value.type)) + return ira->codegen->invalid_instruction; + } else if (handle_is_ptr(impl_fn_type_id->return_type)) { result_loc = ir_resolve_result(ira, &call_instruction->base, call_instruction->result_loc, impl_fn_type_id->return_type, nullptr, true, true, false); if (result_loc != nullptr) { @@ -16044,7 +16097,7 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c IrInstructionCallGen *new_call_instruction = ir_build_call_gen(ira, &call_instruction->base, impl_fn, nullptr, impl_param_count, casted_args, fn_inline, - false, casted_new_stack, result_loc, + false, casted_new_stack, call_instruction->is_async_call_builtin, result_loc, impl_fn_type_id->return_type); parent_fn_entry->call_list.append(new_call_instruction); @@ -16167,7 +16220,11 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c } IrInstruction *result_loc; - if (handle_is_ptr(return_type)) { + if (call_instruction->is_async_call_builtin) { + result_loc = get_async_call_result_loc(ira, call_instruction, return_type); + if (result_loc != nullptr && type_is_invalid(result_loc->value.type)) + return ira->codegen->invalid_instruction; + } else if (handle_is_ptr(return_type)) { result_loc = ir_resolve_result(ira, &call_instruction->base, call_instruction->result_loc, return_type, nullptr, true, true, false); if (result_loc != nullptr) { @@ -16185,7 +16242,7 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c IrInstructionCallGen *new_call_instruction = ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, call_param_count, casted_args, fn_inline, false, casted_new_stack, - result_loc, return_type); + call_instruction->is_async_call_builtin, result_loc, return_type); parent_fn_entry->call_list.append(new_call_instruction); return ir_finish_anal(ira, &new_call_instruction->base); } diff --git a/test/compile_errors.zig b/test/compile_errors.zig index a9e99f4799..12f17ec790 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -2,6 +2,22 @@ const tests = @import("tests.zig"); const builtin = @import("builtin"); pub fn addCases(cases: *tests.CompileErrorContext) void { + cases.add( + "wrong type for result ptr to @asyncCall", + \\export fn entry() void { + \\ _ = async amain(); + \\} + \\fn amain() i32 { + \\ var frame: @Frame(foo) = undefined; + \\ return await @asyncCall(&frame, false, foo); + \\} + \\fn foo() i32 { + \\ return 1234; + \\} + , + "tmp.zig:6:37: error: expected type '*i32', found 'bool'", + ); + cases.add( "struct depends on itself via optional field", \\const LhsExpr = struct { diff --git a/test/stage1/behavior/async_fn.zig b/test/stage1/behavior/async_fn.zig index 76a2780737..28a9ade1b3 100644 --- a/test/stage1/behavior/async_fn.zig +++ b/test/stage1/behavior/async_fn.zig @@ -331,8 +331,9 @@ test "async fn with inferred error set" { fn doTheTest() void { var frame: [1]@Frame(middle) = undefined; - var result: anyerror!void = undefined; - _ = @asyncCall(@sliceToBytes(frame[0..]), &result, middle); + var fn_ptr = middle; + var result: @typeOf(fn_ptr).ReturnType.ErrorSet!void = undefined; + _ = @asyncCall(@sliceToBytes(frame[0..]), &result, fn_ptr); resume global_frame; std.testing.expectError(error.Fail, result); } @@ -819,6 +820,34 @@ test "struct parameter to async function is copied to the frame" { } test "cast fn to async fn when it is inferred to be async" { + const S = struct { + var frame: anyframe = undefined; + var ok = false; + + fn doTheTest() void { + var ptr: async fn () i32 = undefined; + ptr = func; + var buf: [100]u8 align(16) = undefined; + var result: i32 = undefined; + const f = @asyncCall(&buf, &result, ptr); + _ = await f; + expect(result == 1234); + ok = true; + } + + fn func() i32 { + suspend { + frame = @frame(); + } + return 1234; + } + }; + _ = async S.doTheTest(); + resume S.frame; + expect(S.ok); +} + +test "cast fn to async fn when it is inferred to be async, awaited directly" { const S = struct { var frame: anyframe = undefined; var ok = false; @@ -919,3 +948,75 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type { } }; } + +test "@asyncCall with comptime-known function, but not awaited directly" { + const S = struct { + var global_frame: anyframe = undefined; + + fn doTheTest() void { + var frame: [1]@Frame(middle) = undefined; + var result: @typeOf(middle).ReturnType.ErrorSet!void = undefined; + _ = @asyncCall(@sliceToBytes(frame[0..]), &result, middle); + resume global_frame; + std.testing.expectError(error.Fail, result); + } + + async fn middle() !void { + var f = async middle2(); + return await f; + } + + fn middle2() !void { + return failing(); + } + + fn failing() !void { + global_frame = @frame(); + suspend; + return error.Fail; + } + }; + S.doTheTest(); +} + +test "@asyncCall with actual frame instead of byte buffer" { + const S = struct { + fn func() i32 { + suspend; + return 1234; + } + }; + var frame: @Frame(S.func) = undefined; + var result: i32 = undefined; + const ptr = @asyncCall(&frame, &result, S.func); + resume ptr; + expect(result == 1234); +} + +test "@asyncCall using the result location inside the frame" { + const S = struct { + async fn simple2(y: *i32) i32 { + defer y.* += 2; + y.* += 1; + suspend; + return 1234; + } + fn getAnswer(f: anyframe->i32, out: *i32) void { + var res = await f; // TODO https://github.com/ziglang/zig/issues/3077 + out.* = res; + } + }; + var data: i32 = 1; + const Foo = struct { + bar: async fn (*i32) i32, + }; + var foo = Foo{ .bar = S.simple2 }; + var bytes: [64]u8 align(16) = undefined; + const f = @asyncCall(&bytes, {}, foo.bar, &data); + comptime expect(@typeOf(f) == anyframe->i32); + expect(data == 2); + resume f; + expect(data == 4); + _ = async S.getAnswer(f, &data); + expect(data == 1234); +} -- cgit v1.2.3