diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/all_types.hpp | 3 | ||||
| -rw-r--r-- | src/analyze.cpp | 14 | ||||
| -rw-r--r-- | src/codegen.cpp | 130 |
3 files changed, 115 insertions, 32 deletions
diff --git a/src/all_types.hpp b/src/all_types.hpp index 3bd51002c5..cc944bc35a 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1051,6 +1051,7 @@ struct FnTableEntry { bool is_extern; bool is_test; bool is_pure; + bool safety_off; BlockContext *parent_block_context; FnAnalState anal_state; @@ -1315,6 +1316,8 @@ struct BlockContext { // if this is true, then this code will not be generated bool codegen_excluded; + + bool safety_off; }; diff --git a/src/analyze.cpp b/src/analyze.cpp index 737fff0a7b..5be9e10392 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -993,6 +993,18 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t add_node_error(g, directive_node, buf_sprintf("invalid function attribute: '%s'", buf_ptr(name))); } + } else if (buf_eql_str(name, "debug_safety")) { + if (fn_table_entry->is_extern) { + add_node_error(g, directive_node, + buf_sprintf("#debug_safety invalid on extern functions")); + } else { + bool enable; + bool ok = resolve_const_expr_bool(g, import, import->block_context, + &directive_node->data.directive.expr, &enable); + if (ok && !enable) { + fn_table_entry->safety_off = true; + } + } } else if (buf_eql_str(name, "condition")) { if (fn_proto->top_level_decl.visib_mod == VisibModExport) { bool include; @@ -2102,11 +2114,13 @@ BlockContext *new_block_context(AstNode *node, BlockContext *parent) { context->parent_loop_node = parent->parent_loop_node; context->c_import_buf = parent->c_import_buf; context->codegen_excluded = parent->codegen_excluded; + context->safety_off = parent->safety_off; } if (node && node->type == NodeTypeFnDef) { AstNode *fn_proto_node = node->data.fn_def.fn_proto; context->fn_entry = fn_proto_node->data.fn_proto.fn_table_entry; + context->safety_off = context->fn_entry->safety_off; } else if (parent) { context->fn_entry = parent->fn_entry; } diff --git a/src/codegen.cpp b/src/codegen.cpp index aa3b4d1529..dda78a0653 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -330,6 +330,46 @@ static LLVMValueRef get_handle_value(CodeGen *g, AstNode *source_node, LLVMValue } } +static bool want_debug_safety(CodeGen *g, AstNode *node) { + return !g->is_release_build && !node->block_context->safety_off; +} + +static void add_bounds_check(CodeGen *g, AstNode *source_node, LLVMValueRef target_val, + LLVMIntPredicate lower_pred, LLVMValueRef lower_value, + LLVMIntPredicate upper_pred, LLVMValueRef upper_value) +{ + if (!lower_value && !upper_value) { + return; + } + if (upper_value && !lower_value) { + lower_value = upper_value; + lower_pred = upper_pred; + upper_value = nullptr; + } + + add_debug_source_node(g, source_node); + + LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckFail"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckOk"); + LLVMBasicBlockRef lower_ok_block = upper_value ? + LLVMAppendBasicBlock(g->cur_fn->fn_value, "FirstBoundsCheckOk") : ok_block; + + LLVMValueRef lower_ok_val = LLVMBuildICmp(g->builder, lower_pred, target_val, lower_value, ""); + LLVMBuildCondBr(g->builder, lower_ok_val, lower_ok_block, bounds_check_fail_block); + + LLVMPositionBuilderAtEnd(g->builder, bounds_check_fail_block); + LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); + LLVMBuildUnreachable(g->builder); + + if (upper_value) { + LLVMPositionBuilderAtEnd(g->builder, lower_ok_block); + LLVMValueRef upper_ok_val = LLVMBuildICmp(g->builder, upper_pred, target_val, upper_value, ""); + LLVMBuildCondBr(g->builder, upper_ok_val, ok_block, bounds_check_fail_block); + } + + LLVMPositionBuilderAtEnd(g->builder, ok_block); +} + static LLVMValueRef gen_err_name(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFnCallExpr); assert(g->generate_error_name_table); @@ -344,25 +384,10 @@ static LLVMValueRef gen_err_name(CodeGen *g, AstNode *node) { LLVMValueRef err_val = gen_expr(g, err_val_node); add_debug_source_node(g, node); - if (!g->is_release_build) { - LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckFail"); - LLVMBasicBlockRef lower_ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "LowerBoundsCheckOk"); - LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckOk"); - + if (want_debug_safety(g, node)) { LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(err_val)); - LLVMValueRef is_zero_val = LLVMBuildICmp(g->builder, LLVMIntEQ, err_val, zero, ""); - LLVMBuildCondBr(g->builder, is_zero_val, bounds_check_fail_block, lower_ok_block); - - LLVMPositionBuilderAtEnd(g->builder, bounds_check_fail_block); - LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); - LLVMBuildUnreachable(g->builder); - - LLVMPositionBuilderAtEnd(g->builder, lower_ok_block); LLVMValueRef end_val = LLVMConstInt(LLVMTypeOf(err_val), g->error_decls.length, false); - LLVMValueRef is_too_big_val = LLVMBuildICmp(g->builder, LLVMIntUGE, err_val, end_val, ""); - LLVMBuildCondBr(g->builder, is_too_big_val, bounds_check_fail_block, ok_block); - - LLVMPositionBuilderAtEnd(g->builder, ok_block); + add_bounds_check(g, node, err_val, LLVMIntNE, zero, LLVMIntULT, end_val); } LLVMValueRef indices[] = { @@ -869,6 +894,11 @@ static LLVMValueRef gen_array_elem_ptr(CodeGen *g, AstNode *source_node, LLVMVal } if (array_type->id == TypeTableEntryIdArray) { + if (want_debug_safety(g, source_node)) { + LLVMValueRef end = LLVMConstInt(g->builtin_types.entry_isize->type_ref, + array_type->data.array.len, false); + add_bounds_check(g, source_node, subscript_value, LLVMIntEQ, nullptr, LLVMIntULT, end); + } LLVMValueRef indices[] = { LLVMConstNull(g->builtin_types.entry_isize->type_ref), subscript_value @@ -887,6 +917,15 @@ static LLVMValueRef gen_array_elem_ptr(CodeGen *g, AstNode *source_node, LLVMVal assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind); assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(array_ptr))) == LLVMStructTypeKind); + if (want_debug_safety(g, source_node)) { + add_debug_source_node(g, source_node); + int len_index = array_type->data.structure.fields[1].gen_index; + assert(len_index >= 0); + LLVMValueRef len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, ""); + LLVMValueRef len = LLVMBuildLoad(g->builder, len_ptr, ""); + add_bounds_check(g, source_node, subscript_value, LLVMIntEQ, nullptr, LLVMIntULT, len); + } + add_debug_source_node(g, source_node); int ptr_index = array_type->data.structure.fields[0].gen_index; assert(ptr_index >= 0); @@ -907,7 +946,6 @@ static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) { LLVMValueRef array_ptr = gen_array_base_ptr(g, array_expr_node); LLVMValueRef subscript_value = gen_expr(g, node->data.array_access_expr.subscript); - return gen_array_elem_ptr(g, node, array_ptr, array_type, subscript_value); } @@ -969,6 +1007,15 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { end_val = LLVMConstInt(g->builtin_types.entry_isize->type_ref, array_type->data.array.len, false); } + if (want_debug_safety(g, node)) { + add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); + if (node->data.slice_expr.end) { + LLVMValueRef array_end = LLVMConstInt(g->builtin_types.entry_isize->type_ref, + array_type->data.array.len, false); + add_bounds_check(g, node, end_val, LLVMIntEQ, nullptr, LLVMIntULE, array_end); + } + } + add_debug_source_node(g, node); LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, ""); LLVMValueRef indices[] = { @@ -987,6 +1034,10 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start); LLVMValueRef end_val = gen_expr(g, node->data.slice_expr.end); + if (want_debug_safety(g, node)) { + add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); + } + add_debug_source_node(g, node); LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, ""); LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &start_val, 1, ""); @@ -1002,22 +1053,33 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) { assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind); assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(array_ptr))) == LLVMStructTypeKind); + int ptr_index = array_type->data.structure.fields[0].gen_index; + assert(ptr_index >= 0); + int len_index = array_type->data.structure.fields[1].gen_index; + assert(len_index >= 0); + + LLVMValueRef prev_end = nullptr; + if (!node->data.slice_expr.end || want_debug_safety(g, node)) { + add_debug_source_node(g, node); + LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, ""); + prev_end = LLVMBuildLoad(g->builder, src_len_ptr, ""); + } + LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start); LLVMValueRef end_val; if (node->data.slice_expr.end) { end_val = gen_expr(g, node->data.slice_expr.end); } else { - add_debug_source_node(g, node); - int len_index = array_type->data.structure.fields[1].gen_index; - assert(len_index >= 0); - LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, ""); - end_val = LLVMBuildLoad(g->builder, src_len_ptr, ""); + end_val = prev_end; } - int ptr_index = array_type->data.structure.fields[0].gen_index; - assert(ptr_index >= 0); - int len_index = array_type->data.structure.fields[1].gen_index; - assert(len_index >= 0); + if (want_debug_safety(g, node)) { + assert(prev_end); + add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val); + if (node->data.slice_expr.end) { + add_bounds_check(g, node, end_val, LLVMIntEQ, nullptr, LLVMIntULE, prev_end); + } + } add_debug_source_node(g, node); LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, ptr_index, ""); @@ -1225,7 +1287,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { assert(expr_type->id == TypeTableEntryIdErrorUnion); TypeTableEntry *child_type = expr_type->data.error.child_type; - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { LLVMValueRef err_val; if (type_has_bits(child_type)) { add_debug_source_node(g, node); @@ -1263,7 +1325,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) { assert(expr_type->id == TypeTableEntryIdMaybe); TypeTableEntry *child_type = expr_type->data.maybe.child_type; - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { add_debug_source_node(g, node); LLVMValueRef cond_val; if (child_type->id == TypeTableEntryIdPointer || @@ -2261,7 +2323,7 @@ static LLVMValueRef gen_container_init_expr(CodeGen *g, AstNode *node) { } else if (type_entry->id == TypeTableEntryIdUnreachable) { assert(node->data.container_init_expr.entries.length == 0); add_debug_source_node(g, node); - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); } LLVMBuildUnreachable(g->builder); @@ -2575,7 +2637,7 @@ static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVa } } } - if (!ignore_uninit && !g->is_release_build) { + if (!ignore_uninit && want_debug_safety(g, source_node)) { TypeTableEntry *isize = g->builtin_types.entry_isize; uint64_t size_bytes = LLVMStoreSizeOfType(g->target_data_ref, variable->type->type_ref); uint64_t align_bytes = get_memcpy_align(g, variable->type); @@ -2790,7 +2852,7 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) { if (!else_prong) { LLVMPositionBuilderAtEnd(g->builder, else_block); add_debug_source_node(g, node); - if (!g->is_release_build) { + if (want_debug_safety(g, node)) { LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, ""); } LLVMBuildUnreachable(g->builder); @@ -3383,6 +3445,10 @@ static void do_code_gen(CodeGen *g) { // Generate the list of test function pointers. if (g->is_test_build) { + if (g->test_fn_count == 0) { + fprintf(stderr, "No tests to run.\n"); + exit(0); + } assert(g->test_fn_count > 0); assert(next_test_index == g->test_fn_count); |
