diff options
| author | LemonBoy <thatlemon@gmail.com> | 2020-04-05 10:40:41 +0200 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2020-04-05 18:34:31 -0400 |
| commit | f6cdc94a50235eaf145f6c2c2ec257008d592494 (patch) | |
| tree | 5225b4d0fe332afabb6bedd3f8e24aa900ad8e73 /src/codegen.cpp | |
| parent | 0f964e19109d44d039fbefa691657ca82c7bbe52 (diff) | |
| download | zig-f6cdc94a50235eaf145f6c2c2ec257008d592494.tar.gz zig-f6cdc94a50235eaf145f6c2c2ec257008d592494.zip | |
ir: Fix error checking for vector ops
The extra logic that's needed was lost during a refactoring, now it
should be fine.
Diffstat (limited to 'src/codegen.cpp')
| -rw-r--r-- | src/codegen.cpp | 59 |
1 files changed, 39 insertions, 20 deletions
diff --git a/src/codegen.cpp b/src/codegen.cpp index e7fff882c3..a2cd5fafc0 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -155,7 +155,6 @@ static LLVMValueRef gen_await_early_return(CodeGen *g, IrInstGen *source_instr, LLVMValueRef target_frame_ptr, ZigType *result_type, ZigType *ptr_result_type, LLVMValueRef result_loc, bool non_async); static Error get_tmp_filename(CodeGen *g, Buf *out, Buf *suffix); -static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val); static void addLLVMAttr(LLVMValueRef val, LLVMAttributeIndex attr_index, const char *attr_name) { unsigned kind_id = LLVMGetEnumAttributeKindForName(attr_name, strlen(attr_name)); @@ -2536,6 +2535,36 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir return nullptr; } +enum class ScalarizePredicate { + // Returns true iff all the elements in the vector are 1. + // Equivalent to folding all the bits with `and`. + All, + // Returns true iff there's at least one element in the vector that is 1. + // Equivalent to folding all the bits with `or`. + Any, +}; + +// Collapses a <N x i1> vector into a single i1 according to the given predicate +static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) { + assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind); + LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val))); + LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, ""); + + switch (predicate) { + case ScalarizePredicate::Any: { + LLVMValueRef all_zeros = LLVMConstNull(scalar_type); + return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, ""); + } + case ScalarizePredicate::All: { + LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type); + return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, ""); + } + } + + zig_unreachable(); +} + + static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type, LLVMValueRef val1, LLVMValueRef val2) { @@ -2560,7 +2589,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit); + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2591,7 +2620,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit); + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2647,16 +2676,6 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) { } } -// Collapses a <N x i1> vector into a single i1 whose value is 1 iff all the -// vector elements are 1 -static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val) { - assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind); - LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val))); - LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type); - LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, ""); - return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, ""); -} - static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast_math, LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, DivKind div_kind) { @@ -2678,7 +2697,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast } if (operand_type->id == ZigTypeIdVector) { - is_zero_bit = scalarize_cmp_result(g, is_zero_bit); + is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any); } LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail"); @@ -2703,7 +2722,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, ""); LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, ""); if (operand_type->id == ZigTypeIdVector) { - overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit); + overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any); } LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block); @@ -2728,7 +2747,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, ""); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit); + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2745,7 +2764,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd"); LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, ""); if (operand_type->id == ZigTypeIdVector) { - ltz = scalarize_cmp_result(g, ltz); + ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any); } LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block); @@ -2797,7 +2816,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, ""); if (operand_type->id == ZigTypeIdVector) { - ok_bit = scalarize_cmp_result(g, ok_bit); + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -2861,7 +2880,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast } if (operand_type->id == ZigTypeIdVector) { - is_zero_bit = scalarize_cmp_result(g, is_zero_bit); + is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any); } LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk"); @@ -2918,7 +2937,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk"); LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, ""); if (rhs_type->id == ZigTypeIdVector) { - less_than_bit = scalarize_cmp_result(g, less_than_bit); + less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any); } LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block); |
