aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLemonBoy <thatlemon@gmail.com>2020-03-14 20:01:28 +0100
committerAndrew Kelley <andrew@ziglang.org>2020-04-05 18:34:31 -0400
commitd2d97e55ccd2d7c992d01bd05ea52a52fe36776e (patch)
tree98da26db32bc99194799c677cc46fbc72704fd1c /src
parent2485f3004659723a1ccd2799a6e0bddb09e32d3b (diff)
downloadzig-d2d97e55ccd2d7c992d01bd05ea52a52fe36776e.tar.gz
zig-d2d97e55ccd2d7c992d01bd05ea52a52fe36776e.zip
ir: Support shift left/right on vectors
Diffstat (limited to 'src')
-rw-r--r--src/codegen.cpp50
-rw-r--r--src/ir.cpp114
2 files changed, 117 insertions, 47 deletions
diff --git a/src/codegen.cpp b/src/codegen.cpp
index 0fa181b32c..97d960b523 100644
--- a/src/codegen.cpp
+++ b/src/codegen.cpp
@@ -155,6 +155,7 @@ static LLVMValueRef gen_await_early_return(CodeGen *g, IrInstGen *source_instr,
LLVMValueRef target_frame_ptr, ZigType *result_type, ZigType *ptr_result_type,
LLVMValueRef result_loc, bool non_async);
static Error get_tmp_filename(CodeGen *g, Buf *out, Buf *suffix);
+static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val);
static void addLLVMAttr(LLVMValueRef val, LLVMAttributeIndex attr_index, const char *attr_name) {
unsigned kind_id = LLVMGetEnumAttributeKindForName(attr_name, strlen(attr_name));
@@ -2535,19 +2536,21 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
return nullptr;
}
-static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry,
- LLVMValueRef val1, LLVMValueRef val2)
+static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
+ LLVMValueRef val1, LLVMValueRef val2)
{
// for unsigned left shifting, we do the lossy shift, then logically shift
// right the same number of bits
// if the values don't match, we have an overflow
// for signed left shifting we do the same except arithmetic shift right
+ ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ?
+ operand_type->data.vector.elem_type : operand_type;
- assert(type_entry->id == ZigTypeIdInt);
+ assert(scalar_type->id == ZigTypeIdInt);
LLVMValueRef result = LLVMBuildShl(g->builder, val1, val2, "");
LLVMValueRef orig_val;
- if (type_entry->data.integral.is_signed) {
+ if (scalar_type->data.integral.is_signed) {
orig_val = LLVMBuildAShr(g->builder, result, val2, "");
} else {
orig_val = LLVMBuildLShr(g->builder, result, val2, "");
@@ -2556,6 +2559,9 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
+ if (operand_type->id == ZigTypeIdVector) {
+ ok_bit = scalarize_cmp_result(g, ok_bit);
+ }
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
@@ -2565,13 +2571,16 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *type_entry,
return result;
}
-static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry,
- LLVMValueRef val1, LLVMValueRef val2)
+static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
+ LLVMValueRef val1, LLVMValueRef val2)
{
- assert(type_entry->id == ZigTypeIdInt);
+ ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ?
+ operand_type->data.vector.elem_type : operand_type;
+
+ assert(scalar_type->id == ZigTypeIdInt);
LLVMValueRef result;
- if (type_entry->data.integral.is_signed) {
+ if (scalar_type->data.integral.is_signed) {
result = LLVMBuildAShr(g->builder, val1, val2, "");
} else {
result = LLVMBuildLShr(g->builder, val1, val2, "");
@@ -2581,6 +2590,9 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
+ if (operand_type->id == ZigTypeIdVector) {
+ ok_bit = scalarize_cmp_result(g, ok_bit);
+ }
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
@@ -2897,11 +2909,17 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
// otherwise the check is useful as the allowed values are limited by the
// operand type itself
if (!is_power_of_2(lhs_type->data.integral.bit_count)) {
- LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type),
- lhs_type->data.integral.bit_count, false);
- LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
+ BigInt bit_count_bi = {0};
+ bigint_init_unsigned(&bit_count_bi, lhs_type->data.integral.bit_count);
+ LLVMValueRef bit_count_value = bigint_to_llvm_const(get_llvm_type(g, rhs_type),
+ &bit_count_bi);
+
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail");
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
+ LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
+ if (rhs_type->id == ZigTypeIdVector) {
+ less_than_bit = scalarize_cmp_result(g, less_than_bit);
+ }
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
@@ -3018,7 +3036,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
case IrBinOpBitShiftLeftExact:
{
assert(scalar_type->id == ZigTypeIdInt);
- LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
+ LLVMValueRef op2_casted = LLVMBuildZExt(g->builder, op2_value,
+ LLVMTypeOf(op1_value), "");//gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
if (want_runtime_safety) {
gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
@@ -3028,7 +3047,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
if (is_sloppy) {
return LLVMBuildShl(g->builder, op1_value, op2_casted, "");
} else if (want_runtime_safety) {
- return gen_overflow_shl_op(g, scalar_type, op1_value, op2_casted);
+ return gen_overflow_shl_op(g, operand_type, op1_value, op2_casted);
} else if (scalar_type->data.integral.is_signed) {
return ZigLLVMBuildNSWShl(g->builder, op1_value, op2_casted, "");
} else {
@@ -3039,7 +3058,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
case IrBinOpBitShiftRightExact:
{
assert(scalar_type->id == ZigTypeIdInt);
- LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
+ LLVMValueRef op2_casted = LLVMBuildZExt(g->builder, op2_value,
+ LLVMTypeOf(op1_value), "");//gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
if (want_runtime_safety) {
gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
@@ -3053,7 +3073,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
return LLVMBuildLShr(g->builder, op1_value, op2_casted, "");
}
} else if (want_runtime_safety) {
- return gen_overflow_shr_op(g, scalar_type, op1_value, op2_casted);
+ return gen_overflow_shr_op(g, operand_type, op1_value, op2_casted);
} else if (scalar_type->data.integral.is_signed) {
return ZigLLVMBuildAShrExact(g->builder, op1_value, op2_casted, "");
} else {
diff --git a/src/ir.cpp b/src/ir.cpp
index 436db592f2..6fed044c6c 100644
--- a/src/ir.cpp
+++ b/src/ir.cpp
@@ -283,6 +283,8 @@ static IrInstGen *ir_analyze_union_init(IrAnalyze *ira, IrInst* source_instructi
IrInstGen *result_loc);
static IrInstGen *ir_analyze_struct_value_field_value(IrAnalyze *ira, IrInst* source_instr,
IrInstGen *struct_operand, TypeStructField *field);
+static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right);
+static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right);
static void destroy_instruction_src(IrInstSrc *inst) {
switch (inst->id) {
@@ -16803,7 +16805,6 @@ static IrInstGen *ir_analyze_math_op(IrAnalyze *ira, IrInst* source_instr,
ZigValue *scalar_op2_val = &op2_val->data.x_array.data.s_none.elements[i];
ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i];
assert(scalar_op1_val->type == scalar_type);
- assert(scalar_op2_val->type == scalar_type);
assert(scalar_out_val->type == scalar_type);
ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type,
scalar_op1_val, op_id, scalar_op2_val, scalar_out_val);
@@ -16828,27 +16829,49 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
if (type_is_invalid(op1->value->type))
return ira->codegen->invalid_inst_gen;
- if (op1->value->type->id != ZigTypeIdInt && op1->value->type->id != ZigTypeIdComptimeInt) {
+ IrInstGen *op2 = bin_op_instruction->op2->child;
+ if (type_is_invalid(op2->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ ZigType *op1_type = op1->value->type;
+ ZigType *op2_type = op2->value->type;
+
+ if (op1_type->id == ZigTypeIdVector && op2_type->id != ZigTypeIdVector) {
ir_add_error(ira, &bin_op_instruction->op1->base,
- buf_sprintf("bit shifting operation expected integer type, found '%s'",
- buf_ptr(&op1->value->type->name)));
+ buf_sprintf("bit shifting operation expected vector type, found '%s'",
+ buf_ptr(&op2_type->name)));
return ira->codegen->invalid_inst_gen;
}
- IrInstGen *op2 = bin_op_instruction->op2->child;
- if (type_is_invalid(op2->value->type))
+ if (op1_type->id != ZigTypeIdVector && op2_type->id == ZigTypeIdVector) {
+ ir_add_error(ira, &bin_op_instruction->op1->base,
+ buf_sprintf("bit shifting operation expected vector type, found '%s'",
+ buf_ptr(&op1_type->name)));
return ira->codegen->invalid_inst_gen;
+ }
+
+ ZigType *op1_scalar_type = (op1_type->id == ZigTypeIdVector) ?
+ op1_type->data.vector.elem_type : op1_type;
+ ZigType *op2_scalar_type = (op2_type->id == ZigTypeIdVector) ?
+ op2_type->data.vector.elem_type : op2_type;
+
+ if (op1_scalar_type->id != ZigTypeIdInt && op1_scalar_type->id != ZigTypeIdComptimeInt) {
+ ir_add_error(ira, &bin_op_instruction->op1->base,
+ buf_sprintf("bit shifting operation expected integer type, found '%s'",
+ buf_ptr(&op1_scalar_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
- if (op2->value->type->id != ZigTypeIdInt && op2->value->type->id != ZigTypeIdComptimeInt) {
+ if (op2_scalar_type->id != ZigTypeIdInt && op2_scalar_type->id != ZigTypeIdComptimeInt) {
ir_add_error(ira, &bin_op_instruction->op2->base,
buf_sprintf("shift amount has to be an integer type, but found '%s'",
- buf_ptr(&op2->value->type->name)));
+ buf_ptr(&op2_scalar_type->name)));
return ira->codegen->invalid_inst_gen;
}
IrInstGen *casted_op2;
IrBinOp op_id = bin_op_instruction->op_id;
- if (op1->value->type->id == ZigTypeIdComptimeInt) {
+ if (op1_scalar_type->id == ZigTypeIdComptimeInt) {
// comptime_int has no finite bit width
casted_op2 = op2;
@@ -16874,10 +16897,15 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
return ira->codegen->invalid_inst_gen;
}
} else {
- const unsigned bit_count = op1->value->type->data.integral.bit_count;
+ const unsigned bit_count = op1_scalar_type->data.integral.bit_count;
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
bit_count > 0 ? bit_count - 1 : 0);
+ if (op1_type->id == ZigTypeIdVector) {
+ shift_amt_type = get_vector_type(ira->codegen, op1_type->data.vector.len,
+ shift_amt_type);
+ }
+
casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
if (type_is_invalid(casted_op2->value->type))
return ira->codegen->invalid_inst_gen;
@@ -16888,10 +16916,10 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
- BigInt bit_count_value = {0};
- bigint_init_unsigned(&bit_count_value, bit_count);
+ ZigValue bit_count_value;
+ init_const_usize(ira->codegen, &bit_count_value, bit_count);
- if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) {
+ if (!value_cmp_numeric_val_all(op2_val, CmpLT, &bit_count_value)) {
ErrorMsg* msg = ir_add_error(ira,
&bin_op_instruction->base.base,
buf_sprintf("RHS of shift is too large for LHS type"));
@@ -16910,7 +16938,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
- if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ)
+ if (value_cmp_numeric_val_all(op2_val, CmpEQ, nullptr))
return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1);
}
@@ -16923,7 +16951,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
- return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val);
+ return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1_type, op1_val, op_id, op2_val);
}
return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type,
@@ -16991,31 +17019,53 @@ static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) {
zig_unreachable();
}
-static bool value_cmp_zero_any(ZigValue *value, Cmp predicate) {
- assert(value->special == ConstValSpecialStatic);
+static bool value_cmp_numeric_val(ZigValue *left, Cmp predicate, ZigValue *right, bool any) {
+ assert(left->special == ConstValSpecialStatic);
+ assert(right == nullptr || right->special == ConstValSpecialStatic);
- switch (value->type->id) {
+ switch (left->type->id) {
case ZigTypeIdComptimeInt:
- case ZigTypeIdInt:
- return bigint_cmp_zero(&value->data.x_bigint) == predicate;
+ case ZigTypeIdInt: {
+ const Cmp result = right ?
+ bigint_cmp(&left->data.x_bigint, &right->data.x_bigint) :
+ bigint_cmp_zero(&left->data.x_bigint);
+ return result == predicate;
+ }
case ZigTypeIdComptimeFloat:
- case ZigTypeIdFloat:
- if (float_is_nan(value))
+ case ZigTypeIdFloat: {
+ if (float_is_nan(left))
return false;
- return float_cmp_zero(value) == predicate;
+ if (right != nullptr && float_is_nan(right))
+ return false;
+
+ const Cmp result = right ? float_cmp(left, right) : float_cmp_zero(left);
+ return result == predicate;
+ }
case ZigTypeIdVector: {
- for (size_t i = 0; i < value->type->data.vector.len; i++) {
- ZigValue *scalar_val = &value->data.x_array.data.s_none.elements[i];
- if (!value_cmp_zero_any(scalar_val, predicate))
- return true;
+ for (size_t i = 0; i < left->type->data.vector.len; i++) {
+ ZigValue *scalar_val = &left->data.x_array.data.s_none.elements[i];
+ const bool result = value_cmp_numeric_val(scalar_val, predicate, right, any);
+
+ if (any && result)
+ return true; // This element satisfies the predicate
+ else if (!any && !result)
+ return false; // This element doesn't satisfy the predicate
}
- return false;
+ return any ? false : true;
}
default:
zig_unreachable();
}
}
+static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right) {
+ return value_cmp_numeric_val(left, predicate, right, true);
+}
+
+static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right) {
+ return value_cmp_numeric_val(left, predicate, right, false);
+}
+
static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruction) {
Error err;
@@ -17165,8 +17215,8 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
return ira->codegen->invalid_inst_gen;
// Promote division with negative numbers to signed
- bool is_signed_div = value_cmp_zero_any(op1_val, CmpLT) ||
- value_cmp_zero_any(op2_val, CmpLT);
+ bool is_signed_div = value_cmp_numeric_val_any(op1_val, CmpLT, nullptr) ||
+ value_cmp_numeric_val_any(op2_val, CmpLT, nullptr);
if (op_id == IrBinOpDivUnspecified && is_int) {
// Default to truncating division and check if it's valid for the
@@ -17176,7 +17226,7 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
if (is_signed_div) {
bool ok = false;
- if (value_cmp_zero_any(op2_val, CmpEQ)) {
+ if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) {
// the division by zero error will be caught later, but we don't have a
// division function ambiguity problem.
ok = true;
@@ -17215,7 +17265,7 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
if (is_signed_div) {
bool ok = false;
- if (value_cmp_zero_any(op2_val, CmpEQ)) {
+ if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) {
// the division by zero error will be caught later, but we don't have a
// division function ambiguity problem.
ok = true;