diff options
| author | LemonBoy <thatlemon@gmail.com> | 2020-10-04 18:23:52 +0200 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2020-10-05 04:51:45 -0400 |
| commit | 22b5e47839cf34c1e4a7c5e6dc256e041b4bf8fc (patch) | |
| tree | a940511e3d881231d60276f824a6a9055877dc83 /src/stage1/ir.cpp | |
| parent | 7c5a24e08cd0bffd2a5cce6d1fd592a7d2bee678 (diff) | |
| download | zig-22b5e47839cf34c1e4a7c5e6dc256e041b4bf8fc.tar.gz zig-22b5e47839cf34c1e4a7c5e6dc256e041b4bf8fc.zip | |
stage1: Implement @reduce builtin for vector types
The builtin folds a Vector(N,T) into a scalar T using a specified
operator.
Closes #2698
Diffstat (limited to 'src/stage1/ir.cpp')
| -rw-r--r-- | src/stage1/ir.cpp | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 045f1ad784..095ed301c9 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -402,6 +402,8 @@ static void destroy_instruction_src(IrInstSrc *inst) { return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcCmpxchg *>(inst)); case IrInstSrcIdFence: return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcFence *>(inst)); + case IrInstSrcIdReduce: + return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcReduce *>(inst)); case IrInstSrcIdTruncate: return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcTruncate *>(inst)); case IrInstSrcIdIntCast: @@ -636,6 +638,8 @@ void destroy_instruction_gen(IrInstGen *inst) { return heap::c_allocator.destroy(reinterpret_cast<IrInstGenCmpxchg *>(inst)); case IrInstGenIdFence: return heap::c_allocator.destroy(reinterpret_cast<IrInstGenFence *>(inst)); + case IrInstGenIdReduce: + return heap::c_allocator.destroy(reinterpret_cast<IrInstGenReduce *>(inst)); case IrInstGenIdTruncate: return heap::c_allocator.destroy(reinterpret_cast<IrInstGenTruncate *>(inst)); case IrInstGenIdShuffleVector: @@ -1311,6 +1315,10 @@ static constexpr IrInstSrcId ir_inst_id(IrInstSrcFence *) { return IrInstSrcIdFence; } +static constexpr IrInstSrcId ir_inst_id(IrInstSrcReduce *) { + return IrInstSrcIdReduce; +} + static constexpr IrInstSrcId ir_inst_id(IrInstSrcTruncate *) { return IrInstSrcIdTruncate; } @@ -1775,6 +1783,10 @@ static constexpr IrInstGenId ir_inst_id(IrInstGenFence *) { return IrInstGenIdFence; } +static constexpr IrInstGenId ir_inst_id(IrInstGenReduce *) { + return IrInstGenIdReduce; +} + static constexpr IrInstGenId ir_inst_id(IrInstGenTruncate *) { return IrInstGenIdTruncate; } @@ -3502,6 +3514,29 @@ static IrInstGen *ir_build_fence_gen(IrAnalyze *ira, IrInst *source_instr, Atomi return &instruction->base; } +static IrInstSrc *ir_build_reduce(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, IrInstSrc *op, IrInstSrc *value) { + IrInstSrcReduce *instruction = ir_build_instruction<IrInstSrcReduce>(irb, scope, source_node); + instruction->op = op; + instruction->value = value; + + ir_ref_instruction(op, irb->current_basic_block); + ir_ref_instruction(value, irb->current_basic_block); + + return &instruction->base; +} + +static IrInstGen *ir_build_reduce_gen(IrAnalyze *ira, IrInst *source_instruction, ReduceOp op, IrInstGen *value, ZigType *result_type) { + IrInstGenReduce *instruction = ir_build_inst_gen<IrInstGenReduce>(&ira->new_irb, + source_instruction->scope, source_instruction->source_node); + instruction->base.value->type = result_type; + instruction->op = op; + instruction->value = value; + + ir_ref_inst_gen(value); + + return &instruction->base; +} + static IrInstSrc *ir_build_truncate(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, IrInstSrc *dest_type, IrInstSrc *target) { @@ -6580,6 +6615,21 @@ static IrInstSrc *ir_gen_builtin_fn_call(IrBuilderSrc *irb, Scope *scope, AstNod IrInstSrc *fence = ir_build_fence(irb, scope, node, arg0_value); return ir_lval_wrap(irb, scope, fence, lval, result_loc); } + case BuiltinFnIdReduce: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstSrc *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_inst_src) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstSrc *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_inst_src) + return arg1_value; + + IrInstSrc *reduce = ir_build_reduce(irb, scope, node, arg0_value, arg1_value); + return ir_lval_wrap(irb, scope, reduce, lval, result_loc); + } case BuiltinFnIdDivExact: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); @@ -15932,6 +15982,24 @@ static bool ir_resolve_comptime(IrAnalyze *ira, IrInstGen *value, bool *out) { return ir_resolve_bool(ira, value, out); } +static bool ir_resolve_reduce_op(IrAnalyze *ira, IrInstGen *value, ReduceOp *out) { + if (type_is_invalid(value->value->type)) + return false; + + ZigType *reduce_op_type = get_builtin_type(ira->codegen, "ReduceOp"); + + IrInstGen *casted_value = ir_implicit_cast(ira, value, reduce_op_type); + if (type_is_invalid(casted_value->value->type)) + return false; + + ZigValue *const_val = ir_resolve_const(ira, casted_value, UndefBad); + if (!const_val) + return false; + + *out = (ReduceOp)bigint_as_u32(&const_val->data.x_enum_tag); + return true; +} + static bool ir_resolve_atomic_order(IrAnalyze *ira, IrInstGen *value, AtomicOrder *out) { if (type_is_invalid(value->value->type)) return false; @@ -26802,6 +26870,161 @@ static IrInstGen *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstSrcCmpxch success_order, failure_order, instruction->is_weak, result_loc); } +static ErrorMsg *ir_eval_reduce(IrAnalyze *ira, IrInst *source_instr, ReduceOp op, ZigValue *value, ZigValue *out_value) { + assert(value->type->id == ZigTypeIdVector); + ZigType *scalar_type = value->type->data.vector.elem_type; + const size_t len = value->type->data.vector.len; + assert(len > 0); + + out_value->type = scalar_type; + out_value->special = ConstValSpecialStatic; + + if (scalar_type->id == ZigTypeIdBool) { + ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0]; + + bool result = first_elem_val->data.x_bool; + for (size_t i = 1; i < len; i++) { + ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i]; + + switch (op) { + case ReduceOp_and: + result = result && elem_val->data.x_bool; + if (!result) break; // Short circuit + break; + case ReduceOp_or: + result = result || elem_val->data.x_bool; + if (result) break; // Short circuit + break; + case ReduceOp_xor: + result = result != elem_val->data.x_bool; + break; + default: + zig_unreachable(); + } + } + + out_value->data.x_bool = result; + return nullptr; + } + + if (op != ReduceOp_min && op != ReduceOp_max) { + ZigValue *first_elem_val = &value->data.x_array.data.s_none.elements[0]; + + copy_const_val(ira->codegen, out_value, first_elem_val); + + for (size_t i = 1; i < len; i++) { + ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i]; + + IrBinOp bin_op; + switch (op) { + case ReduceOp_and: bin_op = IrBinOpBinAnd; break; + case ReduceOp_or: bin_op = IrBinOpBinOr; break; + case ReduceOp_xor: bin_op = IrBinOpBinXor; break; + default: zig_unreachable(); + } + + ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type, + out_value, bin_op, elem_val, out_value); + if (msg != nullptr) + return msg; + } + + return nullptr; + } + + ZigValue *candidate_elem_val = &value->data.x_array.data.s_none.elements[0]; + + ZigValue *dummy_cmp_value = ira->codegen->pass1_arena->create<ZigValue>(); + for (size_t i = 1; i < len; i++) { + ZigValue *elem_val = &value->data.x_array.data.s_none.elements[i]; + + IrBinOp bin_op; + switch (op) { + case ReduceOp_min: bin_op = IrBinOpCmpLessThan; break; + case ReduceOp_max: bin_op = IrBinOpCmpGreaterThan; break; + default: zig_unreachable(); + } + + ErrorMsg *msg = ir_eval_bin_op_cmp_scalar(ira, source_instr, + elem_val, bin_op, candidate_elem_val, dummy_cmp_value); + if (msg != nullptr) + return msg; + + if (dummy_cmp_value->data.x_bool) + candidate_elem_val = elem_val; + } + + ira->codegen->pass1_arena->destroy(dummy_cmp_value); + copy_const_val(ira->codegen, out_value, candidate_elem_val); + + return nullptr; +} + +static IrInstGen *ir_analyze_instruction_reduce(IrAnalyze *ira, IrInstSrcReduce *instruction) { + IrInstGen *op_inst = instruction->op->child; + if (type_is_invalid(op_inst->value->type)) + return ira->codegen->invalid_inst_gen; + + IrInstGen *value_inst = instruction->value->child; + if (type_is_invalid(value_inst->value->type)) + return ira->codegen->invalid_inst_gen; + + ZigType *value_type = value_inst->value->type; + if (value_type->id != ZigTypeIdVector) { + ir_add_error(ira, &value_inst->base, + buf_sprintf("expected vector type, found '%s'", + buf_ptr(&value_type->name))); + return ira->codegen->invalid_inst_gen; + } + + ReduceOp op; + if (!ir_resolve_reduce_op(ira, op_inst, &op)) + return ira->codegen->invalid_inst_gen; + + ZigType *elem_type = value_type->data.vector.elem_type; + switch (elem_type->id) { + case ZigTypeIdInt: + break; + case ZigTypeIdBool: + if (op > ReduceOp_xor) { + ir_add_error(ira, &op_inst->base, + buf_sprintf("invalid operation for '%s' type", + buf_ptr(&elem_type->name))); + return ira->codegen->invalid_inst_gen; + } break; + case ZigTypeIdFloat: + if (op < ReduceOp_min) { + ir_add_error(ira, &op_inst->base, + buf_sprintf("invalid operation for '%s' type", + buf_ptr(&elem_type->name))); + return ira->codegen->invalid_inst_gen; + } break; + default: + // Vectors cannot have child types other than those listed above + zig_unreachable(); + } + + // special case zero bit types + switch (type_has_one_possible_value(ira->codegen, elem_type)) { + case OnePossibleValueInvalid: + return ira->codegen->invalid_inst_gen; + case OnePossibleValueYes: + return ir_const_move(ira, &instruction->base.base, + get_the_one_possible_value(ira->codegen, elem_type)); + case OnePossibleValueNo: + break; + } + + if (instr_is_comptime(value_inst)) { + IrInstGen *result = ir_const(ira, &instruction->base.base, elem_type); + if (ir_eval_reduce(ira, &instruction->base.base, op, value_inst->value, result->value)) + return ira->codegen->invalid_inst_gen; + return result; + } + + return ir_build_reduce_gen(ira, &instruction->base.base, op, value_inst, elem_type); +} + static IrInstGen *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstSrcFence *instruction) { IrInstGen *order_inst = instruction->order->child; if (type_is_invalid(order_inst->value->type)) @@ -31550,6 +31773,8 @@ static IrInstGen *ir_analyze_instruction_base(IrAnalyze *ira, IrInstSrc *instruc return ir_analyze_instruction_cmpxchg(ira, (IrInstSrcCmpxchg *)instruction); case IrInstSrcIdFence: return ir_analyze_instruction_fence(ira, (IrInstSrcFence *)instruction); + case IrInstSrcIdReduce: + return ir_analyze_instruction_reduce(ira, (IrInstSrcReduce *)instruction); case IrInstSrcIdTruncate: return ir_analyze_instruction_truncate(ira, (IrInstSrcTruncate *)instruction); case IrInstSrcIdIntCast: @@ -31937,6 +32162,7 @@ bool ir_inst_gen_has_side_effects(IrInstGen *instruction) { case IrInstGenIdNegation: case IrInstGenIdNegationWrapping: case IrInstGenIdWasmMemorySize: + case IrInstGenIdReduce: return false; case IrInstGenIdAsm: @@ -32106,6 +32332,7 @@ bool ir_inst_src_has_side_effects(IrInstSrc *instruction) { case IrInstSrcIdSpillEnd: case IrInstSrcIdWasmMemorySize: case IrInstSrcIdSrc: + case IrInstSrcIdReduce: return false; case IrInstSrcIdAsm: |
