aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2019-09-07 00:12:15 -0400
committerAndrew Kelley <andrew@ziglang.org>2019-09-07 00:13:12 -0400
commitd1a98ccff481183d7fc53e45a902ef273c3d6aeb (patch)
treeb03efbb135bae39fcf6968b505ad67e4c6a33bda /src
parent9ca8d9e21ad657b023c23db5c440fb79a3303771 (diff)
downloadzig-d1a98ccff481183d7fc53e45a902ef273c3d6aeb.tar.gz
zig-d1a98ccff481183d7fc53e45a902ef273c3d6aeb.zip
implement spills when expressions used across suspend points
closes #3077
Diffstat (limited to 'src')
-rw-r--r--src/all_types.hpp23
-rw-r--r--src/analyze.cpp159
-rw-r--r--src/analyze.hpp2
-rw-r--r--src/codegen.cpp17
-rw-r--r--src/ir.cpp12
5 files changed, 197 insertions, 16 deletions
diff --git a/src/all_types.hpp b/src/all_types.hpp
index afe8bd0675..8ba3e4f484 100644
--- a/src/all_types.hpp
+++ b/src/all_types.hpp
@@ -2124,6 +2124,7 @@ enum ScopeId {
ScopeIdCompTime,
ScopeIdRuntime,
ScopeIdTypeOf,
+ ScopeIdExpr,
};
struct Scope {
@@ -2271,6 +2272,24 @@ struct ScopeTypeOf {
Scope base;
};
+enum MemoizedBool {
+ MemoizedBoolUnknown,
+ MemoizedBoolFalse,
+ MemoizedBoolTrue,
+};
+
+// This scope is created for each expression.
+// It's used to identify when an instruction needs to be spilled,
+// so that it can be accessed after a suspend point.
+struct ScopeExpr {
+ Scope base;
+
+ ScopeExpr **children_ptr;
+ size_t children_len;
+
+ MemoizedBool need_spill;
+};
+
// synchronized with code in define_builtin_compile_vars
enum AtomicOrder {
AtomicOrderUnordered,
@@ -2510,6 +2529,10 @@ struct IrInstruction {
// with this child field.
IrInstruction *child;
IrBasicBlock *owner_bb;
+ // Nearly any instruction can have to be stored as a local variable before suspending
+ // and then loaded after resuming, in case there is an expression with a suspend point
+ // in it, such as: x + await y
+ IrInstruction *spill;
IrInstructionId id;
// true if this instruction was generated by zig and not from user code
bool is_gen;
diff --git a/src/analyze.cpp b/src/analyze.cpp
index c7da620428..bbb5b7192b 100644
--- a/src/analyze.cpp
+++ b/src/analyze.cpp
@@ -96,6 +96,30 @@ static ScopeDecls **get_container_scope_ptr(ZigType *type_entry) {
zig_unreachable();
}
+static ScopeExpr *find_expr_scope(Scope *scope) {
+ for (;;) {
+ switch (scope->id) {
+ case ScopeIdExpr:
+ return reinterpret_cast<ScopeExpr *>(scope);
+ case ScopeIdDefer:
+ case ScopeIdDeferExpr:
+ case ScopeIdDecls:
+ case ScopeIdFnDef:
+ case ScopeIdCompTime:
+ case ScopeIdVarDecl:
+ case ScopeIdCImport:
+ case ScopeIdSuspend:
+ case ScopeIdTypeOf:
+ case ScopeIdBlock:
+ return nullptr;
+ case ScopeIdLoop:
+ case ScopeIdRuntime:
+ scope = scope->parent;
+ continue;
+ }
+ }
+}
+
ScopeDecls *get_container_scope(ZigType *type_entry) {
return *get_container_scope_ptr(type_entry);
}
@@ -203,6 +227,20 @@ Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent) {
return &scope->base;
}
+Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
+ ScopeExpr *scope = allocate<ScopeExpr>(1);
+ init_scope(g, &scope->base, ScopeIdExpr, node, parent);
+ ScopeExpr *parent_expr = find_expr_scope(parent);
+ if (parent_expr != nullptr) {
+ size_t new_len = parent_expr->children_len + 1;
+ parent_expr->children_ptr = reallocate_nonzero<ScopeExpr *>(
+ parent_expr->children_ptr, parent_expr->children_len, new_len);
+ parent_expr->children_ptr[parent_expr->children_len] = scope;
+ parent_expr->children_len = new_len;
+ }
+ return &scope->base;
+}
+
ZigType *get_scope_import(Scope *scope) {
while (scope) {
if (scope->id == ScopeIdDecls) {
@@ -5654,6 +5692,69 @@ static ZigType *get_async_fn_type(CodeGen *g, ZigType *orig_fn_type) {
return fn_type;
}
+// Traverse up to the very top ExprScope, which has children.
+// We have just arrived at the top from a child. That child,
+// and its next siblings, do not need to be marked. But the previous
+// siblings do.
+// x + (await y)
+// vs
+// (await y) + x
+static void mark_suspension_point(Scope *scope) {
+ ScopeExpr *child_expr_scope = (scope->id == ScopeIdExpr) ? reinterpret_cast<ScopeExpr *>(scope) : nullptr;
+ for (;;) {
+ scope = scope->parent;
+ switch (scope->id) {
+ case ScopeIdDefer:
+ case ScopeIdDeferExpr:
+ case ScopeIdDecls:
+ case ScopeIdFnDef:
+ case ScopeIdCompTime:
+ case ScopeIdVarDecl:
+ case ScopeIdCImport:
+ case ScopeIdSuspend:
+ case ScopeIdTypeOf:
+ case ScopeIdBlock:
+ return;
+ case ScopeIdLoop:
+ case ScopeIdRuntime:
+ continue;
+ case ScopeIdExpr: {
+ ScopeExpr *parent_expr_scope = reinterpret_cast<ScopeExpr *>(scope);
+ if (child_expr_scope != nullptr) {
+ for (size_t i = 0; parent_expr_scope->children_ptr[i] != child_expr_scope; i += 1) {
+ assert(i < parent_expr_scope->children_len);
+ parent_expr_scope->children_ptr[i]->need_spill = MemoizedBoolTrue;
+ }
+ }
+ parent_expr_scope->need_spill = MemoizedBoolTrue;
+ child_expr_scope = parent_expr_scope;
+ continue;
+ }
+ }
+ }
+}
+
+static bool scope_needs_spill(Scope *scope) {
+ ScopeExpr *scope_expr = find_expr_scope(scope);
+ if (scope_expr == nullptr) return false;
+
+ switch (scope_expr->need_spill) {
+ case MemoizedBoolUnknown:
+ if (scope_needs_spill(scope_expr->base.parent)) {
+ scope_expr->need_spill = MemoizedBoolTrue;
+ return true;
+ } else {
+ scope_expr->need_spill = MemoizedBoolFalse;
+ return false;
+ }
+ case MemoizedBoolFalse:
+ return false;
+ case MemoizedBoolTrue:
+ return true;
+ }
+ zig_unreachable();
+}
+
static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
Error err;
@@ -5786,21 +5887,17 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
callee_frame_type, "");
}
// Since this frame is async, an await might represent a suspend point, and
- // therefore need to spill.
+ // therefore need to spill. It also needs to mark expr scopes as having to spill.
+ // For example: foo() + await z
+ // The funtion call result of foo() must be spilled.
for (size_t i = 0; i < fn->await_list.length; i += 1) {
IrInstructionAwaitGen *await = fn->await_list.at(i);
- // TODO If this is a noasync await, it doesn't need to spill
+ // TODO If this is a noasync await, it doesn't suspend
// https://github.com/ziglang/zig/issues/3157
- if (await->result_loc != nullptr) {
- // If there's a result location, that is the spill
+ if (await->base.value.special != ConstValSpecialRuntime) {
+ // Known at comptime. No spill, no suspend.
continue;
}
- if (!type_has_bits(await->base.value.type))
- continue;
- if (await->base.value.special != ConstValSpecialRuntime)
- continue;
- if (await->base.ref_count == 0)
- continue;
if (await->target_fn != nullptr) {
// we might not need to suspend
analyze_fn_async(g, await->target_fn, false);
@@ -5809,13 +5906,53 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
return ErrorSemanticAnalyzeFail;
}
if (!fn_is_async(await->target_fn)) {
- // This await does not represent a suspend point. No spill needed.
+ // This await does not represent a suspend point. No spill needed,
+ // and no need to mark ExprScope.
continue;
}
}
+ // This await is a suspend point, but it might not need a spill.
+ // We do need to mark the ExprScope as having a suspend point in it.
+ mark_suspension_point(await->base.scope);
+
+ if (await->result_loc != nullptr) {
+ // If there's a result location, that is the spill
+ continue;
+ }
+ if (await->base.ref_count == 0)
+ continue;
+ if (!type_has_bits(await->base.value.type))
+ continue;
await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
await->base.value.type, "");
}
+ // Now that we've marked all the expr scopes that have to spill, we go over the instructions
+ // and spill the relevant ones.
+ for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) {
+ IrBasicBlock *block = fn->analyzed_executable.basic_block_list.at(block_i);
+ for (size_t instr_i = 0; instr_i < block->instruction_list.length; instr_i += 1) {
+ IrInstruction *instruction = block->instruction_list.at(instr_i);
+ if (instruction->id == IrInstructionIdAwaitGen ||
+ instruction->id == IrInstructionIdVarPtr ||
+ instruction->id == IrInstructionIdDeclRef ||
+ instruction->id == IrInstructionIdAllocaGen)
+ {
+ // This instruction does its own spilling specially, or otherwise doesn't need it.
+ continue;
+ }
+ if (instruction->value.special != ConstValSpecialRuntime)
+ continue;
+ if (instruction->ref_count == 0)
+ continue;
+ if (!type_has_bits(instruction->value.type))
+ continue;
+ if (scope_needs_spill(instruction->scope)) {
+ instruction->spill = ir_create_alloca(g, instruction->scope, instruction->source_node,
+ fn, instruction->value.type, "");
+ }
+ }
+ }
+
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
diff --git a/src/analyze.hpp b/src/analyze.hpp
index 2178327571..55bf9aba30 100644
--- a/src/analyze.hpp
+++ b/src/analyze.hpp
@@ -114,6 +114,7 @@ ScopeFnDef *create_fndef_scope(CodeGen *g, AstNode *node, Scope *parent, ZigFn *
Scope *create_comptime_scope(CodeGen *g, AstNode *node, Scope *parent);
Scope *create_runtime_scope(CodeGen *g, AstNode *node, Scope *parent, IrInstruction *is_comptime);
Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent);
+Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent);
void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str);
ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str);
@@ -261,5 +262,4 @@ void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
ZigType *var_type, const char *name_hint);
-
#endif
diff --git a/src/codegen.cpp b/src/codegen.cpp
index bbb1d9fc87..134569374e 100644
--- a/src/codegen.cpp
+++ b/src/codegen.cpp
@@ -649,6 +649,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
return get_di_scope(g, scope->parent);
}
zig_unreachable();
@@ -1644,7 +1645,6 @@ static void gen_assign_raw(CodeGen *g, LLVMValueRef ptr, ZigType *ptr_type,
LLVMValueRef ored_value = LLVMBuildOr(g->builder, shifted_value, anded_containing_int, "");
gen_store(g, ored_value, ptr, ptr_type);
- return;
}
static void gen_var_debug_decl(CodeGen *g, ZigVar *var) {
@@ -1664,11 +1664,16 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
if (instruction->id == IrInstructionIdAwaitGen) {
IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
if (await->result_loc != nullptr) {
- instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
+ return get_handle_value(g, ir_llvm_value(g, await->result_loc),
await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
- return instruction->llvm_value;
}
}
+ if (instruction->spill != nullptr) {
+ ZigType *ptr_type = instruction->spill->value.type;
+ src_assert(ptr_type->id == ZigTypeIdPointer, instruction->source_node);
+ return get_handle_value(g, ir_llvm_value(g, instruction->spill),
+ ptr_type->data.pointer.child_type, instruction->spill->value.type);
+ }
src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
assert(instruction->value.type);
render_const_val(g, &instruction->value, "");
@@ -3786,6 +3791,7 @@ static void render_async_var_decls(CodeGen *g, Scope *scope) {
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
scope = scope->parent;
continue;
}
@@ -6049,6 +6055,11 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) {
set_debug_location(g, instruction);
}
instruction->llvm_value = ir_render_instruction(g, executable, instruction);
+ if (instruction->spill != nullptr) {
+ LLVMValueRef spill_ptr = ir_llvm_value(g, instruction->spill);
+ gen_assign_raw(g, spill_ptr, instruction->spill->value.type, instruction->llvm_value);
+ instruction->llvm_value = nullptr;
+ }
}
current_block->llvm_exit_block = LLVMGetInsertBlock(g->builder);
}
diff --git a/src/ir.cpp b/src/ir.cpp
index 53ce2d89e1..1a0aad36e9 100644
--- a/src/ir.cpp
+++ b/src/ir.cpp
@@ -3364,6 +3364,7 @@ static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_sco
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
scope = scope->parent;
continue;
case ScopeIdDeferExpr:
@@ -3420,6 +3421,7 @@ static bool ir_gen_defers_for_block(IrBuilder *irb, Scope *inner_scope, Scope *o
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
scope = scope->parent;
continue;
case ScopeIdDeferExpr:
@@ -8158,7 +8160,15 @@ static IrInstruction *ir_gen_node_extra(IrBuilder *irb, AstNode *node, Scope *sc
result_loc = no_result_loc();
ir_build_reset_result(irb, scope, node, result_loc);
}
- IrInstruction *result = ir_gen_node_raw(irb, node, scope, lval, result_loc);
+ Scope *child_scope;
+ if (irb->exec->is_inline ||
+ (irb->exec->fn_entry != nullptr && irb->exec->fn_entry->child_scope == scope))
+ {
+ child_scope = scope;
+ } else {
+ child_scope = create_expr_scope(irb->codegen, node, scope);
+ }
+ IrInstruction *result = ir_gen_node_raw(irb, node, child_scope, lval, result_loc);
if (result == irb->codegen->invalid_instruction) {
if (irb->exec->first_err_trace_msg == nullptr) {
irb->exec->first_err_trace_msg = irb->codegen->trace_err;