From d1a98ccff481183d7fc53e45a902ef273c3d6aeb Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 7 Sep 2019 00:12:15 -0400 Subject: implement spills when expressions used across suspend points closes #3077 --- src/all_types.hpp | 23 ++++++++ src/analyze.cpp | 159 ++++++++++++++++++++++++++++++++++++++++++++++++++---- src/analyze.hpp | 2 +- src/codegen.cpp | 17 ++++-- src/ir.cpp | 12 ++++- 5 files changed, 197 insertions(+), 16 deletions(-) (limited to 'src') diff --git a/src/all_types.hpp b/src/all_types.hpp index afe8bd0675..8ba3e4f484 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2124,6 +2124,7 @@ enum ScopeId { ScopeIdCompTime, ScopeIdRuntime, ScopeIdTypeOf, + ScopeIdExpr, }; struct Scope { @@ -2271,6 +2272,24 @@ struct ScopeTypeOf { Scope base; }; +enum MemoizedBool { + MemoizedBoolUnknown, + MemoizedBoolFalse, + MemoizedBoolTrue, +}; + +// This scope is created for each expression. +// It's used to identify when an instruction needs to be spilled, +// so that it can be accessed after a suspend point. +struct ScopeExpr { + Scope base; + + ScopeExpr **children_ptr; + size_t children_len; + + MemoizedBool need_spill; +}; + // synchronized with code in define_builtin_compile_vars enum AtomicOrder { AtomicOrderUnordered, @@ -2510,6 +2529,10 @@ struct IrInstruction { // with this child field. IrInstruction *child; IrBasicBlock *owner_bb; + // Nearly any instruction can have to be stored as a local variable before suspending + // and then loaded after resuming, in case there is an expression with a suspend point + // in it, such as: x + await y + IrInstruction *spill; IrInstructionId id; // true if this instruction was generated by zig and not from user code bool is_gen; diff --git a/src/analyze.cpp b/src/analyze.cpp index c7da620428..bbb5b7192b 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -96,6 +96,30 @@ static ScopeDecls **get_container_scope_ptr(ZigType *type_entry) { zig_unreachable(); } +static ScopeExpr *find_expr_scope(Scope *scope) { + for (;;) { + switch (scope->id) { + case ScopeIdExpr: + return reinterpret_cast(scope); + case ScopeIdDefer: + case ScopeIdDeferExpr: + case ScopeIdDecls: + case ScopeIdFnDef: + case ScopeIdCompTime: + case ScopeIdVarDecl: + case ScopeIdCImport: + case ScopeIdSuspend: + case ScopeIdTypeOf: + case ScopeIdBlock: + return nullptr; + case ScopeIdLoop: + case ScopeIdRuntime: + scope = scope->parent; + continue; + } + } +} + ScopeDecls *get_container_scope(ZigType *type_entry) { return *get_container_scope_ptr(type_entry); } @@ -203,6 +227,20 @@ Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent) { return &scope->base; } +Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) { + ScopeExpr *scope = allocate(1); + init_scope(g, &scope->base, ScopeIdExpr, node, parent); + ScopeExpr *parent_expr = find_expr_scope(parent); + if (parent_expr != nullptr) { + size_t new_len = parent_expr->children_len + 1; + parent_expr->children_ptr = reallocate_nonzero( + parent_expr->children_ptr, parent_expr->children_len, new_len); + parent_expr->children_ptr[parent_expr->children_len] = scope; + parent_expr->children_len = new_len; + } + return &scope->base; +} + ZigType *get_scope_import(Scope *scope) { while (scope) { if (scope->id == ScopeIdDecls) { @@ -5654,6 +5692,69 @@ static ZigType *get_async_fn_type(CodeGen *g, ZigType *orig_fn_type) { return fn_type; } +// Traverse up to the very top ExprScope, which has children. +// We have just arrived at the top from a child. That child, +// and its next siblings, do not need to be marked. But the previous +// siblings do. +// x + (await y) +// vs +// (await y) + x +static void mark_suspension_point(Scope *scope) { + ScopeExpr *child_expr_scope = (scope->id == ScopeIdExpr) ? reinterpret_cast(scope) : nullptr; + for (;;) { + scope = scope->parent; + switch (scope->id) { + case ScopeIdDefer: + case ScopeIdDeferExpr: + case ScopeIdDecls: + case ScopeIdFnDef: + case ScopeIdCompTime: + case ScopeIdVarDecl: + case ScopeIdCImport: + case ScopeIdSuspend: + case ScopeIdTypeOf: + case ScopeIdBlock: + return; + case ScopeIdLoop: + case ScopeIdRuntime: + continue; + case ScopeIdExpr: { + ScopeExpr *parent_expr_scope = reinterpret_cast(scope); + if (child_expr_scope != nullptr) { + for (size_t i = 0; parent_expr_scope->children_ptr[i] != child_expr_scope; i += 1) { + assert(i < parent_expr_scope->children_len); + parent_expr_scope->children_ptr[i]->need_spill = MemoizedBoolTrue; + } + } + parent_expr_scope->need_spill = MemoizedBoolTrue; + child_expr_scope = parent_expr_scope; + continue; + } + } + } +} + +static bool scope_needs_spill(Scope *scope) { + ScopeExpr *scope_expr = find_expr_scope(scope); + if (scope_expr == nullptr) return false; + + switch (scope_expr->need_spill) { + case MemoizedBoolUnknown: + if (scope_needs_spill(scope_expr->base.parent)) { + scope_expr->need_spill = MemoizedBoolTrue; + return true; + } else { + scope_expr->need_spill = MemoizedBoolFalse; + return false; + } + case MemoizedBoolFalse: + return false; + case MemoizedBoolTrue: + return true; + } + zig_unreachable(); +} + static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) { Error err; @@ -5786,21 +5887,17 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) { callee_frame_type, ""); } // Since this frame is async, an await might represent a suspend point, and - // therefore need to spill. + // therefore need to spill. It also needs to mark expr scopes as having to spill. + // For example: foo() + await z + // The funtion call result of foo() must be spilled. for (size_t i = 0; i < fn->await_list.length; i += 1) { IrInstructionAwaitGen *await = fn->await_list.at(i); - // TODO If this is a noasync await, it doesn't need to spill + // TODO If this is a noasync await, it doesn't suspend // https://github.com/ziglang/zig/issues/3157 - if (await->result_loc != nullptr) { - // If there's a result location, that is the spill + if (await->base.value.special != ConstValSpecialRuntime) { + // Known at comptime. No spill, no suspend. continue; } - if (!type_has_bits(await->base.value.type)) - continue; - if (await->base.value.special != ConstValSpecialRuntime) - continue; - if (await->base.ref_count == 0) - continue; if (await->target_fn != nullptr) { // we might not need to suspend analyze_fn_async(g, await->target_fn, false); @@ -5809,13 +5906,53 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) { return ErrorSemanticAnalyzeFail; } if (!fn_is_async(await->target_fn)) { - // This await does not represent a suspend point. No spill needed. + // This await does not represent a suspend point. No spill needed, + // and no need to mark ExprScope. continue; } } + // This await is a suspend point, but it might not need a spill. + // We do need to mark the ExprScope as having a suspend point in it. + mark_suspension_point(await->base.scope); + + if (await->result_loc != nullptr) { + // If there's a result location, that is the spill + continue; + } + if (await->base.ref_count == 0) + continue; + if (!type_has_bits(await->base.value.type)) + continue; await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn, await->base.value.type, ""); } + // Now that we've marked all the expr scopes that have to spill, we go over the instructions + // and spill the relevant ones. + for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) { + IrBasicBlock *block = fn->analyzed_executable.basic_block_list.at(block_i); + for (size_t instr_i = 0; instr_i < block->instruction_list.length; instr_i += 1) { + IrInstruction *instruction = block->instruction_list.at(instr_i); + if (instruction->id == IrInstructionIdAwaitGen || + instruction->id == IrInstructionIdVarPtr || + instruction->id == IrInstructionIdDeclRef || + instruction->id == IrInstructionIdAllocaGen) + { + // This instruction does its own spilling specially, or otherwise doesn't need it. + continue; + } + if (instruction->value.special != ConstValSpecialRuntime) + continue; + if (instruction->ref_count == 0) + continue; + if (!type_has_bits(instruction->value.type)) + continue; + if (scope_needs_spill(instruction->scope)) { + instruction->spill = ir_create_alloca(g, instruction->scope, instruction->source_node, + fn, instruction->value.type, ""); + } + } + } + FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false); diff --git a/src/analyze.hpp b/src/analyze.hpp index 2178327571..55bf9aba30 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -114,6 +114,7 @@ ScopeFnDef *create_fndef_scope(CodeGen *g, AstNode *node, Scope *parent, ZigFn * Scope *create_comptime_scope(CodeGen *g, AstNode *node, Scope *parent); Scope *create_runtime_scope(CodeGen *g, AstNode *node, Scope *parent, IrInstruction *is_comptime); Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent); +Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent); void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str); ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str); @@ -261,5 +262,4 @@ void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn); IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn, ZigType *var_type, const char *name_hint); - #endif diff --git a/src/codegen.cpp b/src/codegen.cpp index bbb1d9fc87..134569374e 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -649,6 +649,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) { case ScopeIdCompTime: case ScopeIdRuntime: case ScopeIdTypeOf: + case ScopeIdExpr: return get_di_scope(g, scope->parent); } zig_unreachable(); @@ -1644,7 +1645,6 @@ static void gen_assign_raw(CodeGen *g, LLVMValueRef ptr, ZigType *ptr_type, LLVMValueRef ored_value = LLVMBuildOr(g->builder, shifted_value, anded_containing_int, ""); gen_store(g, ored_value, ptr, ptr_type); - return; } static void gen_var_debug_decl(CodeGen *g, ZigVar *var) { @@ -1664,11 +1664,16 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) { if (instruction->id == IrInstructionIdAwaitGen) { IrInstructionAwaitGen *await = reinterpret_cast(instruction); if (await->result_loc != nullptr) { - instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc), + return get_handle_value(g, ir_llvm_value(g, await->result_loc), await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type); - return instruction->llvm_value; } } + if (instruction->spill != nullptr) { + ZigType *ptr_type = instruction->spill->value.type; + src_assert(ptr_type->id == ZigTypeIdPointer, instruction->source_node); + return get_handle_value(g, ir_llvm_value(g, instruction->spill), + ptr_type->data.pointer.child_type, instruction->spill->value.type); + } src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node); assert(instruction->value.type); render_const_val(g, &instruction->value, ""); @@ -3786,6 +3791,7 @@ static void render_async_var_decls(CodeGen *g, Scope *scope) { case ScopeIdCompTime: case ScopeIdRuntime: case ScopeIdTypeOf: + case ScopeIdExpr: scope = scope->parent; continue; } @@ -6049,6 +6055,11 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) { set_debug_location(g, instruction); } instruction->llvm_value = ir_render_instruction(g, executable, instruction); + if (instruction->spill != nullptr) { + LLVMValueRef spill_ptr = ir_llvm_value(g, instruction->spill); + gen_assign_raw(g, spill_ptr, instruction->spill->value.type, instruction->llvm_value); + instruction->llvm_value = nullptr; + } } current_block->llvm_exit_block = LLVMGetInsertBlock(g->builder); } diff --git a/src/ir.cpp b/src/ir.cpp index 53ce2d89e1..1a0aad36e9 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -3364,6 +3364,7 @@ static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_sco case ScopeIdCompTime: case ScopeIdRuntime: case ScopeIdTypeOf: + case ScopeIdExpr: scope = scope->parent; continue; case ScopeIdDeferExpr: @@ -3420,6 +3421,7 @@ static bool ir_gen_defers_for_block(IrBuilder *irb, Scope *inner_scope, Scope *o case ScopeIdCompTime: case ScopeIdRuntime: case ScopeIdTypeOf: + case ScopeIdExpr: scope = scope->parent; continue; case ScopeIdDeferExpr: @@ -8158,7 +8160,15 @@ static IrInstruction *ir_gen_node_extra(IrBuilder *irb, AstNode *node, Scope *sc result_loc = no_result_loc(); ir_build_reset_result(irb, scope, node, result_loc); } - IrInstruction *result = ir_gen_node_raw(irb, node, scope, lval, result_loc); + Scope *child_scope; + if (irb->exec->is_inline || + (irb->exec->fn_entry != nullptr && irb->exec->fn_entry->child_scope == scope)) + { + child_scope = scope; + } else { + child_scope = create_expr_scope(irb->codegen, node, scope); + } + IrInstruction *result = ir_gen_node_raw(irb, node, child_scope, lval, result_loc); if (result == irb->codegen->invalid_instruction) { if (irb->exec->first_err_trace_msg == nullptr) { irb->exec->first_err_trace_msg = irb->codegen->trace_err; -- cgit v1.2.3