diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2020-04-05 18:34:47 -0400 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2020-04-05 18:34:47 -0400 |
| commit | 05b587fcdee91a7c9f170da4a186a512b51b39a8 (patch) | |
| tree | 638c7444a2f73fc304b662bf50329ccc18d2c86c /src | |
| parent | e2dc63644ab3d8e5cdaec2d58dc57c587295081f (diff) | |
| parent | e84b9b70ff2814d6e50a851dc9f094b15399d2fe (diff) | |
| download | zig-05b587fcdee91a7c9f170da4a186a512b51b39a8.tar.gz zig-05b587fcdee91a7c9f170da4a186a512b51b39a8.zip | |
Merge branch 'LemonBoy-vec-div'
closes #4737
Diffstat (limited to 'src')
| -rw-r--r-- | src/codegen.cpp | 209 | ||||
| -rw-r--r-- | src/ir.cpp | 411 |
2 files changed, 404 insertions, 216 deletions
diff --git a/src/codegen.cpp b/src/codegen.cpp index 84168f509f..a2cd5fafc0 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2535,19 +2535,51 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir return nullptr; } -static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry, - LLVMValueRef val1, LLVMValueRef val2) +enum class ScalarizePredicate { + // Returns true iff all the elements in the vector are 1. + // Equivalent to folding all the bits with `and`. + All, + // Returns true iff there's at least one element in the vector that is 1. + // Equivalent to folding all the bits with `or`. + Any, +}; + +// Collapses a <N x i1> vector into a single i1 according to the given predicate +static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) { + assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind); + LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val))); + LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, ""); + + switch (predicate) { + case ScalarizePredicate::Any: { + LLVMValueRef all_zeros = LLVMConstNull(scalar_type); + return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, ""); + } + case ScalarizePredicate::All: { + LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type); + return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, ""); + } + } + + zig_unreachable(); +} + + +static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type, + LLVMValueRef val1, LLVMValueRef val2) { // for unsigned left shifting, we do the lossy shift, then logically shift // right the same number of bits // if the values don't match, we have an overflow // for signed left shifting we do the same except arithmetic shift right + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? + operand_type->data.vector.elem_type : operand_type; - assert(type_entry->id == ZigTypeIdInt); + assert(scalar_type->id == ZigTypeIdInt); LLVMValueRef result = LLVMBuildShl(g->builder, val1, val2, ""); LLVMValueRef orig_val; - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { orig_val = LLVMBuildAShr(g->builder, result, val2, ""); } else { orig_val = LLVMBuildLShr(g->builder, result, val2, ""); @@ -2556,6 +2588,9 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); + if (operand_type->id == ZigTypeIdVector) { + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -2565,13 +2600,16 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry, return result; } -static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, - LLVMValueRef val1, LLVMValueRef val2) +static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type, + LLVMValueRef val1, LLVMValueRef val2) { - assert(type_entry->id == ZigTypeIdInt); + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? + operand_type->data.vector.elem_type : operand_type; + + assert(scalar_type->id == ZigTypeIdInt); LLVMValueRef result; - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { result = LLVMBuildAShr(g->builder, val1, val2, ""); } else { result = LLVMBuildLShr(g->builder, val1, val2, ""); @@ -2581,6 +2619,9 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail"); + if (operand_type->id == ZigTypeIdVector) { + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -2591,12 +2632,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, } static LLVMValueRef gen_float_op(CodeGen *g, LLVMValueRef val, ZigType *type_entry, BuiltinFnId op) { - if ((op == BuiltinFnIdCeil || - op == BuiltinFnIdFloor) && - type_entry->id == ZigTypeIdInt) - return val; - assert(type_entry->id == ZigTypeIdFloat); - + assert(type_entry->id == ZigTypeIdFloat || type_entry->id == ZigTypeIdVector); LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloatOp, op); return LLVMBuildCall(g->builder, floor_fn, &val, 1, ""); } @@ -2612,6 +2648,21 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) { if (bigint->digit_count == 0) { return LLVMConstNull(type_ref); } + + if (LLVMGetTypeKind(type_ref) == LLVMVectorTypeKind) { + const unsigned vector_len = LLVMGetVectorSize(type_ref); + LLVMTypeRef elem_type = LLVMGetElementType(type_ref); + + LLVMValueRef *values = heap::c_allocator.allocate_nonzero<LLVMValueRef>(vector_len); + // Create a vector with all the elements having the same value + for (unsigned i = 0; i < vector_len; i++) { + values[i] = bigint_to_llvm_const(elem_type, bigint); + } + LLVMValueRef result = LLVMConstVector(values, vector_len); + heap::c_allocator.deallocate(values, vector_len); + return result; + } + LLVMValueRef unsigned_val; if (bigint->digit_count == 1) { unsigned_val = LLVMConstInt(type_ref, bigint_ptr(bigint)[0], false); @@ -2626,21 +2677,29 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) { } static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast_math, - LLVMValueRef val1, LLVMValueRef val2, - ZigType *type_entry, DivKind div_kind) + LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, DivKind div_kind) { + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? + operand_type->data.vector.elem_type : operand_type; + ZigLLVMSetFastMath(g->builder, want_fast_math); - LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, type_entry)); - if (want_runtime_safety && (want_fast_math || type_entry->id != ZigTypeIdFloat)) { + LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, operand_type)); + if (want_runtime_safety && (want_fast_math || scalar_type->id != ZigTypeIdFloat)) { + // Safety check: divisor != 0 LLVMValueRef is_zero_bit; - if (type_entry->id == ZigTypeIdInt) { + if (scalar_type->id == ZigTypeIdInt) { is_zero_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, ""); - } else if (type_entry->id == ZigTypeIdFloat) { + } else if (scalar_type->id == ZigTypeIdFloat) { is_zero_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, val2, zero, ""); } else { zig_unreachable(); } + + if (operand_type->id == ZigTypeIdVector) { + is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any); + } + LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail"); LLVMBasicBlockRef div_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk"); LLVMBuildCondBr(g->builder, is_zero_bit, div_zero_fail_block, div_zero_ok_block); @@ -2650,16 +2709,21 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMPositionBuilderAtEnd(g->builder, div_zero_ok_block); - if (type_entry->id == ZigTypeIdInt && type_entry->data.integral.is_signed) { - LLVMValueRef neg_1_value = LLVMConstInt(get_llvm_type(g, type_entry), -1, true); + // Safety check: check for overflow (dividend = minInt and divisor = -1) + if (scalar_type->id == ZigTypeIdInt && scalar_type->data.integral.is_signed) { + LLVMValueRef neg_1_value = LLVMConstAllOnes(get_llvm_type(g, operand_type)); BigInt int_min_bi = {0}; - eval_min_max_value_int(g, type_entry, &int_min_bi, false); - LLVMValueRef int_min_value = bigint_to_llvm_const(get_llvm_type(g, type_entry), &int_min_bi); + eval_min_max_value_int(g, scalar_type, &int_min_bi, false); + LLVMValueRef int_min_value = bigint_to_llvm_const(get_llvm_type(g, operand_type), &int_min_bi); + LLVMBasicBlockRef overflow_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowFail"); LLVMBasicBlockRef overflow_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowOk"); LLVMValueRef num_is_int_min = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, int_min_value, ""); LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, ""); LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, ""); + if (operand_type->id == ZigTypeIdVector) { + overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any); + } LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block); LLVMPositionBuilderAtEnd(g->builder, overflow_fail_block); @@ -2669,18 +2733,22 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast } } - if (type_entry->id == ZigTypeIdFloat) { + if (scalar_type->id == ZigTypeIdFloat) { LLVMValueRef result = LLVMBuildFDiv(g->builder, val1, val2, ""); switch (div_kind) { case DivKindFloat: return result; case DivKindExact: if (want_runtime_safety) { - LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); + // Safety check: a / b == floor(a / b) + LLVMValueRef floored = gen_float_op(g, result, operand_type, BuiltinFnIdFloor); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, ""); - + if (operand_type->id == ZigTypeIdVector) { + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -2695,54 +2763,61 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBasicBlockRef gez_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncGEZero"); LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd"); LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, ""); + if (operand_type->id == ZigTypeIdVector) { + ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any); + } LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block); LLVMPositionBuilderAtEnd(g->builder, ltz_block); - LLVMValueRef ceiled = gen_float_op(g, result, type_entry, BuiltinFnIdCeil); + LLVMValueRef ceiled = gen_float_op(g, result, operand_type, BuiltinFnIdCeil); LLVMBasicBlockRef ceiled_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); LLVMPositionBuilderAtEnd(g->builder, gez_block); - LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); + LLVMValueRef floored = gen_float_op(g, result, operand_type, BuiltinFnIdFloor); LLVMBasicBlockRef floored_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); LLVMPositionBuilderAtEnd(g->builder, end_block); - LLVMValueRef phi = LLVMBuildPhi(g->builder, get_llvm_type(g, type_entry), ""); + LLVMValueRef phi = LLVMBuildPhi(g->builder, get_llvm_type(g, operand_type), ""); LLVMValueRef incoming_values[] = { ceiled, floored }; LLVMBasicBlockRef incoming_blocks[] = { ceiled_end_block, floored_end_block }; LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2); return phi; } case DivKindFloor: - return gen_float_op(g, result, type_entry, BuiltinFnIdFloor); + return gen_float_op(g, result, operand_type, BuiltinFnIdFloor); } zig_unreachable(); } - assert(type_entry->id == ZigTypeIdInt); + assert(scalar_type->id == ZigTypeIdInt); switch (div_kind) { case DivKindFloat: zig_unreachable(); case DivKindTrunc: - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { return LLVMBuildSDiv(g->builder, val1, val2, ""); } else { return LLVMBuildUDiv(g->builder, val1, val2, ""); } case DivKindExact: if (want_runtime_safety) { + // Safety check: a % b == 0 LLVMValueRef remainder_val; - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { remainder_val = LLVMBuildSRem(g->builder, val1, val2, ""); } else { remainder_val = LLVMBuildURem(g->builder, val1, val2, ""); } - LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, ""); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); + LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, ""); + if (operand_type->id == ZigTypeIdVector) { + ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All); + } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -2750,14 +2825,14 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMPositionBuilderAtEnd(g->builder, ok_block); } - if (type_entry->data.integral.is_signed) { + if (scalar_type->data.integral.is_signed) { return LLVMBuildExactSDiv(g->builder, val1, val2, ""); } else { return LLVMBuildExactUDiv(g->builder, val1, val2, ""); } case DivKindFloor: { - if (!type_entry->data.integral.is_signed) { + if (!scalar_type->data.integral.is_signed) { return LLVMBuildUDiv(g->builder, val1, val2, ""); } // const d = @divTrunc(a, b); @@ -2784,22 +2859,30 @@ enum RemKind { }; static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast_math, - LLVMValueRef val1, LLVMValueRef val2, - ZigType *type_entry, RemKind rem_kind) + LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, RemKind rem_kind) { + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? + operand_type->data.vector.elem_type : operand_type; + ZigLLVMSetFastMath(g->builder, want_fast_math); - LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, type_entry)); + LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, operand_type)); if (want_runtime_safety) { + // Safety check: divisor != 0 LLVMValueRef is_zero_bit; - if (type_entry->id == ZigTypeIdInt) { - LLVMIntPredicate pred = type_entry->data.integral.is_signed ? LLVMIntSLE : LLVMIntEQ; + if (scalar_type->id == ZigTypeIdInt) { + LLVMIntPredicate pred = scalar_type->data.integral.is_signed ? LLVMIntSLE : LLVMIntEQ; is_zero_bit = LLVMBuildICmp(g->builder, pred, val2, zero, ""); - } else if (type_entry->id == ZigTypeIdFloat) { + } else if (scalar_type->id == ZigTypeIdFloat) { is_zero_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, val2, zero, ""); } else { zig_unreachable(); } + + if (operand_type->id == ZigTypeIdVector) { + is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any); + } + LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk"); LLVMBasicBlockRef rem_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroFail"); LLVMBuildCondBr(g->builder, is_zero_bit, rem_zero_fail_block, rem_zero_ok_block); @@ -2810,7 +2893,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMPositionBuilderAtEnd(g->builder, rem_zero_ok_block); } - if (type_entry->id == ZigTypeIdFloat) { + if (scalar_type->id == ZigTypeIdFloat) { if (rem_kind == RemKindRem) { return LLVMBuildFRem(g->builder, val1, val2, ""); } else { @@ -2821,8 +2904,8 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast return LLVMBuildSelect(g->builder, ltz, c, a, ""); } } else { - assert(type_entry->id == ZigTypeIdInt); - if (type_entry->data.integral.is_signed) { + assert(scalar_type->id == ZigTypeIdInt); + if (scalar_type->data.integral.is_signed) { if (rem_kind == RemKindRem) { return LLVMBuildSRem(g->builder, val1, val2, ""); } else { @@ -2845,11 +2928,17 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type // otherwise the check is useful as the allowed values are limited by the // operand type itself if (!is_power_of_2(lhs_type->data.integral.bit_count)) { - LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type), - lhs_type->data.integral.bit_count, false); - LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, ""); + BigInt bit_count_bi = {0}; + bigint_init_unsigned(&bit_count_bi, lhs_type->data.integral.bit_count); + LLVMValueRef bit_count_value = bigint_to_llvm_const(get_llvm_type(g, rhs_type), + &bit_count_bi); + LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail"); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk"); + LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, ""); + if (rhs_type->id == ZigTypeIdVector) { + less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any); + } LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -2966,7 +3055,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, case IrBinOpBitShiftLeftExact: { assert(scalar_type->id == ZigTypeIdInt); - LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); + LLVMValueRef op2_casted = LLVMBuildZExt(g->builder, op2_value, + LLVMTypeOf(op1_value), ""); if (want_runtime_safety) { gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value); @@ -2976,7 +3066,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, if (is_sloppy) { return LLVMBuildShl(g->builder, op1_value, op2_casted, ""); } else if (want_runtime_safety) { - return gen_overflow_shl_op(g, scalar_type, op1_value, op2_casted); + return gen_overflow_shl_op(g, operand_type, op1_value, op2_casted); } else if (scalar_type->data.integral.is_signed) { return ZigLLVMBuildNSWShl(g->builder, op1_value, op2_casted, ""); } else { @@ -2987,7 +3077,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, case IrBinOpBitShiftRightExact: { assert(scalar_type->id == ZigTypeIdInt); - LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value); + LLVMValueRef op2_casted = LLVMBuildZExt(g->builder, op2_value, + LLVMTypeOf(op1_value), ""); if (want_runtime_safety) { gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value); @@ -3001,7 +3092,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, return LLVMBuildLShr(g->builder, op1_value, op2_casted, ""); } } else if (want_runtime_safety) { - return gen_overflow_shr_op(g, scalar_type, op1_value, op2_casted); + return gen_overflow_shr_op(g, operand_type, op1_value, op2_casted); } else if (scalar_type->data.integral.is_signed) { return ZigLLVMBuildAShrExact(g->builder, op1_value, op2_casted, ""); } else { @@ -3010,22 +3101,22 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable, } case IrBinOpDivUnspecified: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, scalar_type, DivKindFloat); + op1_value, op2_value, operand_type, DivKindFloat); case IrBinOpDivExact: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, scalar_type, DivKindExact); + op1_value, op2_value, operand_type, DivKindExact); case IrBinOpDivTrunc: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, scalar_type, DivKindTrunc); + op1_value, op2_value, operand_type, DivKindTrunc); case IrBinOpDivFloor: return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, scalar_type, DivKindFloor); + op1_value, op2_value, operand_type, DivKindFloor); case IrBinOpRemRem: return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, scalar_type, RemKindRem); + op1_value, op2_value, operand_type, RemKindRem); case IrBinOpRemMod: return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base), - op1_value, op2_value, scalar_type, RemKindMod); + op1_value, op2_value, operand_type, RemKindMod); } zig_unreachable(); } diff --git a/src/ir.cpp b/src/ir.cpp index bc222a311b..6fed044c6c 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -283,6 +283,8 @@ static IrInstGen *ir_analyze_union_init(IrAnalyze *ira, IrInst* source_instructi IrInstGen *result_loc); static IrInstGen *ir_analyze_struct_value_field_value(IrAnalyze *ira, IrInst* source_instr, IrInstGen *struct_operand, TypeStructField *field); +static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right); +static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right); static void destroy_instruction_src(IrInstSrc *inst) { switch (inst->id) { @@ -16803,7 +16805,6 @@ static IrInstGen *ir_analyze_math_op(IrAnalyze *ira, IrInst* source_instr, ZigValue *scalar_op2_val = &op2_val->data.x_array.data.s_none.elements[i]; ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i]; assert(scalar_op1_val->type == scalar_type); - assert(scalar_op2_val->type == scalar_type); assert(scalar_out_val->type == scalar_type); ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type, scalar_op1_val, op_id, scalar_op2_val, scalar_out_val); @@ -16828,27 +16829,49 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (type_is_invalid(op1->value->type)) return ira->codegen->invalid_inst_gen; - if (op1->value->type->id != ZigTypeIdInt && op1->value->type->id != ZigTypeIdComptimeInt) { + IrInstGen *op2 = bin_op_instruction->op2->child; + if (type_is_invalid(op2->value->type)) + return ira->codegen->invalid_inst_gen; + + ZigType *op1_type = op1->value->type; + ZigType *op2_type = op2->value->type; + + if (op1_type->id == ZigTypeIdVector && op2_type->id != ZigTypeIdVector) { ir_add_error(ira, &bin_op_instruction->op1->base, - buf_sprintf("bit shifting operation expected integer type, found '%s'", - buf_ptr(&op1->value->type->name))); + buf_sprintf("bit shifting operation expected vector type, found '%s'", + buf_ptr(&op2_type->name))); return ira->codegen->invalid_inst_gen; } - IrInstGen *op2 = bin_op_instruction->op2->child; - if (type_is_invalid(op2->value->type)) + if (op1_type->id != ZigTypeIdVector && op2_type->id == ZigTypeIdVector) { + ir_add_error(ira, &bin_op_instruction->op1->base, + buf_sprintf("bit shifting operation expected vector type, found '%s'", + buf_ptr(&op1_type->name))); + return ira->codegen->invalid_inst_gen; + } + + ZigType *op1_scalar_type = (op1_type->id == ZigTypeIdVector) ? + op1_type->data.vector.elem_type : op1_type; + ZigType *op2_scalar_type = (op2_type->id == ZigTypeIdVector) ? + op2_type->data.vector.elem_type : op2_type; + + if (op1_scalar_type->id != ZigTypeIdInt && op1_scalar_type->id != ZigTypeIdComptimeInt) { + ir_add_error(ira, &bin_op_instruction->op1->base, + buf_sprintf("bit shifting operation expected integer type, found '%s'", + buf_ptr(&op1_scalar_type->name))); return ira->codegen->invalid_inst_gen; + } - if (op2->value->type->id != ZigTypeIdInt && op2->value->type->id != ZigTypeIdComptimeInt) { + if (op2_scalar_type->id != ZigTypeIdInt && op2_scalar_type->id != ZigTypeIdComptimeInt) { ir_add_error(ira, &bin_op_instruction->op2->base, buf_sprintf("shift amount has to be an integer type, but found '%s'", - buf_ptr(&op2->value->type->name))); + buf_ptr(&op2_scalar_type->name))); return ira->codegen->invalid_inst_gen; } IrInstGen *casted_op2; IrBinOp op_id = bin_op_instruction->op_id; - if (op1->value->type->id == ZigTypeIdComptimeInt) { + if (op1_scalar_type->id == ZigTypeIdComptimeInt) { // comptime_int has no finite bit width casted_op2 = op2; @@ -16874,10 +16897,15 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in return ira->codegen->invalid_inst_gen; } } else { - const unsigned bit_count = op1->value->type->data.integral.bit_count; + const unsigned bit_count = op1_scalar_type->data.integral.bit_count; ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen, bit_count > 0 ? bit_count - 1 : 0); + if (op1_type->id == ZigTypeIdVector) { + shift_amt_type = get_vector_type(ira->codegen, op1_type->data.vector.len, + shift_amt_type); + } + casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type); if (type_is_invalid(casted_op2->value->type)) return ira->codegen->invalid_inst_gen; @@ -16888,10 +16916,10 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; - BigInt bit_count_value = {0}; - bigint_init_unsigned(&bit_count_value, bit_count); + ZigValue bit_count_value; + init_const_usize(ira->codegen, &bit_count_value, bit_count); - if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) { + if (!value_cmp_numeric_val_all(op2_val, CmpLT, &bit_count_value)) { ErrorMsg* msg = ir_add_error(ira, &bin_op_instruction->base.base, buf_sprintf("RHS of shift is too large for LHS type")); @@ -16910,7 +16938,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; - if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ) + if (value_cmp_numeric_val_all(op2_val, CmpEQ, nullptr)) return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1); } @@ -16923,7 +16951,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; - return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val); + return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1_type, op1_val, op_id, op2_val); } return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type, @@ -16943,6 +16971,7 @@ static bool ok_float_op(IrBinOp op) { case IrBinOpDivExact: case IrBinOpRemRem: case IrBinOpRemMod: + case IrBinOpRemUnspecified: return true; case IrBinOpBoolOr: @@ -16963,7 +16992,6 @@ static bool ok_float_op(IrBinOp op) { case IrBinOpAddWrap: case IrBinOpSubWrap: case IrBinOpMultWrap: - case IrBinOpRemUnspecified: case IrBinOpArrayCat: case IrBinOpArrayMult: return false; @@ -16991,6 +17019,53 @@ static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) { zig_unreachable(); } +static bool value_cmp_numeric_val(ZigValue *left, Cmp predicate, ZigValue *right, bool any) { + assert(left->special == ConstValSpecialStatic); + assert(right == nullptr || right->special == ConstValSpecialStatic); + + switch (left->type->id) { + case ZigTypeIdComptimeInt: + case ZigTypeIdInt: { + const Cmp result = right ? + bigint_cmp(&left->data.x_bigint, &right->data.x_bigint) : + bigint_cmp_zero(&left->data.x_bigint); + return result == predicate; + } + case ZigTypeIdComptimeFloat: + case ZigTypeIdFloat: { + if (float_is_nan(left)) + return false; + if (right != nullptr && float_is_nan(right)) + return false; + + const Cmp result = right ? float_cmp(left, right) : float_cmp_zero(left); + return result == predicate; + } + case ZigTypeIdVector: { + for (size_t i = 0; i < left->type->data.vector.len; i++) { + ZigValue *scalar_val = &left->data.x_array.data.s_none.elements[i]; + const bool result = value_cmp_numeric_val(scalar_val, predicate, right, any); + + if (any && result) + return true; // This element satisfies the predicate + else if (!any && !result) + return false; // This element doesn't satisfy the predicate + } + return any ? false : true; + } + default: + zig_unreachable(); + } +} + +static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right) { + return value_cmp_numeric_val(left, predicate, right, true); +} + +static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right) { + return value_cmp_numeric_val(left, predicate, right, false); +} + static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruction) { Error err; @@ -17096,127 +17171,13 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc if (type_is_invalid(resolved_type)) return ira->codegen->invalid_inst_gen; - bool is_int = resolved_type->id == ZigTypeIdInt || resolved_type->id == ZigTypeIdComptimeInt; - bool is_float = resolved_type->id == ZigTypeIdFloat || resolved_type->id == ZigTypeIdComptimeFloat; - bool is_signed_div = ( - (resolved_type->id == ZigTypeIdInt && resolved_type->data.integral.is_signed) || - resolved_type->id == ZigTypeIdFloat || - (resolved_type->id == ZigTypeIdComptimeFloat && - ((bigfloat_cmp_zero(&op1->value->data.x_bigfloat) != CmpGT) != - (bigfloat_cmp_zero(&op2->value->data.x_bigfloat) != CmpGT))) || - (resolved_type->id == ZigTypeIdComptimeInt && - ((bigint_cmp_zero(&op1->value->data.x_bigint) != CmpGT) != - (bigint_cmp_zero(&op2->value->data.x_bigint) != CmpGT))) - ); - if (op_id == IrBinOpDivUnspecified && is_int) { - if (is_signed_div) { - bool ok = false; - if (instr_is_comptime(op1) && instr_is_comptime(op2)) { - ZigValue *op1_val = ir_resolve_const(ira, op1, UndefBad); - if (op1_val == nullptr) - return ira->codegen->invalid_inst_gen; - - ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad); - if (op2_val == nullptr) - return ira->codegen->invalid_inst_gen; - - if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ) { - // the division by zero error will be caught later, but we don't have a - // division function ambiguity problem. - op_id = IrBinOpDivTrunc; - ok = true; - } else { - BigInt trunc_result; - BigInt floor_result; - bigint_div_trunc(&trunc_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint); - bigint_div_floor(&floor_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint); - if (bigint_cmp(&trunc_result, &floor_result) == CmpEQ) { - ok = true; - op_id = IrBinOpDivTrunc; - } - } - } - if (!ok) { - ir_add_error(ira, &instruction->base.base, - buf_sprintf("division with '%s' and '%s': signed integers must use @divTrunc, @divFloor, or @divExact", - buf_ptr(&op1->value->type->name), - buf_ptr(&op2->value->type->name))); - return ira->codegen->invalid_inst_gen; - } - } else { - op_id = IrBinOpDivTrunc; - } - } else if (op_id == IrBinOpRemUnspecified) { - if (is_signed_div && (is_int || is_float)) { - bool ok = false; - if (instr_is_comptime(op1) && instr_is_comptime(op2)) { - ZigValue *op1_val = ir_resolve_const(ira, op1, UndefBad); - if (op1_val == nullptr) - return ira->codegen->invalid_inst_gen; + ZigType *scalar_type = (resolved_type->id == ZigTypeIdVector) ? + resolved_type->data.vector.elem_type : resolved_type; - if (is_int) { - ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad); - if (op2_val == nullptr) - return ira->codegen->invalid_inst_gen; + bool is_int = scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdComptimeInt; + bool is_float = scalar_type->id == ZigTypeIdFloat || scalar_type->id == ZigTypeIdComptimeFloat; - if (bigint_cmp_zero(&op2->value->data.x_bigint) == CmpEQ) { - // the division by zero error will be caught later, but we don't - // have a remainder function ambiguity problem - ok = true; - } else { - BigInt rem_result; - BigInt mod_result; - bigint_rem(&rem_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint); - bigint_mod(&mod_result, &op1_val->data.x_bigint, &op2_val->data.x_bigint); - ok = bigint_cmp(&rem_result, &mod_result) == CmpEQ; - } - } else { - IrInstGen *casted_op2 = ir_implicit_cast(ira, op2, resolved_type); - if (type_is_invalid(casted_op2->value->type)) - return ira->codegen->invalid_inst_gen; - - ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad); - if (op2_val == nullptr) - return ira->codegen->invalid_inst_gen; - - if (float_cmp_zero(casted_op2->value) == CmpEQ) { - // the division by zero error will be caught later, but we don't - // have a remainder function ambiguity problem - ok = true; - } else { - ZigValue rem_result = {}; - ZigValue mod_result = {}; - float_rem(&rem_result, op1_val, op2_val); - float_mod(&mod_result, op1_val, op2_val); - ok = float_cmp(&rem_result, &mod_result) == CmpEQ; - } - } - } - if (!ok) { - ir_add_error(ira, &instruction->base.base, - buf_sprintf("remainder division with '%s' and '%s': signed integers and floats must use @rem or @mod", - buf_ptr(&op1->value->type->name), - buf_ptr(&op2->value->type->name))); - return ira->codegen->invalid_inst_gen; - } - } - op_id = IrBinOpRemRem; - } - - bool ok = false; - if (is_int) { - ok = true; - } else if (is_float && ok_float_op(op_id)) { - ok = true; - } else if (resolved_type->id == ZigTypeIdVector) { - ZigType *elem_type = resolved_type->data.vector.elem_type; - if (elem_type->id == ZigTypeIdInt || elem_type->id == ZigTypeIdComptimeInt) { - ok = true; - } else if ((elem_type->id == ZigTypeIdFloat || elem_type->id == ZigTypeIdComptimeFloat) && ok_float_op(op_id)) { - ok = true; - } - } - if (!ok) { + if (!is_int && !(is_float && ok_float_op(op_id))) { AstNode *source_node = instruction->base.base.source_node; ir_add_error_node(ira, source_node, buf_sprintf("invalid operands to binary expression: '%s' and '%s'", @@ -17225,7 +17186,16 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc return ira->codegen->invalid_inst_gen; } - if (resolved_type->id == ZigTypeIdComptimeInt) { + IrInstGen *casted_op1 = ir_implicit_cast(ira, op1, resolved_type); + if (type_is_invalid(casted_op1->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *casted_op2 = ir_implicit_cast(ira, op2, resolved_type); + if (type_is_invalid(casted_op2->value->type)) + return ira->codegen->invalid_inst_gen; + + // Comptime integers have no fixed size + if (scalar_type->id == ZigTypeIdComptimeInt) { if (op_id == IrBinOpAddWrap) { op_id = IrBinOpAdd; } else if (op_id == IrBinOpSubWrap) { @@ -17235,25 +17205,131 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc } } - IrInstGen *casted_op1 = ir_implicit_cast(ira, op1, resolved_type); - if (type_is_invalid(casted_op1->value->type)) - return ira->codegen->invalid_inst_gen; - - IrInstGen *casted_op2 = ir_implicit_cast(ira, op2, resolved_type); - if (type_is_invalid(casted_op2->value->type)) - return ira->codegen->invalid_inst_gen; - if (instr_is_comptime(casted_op1) && instr_is_comptime(casted_op2)) { ZigValue *op1_val = ir_resolve_const(ira, casted_op1, UndefBad); if (op1_val == nullptr) return ira->codegen->invalid_inst_gen; + ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad); if (op2_val == nullptr) return ira->codegen->invalid_inst_gen; + // Promote division with negative numbers to signed + bool is_signed_div = value_cmp_numeric_val_any(op1_val, CmpLT, nullptr) || + value_cmp_numeric_val_any(op2_val, CmpLT, nullptr); + + if (op_id == IrBinOpDivUnspecified && is_int) { + // Default to truncating division and check if it's valid for the + // given operands if signed + op_id = IrBinOpDivTrunc; + + if (is_signed_div) { + bool ok = false; + + if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) { + // the division by zero error will be caught later, but we don't have a + // division function ambiguity problem. + ok = true; + } else { + IrInstGen *trunc_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type, + op1_val, IrBinOpDivTrunc, op2_val); + if (type_is_invalid(trunc_val->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *floor_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type, + op1_val, IrBinOpDivFloor, op2_val); + if (type_is_invalid(floor_val->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *cmp_val = ir_analyze_bin_op_cmp_numeric(ira, &instruction->base.base, + trunc_val, floor_val, IrBinOpCmpEq); + if (type_is_invalid(cmp_val->value->type)) + return ira->codegen->invalid_inst_gen; + + // We can "upgrade" the operator only if trunc(a/b) == floor(a/b) + if (!ir_resolve_bool(ira, cmp_val, &ok)) + return ira->codegen->invalid_inst_gen; + } + + if (!ok) { + ir_add_error(ira, &instruction->base.base, + buf_sprintf("division with '%s' and '%s': signed integers must use @divTrunc, @divFloor, or @divExact", + buf_ptr(&op1->value->type->name), + buf_ptr(&op2->value->type->name))); + return ira->codegen->invalid_inst_gen; + } + } + } else if (op_id == IrBinOpRemUnspecified) { + op_id = IrBinOpRemRem; + + if (is_signed_div) { + bool ok = false; + + if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) { + // the division by zero error will be caught later, but we don't have a + // division function ambiguity problem. + ok = true; + } else { + IrInstGen *rem_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type, + op1_val, IrBinOpRemRem, op2_val); + if (type_is_invalid(rem_val->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *mod_val = ir_analyze_math_op(ira, &instruction->base.base, resolved_type, + op1_val, IrBinOpRemMod, op2_val); + if (type_is_invalid(mod_val->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *cmp_val = ir_analyze_bin_op_cmp_numeric(ira, &instruction->base.base, + rem_val, mod_val, IrBinOpCmpEq); + if (type_is_invalid(cmp_val->value->type)) + return ira->codegen->invalid_inst_gen; + + // We can "upgrade" the operator only if mod(a,b) == rem(a,b) + if (!ir_resolve_bool(ira, cmp_val, &ok)) + return ira->codegen->invalid_inst_gen; + } + + if (!ok) { + ir_add_error(ira, &instruction->base.base, + buf_sprintf("remainder division with '%s' and '%s': signed integers and floats must use @rem or @mod", + buf_ptr(&op1->value->type->name), + buf_ptr(&op2->value->type->name))); + return ira->codegen->invalid_inst_gen; + } + } + } + return ir_analyze_math_op(ira, &instruction->base.base, resolved_type, op1_val, op_id, op2_val); } + const bool is_signed_div = + (scalar_type->id == ZigTypeIdInt && scalar_type->data.integral.is_signed) || + scalar_type->id == ZigTypeIdFloat; + + // Warn the user to use the proper operators here + if (op_id == IrBinOpDivUnspecified && is_int) { + op_id = IrBinOpDivTrunc; + + if (is_signed_div) { + ir_add_error(ira, &instruction->base.base, + buf_sprintf("division with '%s' and '%s': signed integers must use @divTrunc, @divFloor, or @divExact", + buf_ptr(&op1->value->type->name), + buf_ptr(&op2->value->type->name))); + return ira->codegen->invalid_inst_gen; + } + } else if (op_id == IrBinOpRemUnspecified) { + op_id = IrBinOpRemRem; + + if (is_signed_div) { + ir_add_error(ira, &instruction->base.base, + buf_sprintf("remainder division with '%s' and '%s': signed integers and floats must use @rem or @mod", + buf_ptr(&op1->value->type->name), + buf_ptr(&op2->value->type->name))); + return ira->codegen->invalid_inst_gen; + } + } + return ir_build_bin_op_gen(ira, &instruction->base.base, resolved_type, op_id, casted_op1, casted_op2, instruction->safety_check_on); } @@ -20337,24 +20413,45 @@ static IrInstGen *ir_analyze_bin_not(IrAnalyze *ira, IrInstSrcUnOp *instruction) if (type_is_invalid(expr_type)) return ira->codegen->invalid_inst_gen; - if (expr_type->id == ZigTypeIdInt) { - if (instr_is_comptime(value)) { - ZigValue *target_const_val = ir_resolve_const(ira, value, UndefBad); - if (target_const_val == nullptr) - return ira->codegen->invalid_inst_gen; + ZigType *scalar_type = (expr_type->id == ZigTypeIdVector) ? + expr_type->data.vector.elem_type : expr_type; - IrInstGen *result = ir_const(ira, &instruction->base.base, expr_type); - bigint_not(&result->value->data.x_bigint, &target_const_val->data.x_bigint, - expr_type->data.integral.bit_count, expr_type->data.integral.is_signed); - return result; + if (scalar_type->id != ZigTypeIdInt) { + ir_add_error(ira, &instruction->base.base, + buf_sprintf("unable to perform binary not operation on type '%s'", buf_ptr(&expr_type->name))); + return ira->codegen->invalid_inst_gen; + } + + if (instr_is_comptime(value)) { + ZigValue *expr_val = ir_resolve_const(ira, value, UndefBad); + if (expr_val == nullptr) + return ira->codegen->invalid_inst_gen; + + IrInstGen *result = ir_const(ira, &instruction->base.base, expr_type); + + if (expr_type->id == ZigTypeIdVector) { + expand_undef_array(ira->codegen, expr_val); + result->value->special = ConstValSpecialUndef; + expand_undef_array(ira->codegen, result->value); + + for (size_t i = 0; i < expr_type->data.vector.len; i++) { + ZigValue *src_val = &expr_val->data.x_array.data.s_none.elements[i]; + ZigValue *dst_val = &result->value->data.x_array.data.s_none.elements[i]; + + dst_val->type = scalar_type; + dst_val->special = ConstValSpecialStatic; + bigint_not(&dst_val->data.x_bigint, &src_val->data.x_bigint, + scalar_type->data.integral.bit_count, scalar_type->data.integral.is_signed); + } + } else { + bigint_not(&result->value->data.x_bigint, &expr_val->data.x_bigint, + scalar_type->data.integral.bit_count, scalar_type->data.integral.is_signed); } - return ir_build_binary_not(ira, &instruction->base.base, value, expr_type); + return result; } - ir_add_error(ira, &instruction->base.base, - buf_sprintf("unable to perform binary not operation on type '%s'", buf_ptr(&expr_type->name))); - return ira->codegen->invalid_inst_gen; + return ir_build_binary_not(ira, &instruction->base.base, value, expr_type); } static IrInstGen *ir_analyze_instruction_un_op(IrAnalyze *ira, IrInstSrcUnOp *instruction) { |
