aboutsummaryrefslogtreecommitdiff
path: root/src/stage1/ir.cpp
diff options
context:
space:
mode:
authorLemonBoy <thatlemon@gmail.com>2020-10-04 18:23:52 +0200
committerAndrew Kelley <andrew@ziglang.org>2020-10-05 04:51:45 -0400
commit22b5e47839cf34c1e4a7c5e6dc256e041b4bf8fc (patch)
treea940511e3d881231d60276f824a6a9055877dc83 /src/stage1/ir.cpp
parent7c5a24e08cd0bffd2a5cce6d1fd592a7d2bee678 (diff)
downloadzig-22b5e47839cf34c1e4a7c5e6dc256e041b4bf8fc.tar.gz
zig-22b5e47839cf34c1e4a7c5e6dc256e041b4bf8fc.zip
stage1: Implement @reduce builtin for vector types
The builtin folds a Vector(N,T) into a scalar T using a specified operator. Closes #2698
Diffstat (limited to 'src/stage1/ir.cpp')
-rw-r--r--src/stage1/ir.cpp227
1 files changed, 227 insertions, 0 deletions
diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp
index 045f1ad784..095ed301c9 100644
--- a/src/stage1/ir.cpp
+++ b/src/stage1/ir.cpp
@@ -402,6 +402,8 @@ static void destroy_instruction_src(IrInstSrc *inst) {
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcCmpxchg *>(inst));
case IrInstSrcIdFence:
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcFence *>(inst));
+ case IrInstSrcIdReduce:
+ return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcReduce *>(inst));
case IrInstSrcIdTruncate:
return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcTruncate *>(inst));
case IrInstSrcIdIntCast:
@@ -636,6 +638,8 @@ void destroy_instruction_gen(IrInstGen *inst) {
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenCmpxchg *>(inst));
case IrInstGenIdFence:
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenFence *>(inst));
+ case IrInstGenIdReduce:
+ return heap::c_allocator.destroy(reinterpret_cast<IrInstGenReduce *>(inst));
case IrInstGenIdTruncate:
return heap::c_allocator.destroy(reinterpret_cast<IrInstGenTruncate *>(inst));
case IrInstGenIdShuffleVector:
@@ -1311,6 +1315,10 @@ static constexpr IrInstSrcId ir_inst_id(IrInstSrcFence *) {
return IrInstSrcIdFence;
}
+static constexpr IrInstSrcId ir_inst_id(IrInstSrcReduce *) {
+ return IrInstSrcIdReduce;
+}
+
static constexpr IrInstSrcId ir_inst_id(IrInstSrcTruncate *) {
return IrInstSrcIdTruncate;
}
@@ -1775,6 +1783,10 @@ static constexpr IrInstGenId ir_inst_id(IrInstGenFence *) {
return IrInstGenIdFence;
}
+static constexpr IrInstGenId ir_inst_id(IrInstGenReduce *) {
+ return IrInstGenIdReduce;
+}
+
static constexpr IrInstGenId ir_inst_id(IrInstGenTruncate *) {
return IrInstGenIdTruncate;
}
@@ -3502,6 +3514,29 @@ static IrInstGen *ir_build_fence_gen(IrAnalyze *ira, IrInst *source_instr, Atomi
return &instruction->base;
}
+static IrInstSrc *ir_build_reduce(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, IrInstSrc *op, IrInstSrc *value) {
+ IrInstSrcReduce *instruction = ir_build_instruction<IrInstSrcReduce>(irb, scope, source_node);
+ instruction->op = op;
+ instruction->value = value;
+
+ ir_ref_instruction(op, irb->current_basic_block);
+ ir_ref_instruction(value, irb->current_basic_block);
+
+ return &instruction->base;
+}
+
+static IrInstGen *ir_build_reduce_gen(IrAnalyze *ira, IrInst *source_instruction, ReduceOp op, IrInstGen *value, ZigType *result_type) {
+ IrInstGenReduce *instruction = ir_build_inst_gen<IrInstGenReduce>(&ira->new_irb,
+ source_instruction->scope, source_instruction->source_node);
+ instruction->base.value->type = result_type;
+ instruction->op = op;
+ instruction->value = value;
+
+ ir_ref_inst_gen(value);
+
+ return &instruction->base;
+}
+
static IrInstSrc *ir_build_truncate(IrBuilderSrc *irb, Scope *scope, AstNode *source_node,
IrInstSrc *dest_type, IrInstSrc *target)
{
@@ -6580,6 +6615,21 @@ static IrInstSrc *ir_gen_builtin_fn_call(IrBuilderSrc *irb, Scope *scope, AstNod
IrInstSrc *fence = ir_build_fence(irb, scope, node, arg0_value);
return ir_lval_wrap(irb, scope, fence, lval, result_loc);
}
+ case BuiltinFnIdReduce:
+ {
+ AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
+ IrInstSrc *arg0_value = ir_gen_node(irb, arg0_node, scope);
+ if (arg0_value == irb->codegen->invalid_inst_src)
+ return arg0_value;
+
+ AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
+ IrInstSrc *arg1_value = ir_gen_node(irb, arg1_node, scope);
+ if (arg1_value == irb->codegen->invalid_inst_src)
+ return arg1_value;
+
+ IrInstSrc *reduce = ir_build_reduce(irb, scope, node, arg0_value, arg1_value);
+ return ir_lval_wrap(irb, scope, reduce, lval, result_loc);
+ }
case BuiltinFnIdDivExact:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@@ -15932,6 +15982,24 @@ static bool ir_resolve_comptime(IrAnalyze *ira, IrInstGen *value, bool *out) {
return ir_resolve_bool(ira, value, out);
}
+static bool ir_resolve_reduce_op(IrAnalyze *ira, IrInstGen *value, ReduceOp *out) {
+ if (type_is_invalid(value->value->type))
+ return false;
+
+ ZigType *reduce_op_type = get_builtin_type(ira->codegen, "ReduceOp");
+
+ IrInstGen *casted_value = ir_implicit_cast(ira, value, reduce_op_type);
+ if (type_is_invalid(casted_value->value->type))
+ return false;
+
+ ZigValue *const_val = ir_resolve_const(ira, casted_value, UndefBad);
+ if (!const_val)
+ return false;
+
+ *out = (ReduceOp)bigint_as_u32(&const_val->data.x_enum_tag);
+ return true;
+}
+
static bool ir_resolve_atomic_order(IrAnalyze *ira, IrInstGen *value, AtomicOrder *out) {
if (type_is_invalid(value->value->type))
return false;
@@ -26802,6 +26870,161 @@ static IrInstGen *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstSrcCmpxch
success_order, failure_order, instruction->is_weak, result_loc);
}
+static ErrorMsg *ir_eval_reduce(IrAnalyze *ira, IrInst *source_instr, ReduceOp op, ZigValue *value, ZigValue *out_value) {
+ assert(value->type->id == ZigTypeIdVector);
+ ZigType *scalar_type = value->type->data.vector.elem_type;
+ const size_t len = value->type->data.vector.len;
+ assert(len > 0);
+
+ out_value->type = scalar_type;
+ out_value->special = ConstValSpecialStatic;
+
+ if (scalar_type->id == ZigTypeIdBool) {
+ ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0];
+
+ bool result = first_elem_val->data.x_bool;
+ for (size_t i = 1; i < len; i++) {
+ ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
+
+ switch (op) {
+ case ReduceOp_and:
+ result = result && elem_val->data.x_bool;
+ if (!result) break; // Short circuit
+ break;
+ case ReduceOp_or:
+ result = result || elem_val->data.x_bool;
+ if (result) break; // Short circuit
+ break;
+ case ReduceOp_xor:
+ result = result != elem_val->data.x_bool;
+ break;
+ default:
+ zig_unreachable();
+ }
+ }
+
+ out_value->data.x_bool = result;
+ return nullptr;
+ }
+
+ if (op != ReduceOp_min && op != ReduceOp_max) {
+ ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0];
+
+ copy_const_val(ira->codegen, out_value, first_elem_val);
+
+ for (size_t i = 1; i < len; i++) {
+ ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
+
+ IrBinOp bin_op;
+ switch (op) {
+ case ReduceOp_and: bin_op = IrBinOpBinAnd; break;
+ case ReduceOp_or: bin_op = IrBinOpBinOr; break;
+ case ReduceOp_xor: bin_op = IrBinOpBinXor; break;
+ default: zig_unreachable();
+ }
+
+ ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type,
+ out_value, bin_op, elem_val, out_value);
+ if (msg != nullptr)
+ return msg;
+ }
+
+ return nullptr;
+ }
+
+ ZigValue *candidate_elem_val = &value->data.x_array.data.s_none.elements[0];
+
+ ZigValue *dummy_cmp_value = ira->codegen->pass1_arena->create<ZigValue>();
+ for (size_t i = 1; i < len; i++) {
+ ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i];
+
+ IrBinOp bin_op;
+ switch (op) {
+ case ReduceOp_min: bin_op = IrBinOpCmpLessThan; break;
+ case ReduceOp_max: bin_op = IrBinOpCmpGreaterThan; break;
+ default: zig_unreachable();
+ }
+
+ ErrorMsg *msg = ir_eval_bin_op_cmp_scalar(ira, source_instr,
+ elem_val, bin_op, candidate_elem_val, dummy_cmp_value);
+ if (msg != nullptr)
+ return msg;
+
+ if (dummy_cmp_value->data.x_bool)
+ candidate_elem_val = elem_val;
+ }
+
+ ira->codegen->pass1_arena->destroy(dummy_cmp_value);
+ copy_const_val(ira->codegen, out_value, candidate_elem_val);
+
+ return nullptr;
+}
+
+static IrInstGen *ir_analyze_instruction_reduce(IrAnalyze *ira, IrInstSrcReduce *instruction) {
+ IrInstGen *op_inst = instruction->op->child;
+ if (type_is_invalid(op_inst->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ IrInstGen *value_inst = instruction->value->child;
+ if (type_is_invalid(value_inst->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ ZigType *value_type = value_inst->value->type;
+ if (value_type->id != ZigTypeIdVector) {
+ ir_add_error(ira, &value_inst->base,
+ buf_sprintf("expected vector type, found '%s'",
+ buf_ptr(&value_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ }
+
+ ReduceOp op;
+ if (!ir_resolve_reduce_op(ira, op_inst, &op))
+ return ira->codegen->invalid_inst_gen;
+
+ ZigType *elem_type = value_type->data.vector.elem_type;
+ switch (elem_type->id) {
+ case ZigTypeIdInt:
+ break;
+ case ZigTypeIdBool:
+ if (op > ReduceOp_xor) {
+ ir_add_error(ira, &op_inst->base,
+ buf_sprintf("invalid operation for '%s' type",
+ buf_ptr(&elem_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ } break;
+ case ZigTypeIdFloat:
+ if (op < ReduceOp_min) {
+ ir_add_error(ira, &op_inst->base,
+ buf_sprintf("invalid operation for '%s' type",
+ buf_ptr(&elem_type->name)));
+ return ira->codegen->invalid_inst_gen;
+ } break;
+ default:
+ // Vectors cannot have child types other than those listed above
+ zig_unreachable();
+ }
+
+ // special case zero bit types
+ switch (type_has_one_possible_value(ira->codegen, elem_type)) {
+ case OnePossibleValueInvalid:
+ return ira->codegen->invalid_inst_gen;
+ case OnePossibleValueYes:
+ return ir_const_move(ira, &instruction->base.base,
+ get_the_one_possible_value(ira->codegen, elem_type));
+ case OnePossibleValueNo:
+ break;
+ }
+
+ if (instr_is_comptime(value_inst)) {
+ IrInstGen *result = ir_const(ira, &instruction->base.base, elem_type);
+ if (ir_eval_reduce(ira, &instruction->base.base, op, value_inst->value, result->value))
+ return ira->codegen->invalid_inst_gen;
+ return result;
+ }
+
+ return ir_build_reduce_gen(ira, &instruction->base.base, op, value_inst, elem_type);
+}
+
static IrInstGen *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstSrcFence *instruction) {
IrInstGen *order_inst = instruction->order->child;
if (type_is_invalid(order_inst->value->type))
@@ -31550,6 +31773,8 @@ static IrInstGen *ir_analyze_instruction_base(IrAnalyze *ira, IrInstSrc *instruc
return ir_analyze_instruction_cmpxchg(ira, (IrInstSrcCmpxchg *)instruction);
case IrInstSrcIdFence:
return ir_analyze_instruction_fence(ira, (IrInstSrcFence *)instruction);
+ case IrInstSrcIdReduce:
+ return ir_analyze_instruction_reduce(ira, (IrInstSrcReduce *)instruction);
case IrInstSrcIdTruncate:
return ir_analyze_instruction_truncate(ira, (IrInstSrcTruncate *)instruction);
case IrInstSrcIdIntCast:
@@ -31937,6 +32162,7 @@ bool ir_inst_gen_has_side_effects(IrInstGen *instruction) {
case IrInstGenIdNegation:
case IrInstGenIdNegationWrapping:
case IrInstGenIdWasmMemorySize:
+ case IrInstGenIdReduce:
return false;
case IrInstGenIdAsm:
@@ -32106,6 +32332,7 @@ bool ir_inst_src_has_side_effects(IrInstSrc *instruction) {
case IrInstSrcIdSpillEnd:
case IrInstSrcIdWasmMemorySize:
case IrInstSrcIdSrc:
+ case IrInstSrcIdReduce:
return false;
case IrInstSrcIdAsm: