diff options
Diffstat (limited to 'src/ir.cpp')
| -rw-r--r-- | src/ir.cpp | 390 |
1 files changed, 338 insertions, 52 deletions
diff --git a/src/ir.cpp b/src/ir.cpp index ea9039a1b6..cbc00f0cfe 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -717,6 +717,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionVectorType *) { return IrInstructionIdVectorType; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionShuffleVector *) { + return IrInstructionIdShuffleVector; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionBoolNot *) { return IrInstructionIdBoolNot; } @@ -2277,6 +2281,25 @@ static IrInstruction *ir_build_vector_type(IrBuilder *irb, Scope *scope, AstNode return &instruction->base; } +static IrInstruction *ir_build_shuffle_vector(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *scalar_type, IrInstruction *a, IrInstruction *b, IrInstruction *mask) +{ + IrInstructionShuffleVector *instruction = ir_build_instruction<IrInstructionShuffleVector>(irb, scope, source_node); + instruction->scalar_type = scalar_type; + instruction->a = a; + instruction->b = b; + instruction->mask = mask; + + if (scalar_type != nullptr) { + ir_ref_instruction(scalar_type, irb->current_basic_block); + } + ir_ref_instruction(a, irb->current_basic_block); + ir_ref_instruction(b, irb->current_basic_block); + ir_ref_instruction(mask, irb->current_basic_block); + + return &instruction->base; +} + static IrInstruction *ir_build_bool_not(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) { IrInstructionBoolNot *instruction = ir_build_instruction<IrInstructionBoolNot>(irb, scope, source_node); instruction->value = value; @@ -4936,6 +4959,32 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo IrInstruction *vector_type = ir_build_vector_type(irb, scope, node, arg0_value, arg1_value); return ir_lval_wrap(irb, scope, vector_type, lval, result_loc); } + case BuiltinFnIdShuffle: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); + if (arg0_value == irb->codegen->invalid_instruction) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope); + if (arg1_value == irb->codegen->invalid_instruction) + return arg1_value; + + AstNode *arg2_node = node->data.fn_call_expr.params.at(2); + IrInstruction *arg2_value = ir_gen_node(irb, arg2_node, scope); + if (arg2_value == irb->codegen->invalid_instruction) + return arg2_value; + + AstNode *arg3_node = node->data.fn_call_expr.params.at(3); + IrInstruction *arg3_value = ir_gen_node(irb, arg3_node, scope); + if (arg3_value == irb->codegen->invalid_instruction) + return arg3_value; + + IrInstruction *shuffle_vector = ir_build_shuffle_vector(irb, scope, node, + arg0_value, arg1_value, arg2_value, arg3_value); + return ir_lval_wrap(irb, scope, shuffle_vector, lval, result_loc); + } case BuiltinFnIdMemcpy: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); @@ -11000,6 +11049,19 @@ static ZigType *ir_resolve_type(IrAnalyze *ira, IrInstruction *type_value) { return ir_resolve_const_type(ira->codegen, ira->new_irb.exec, type_value->source_node, val); } +static ZigType *ir_resolve_vector_elem_type(IrAnalyze *ira, IrInstruction *elem_type_value) { + ZigType *elem_type = ir_resolve_type(ira, elem_type_value); + if (type_is_invalid(elem_type)) + return ira->codegen->builtin_types.entry_invalid; + if (!is_valid_vector_elem_type(elem_type)) { + ir_add_error(ira, elem_type_value, + buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid", + buf_ptr(&elem_type->name))); + return ira->codegen->builtin_types.entry_invalid; + } + return elem_type; +} + static ZigType *ir_resolve_int_type(IrAnalyze *ira, IrInstruction *type_value) { ZigType *ty = ir_resolve_type(ira, type_value); if (type_is_invalid(ty)) @@ -13092,6 +13154,59 @@ static bool optional_value_is_null(ConstExprValue *val) { } } +static IrInstruction *ir_evaluate_bin_op_cmp(IrAnalyze *ira, ZigType *resolved_type, + ConstExprValue *op1_val, ConstExprValue *op2_val, IrInstructionBinOp *bin_op_instruction, IrBinOp op_id, + bool one_possible_value) { + if (op1_val->special == ConstValSpecialUndef || + op2_val->special == ConstValSpecialUndef) + return ir_const_undef(ira, &bin_op_instruction->base, resolved_type); + if (resolved_type->id == ZigTypeIdComptimeFloat || resolved_type->id == ZigTypeIdFloat) { + if (float_is_nan(op1_val) || float_is_nan(op2_val)) { + return ir_const_bool(ira, &bin_op_instruction->base, op_id == IrBinOpCmpNotEq); + } + Cmp cmp_result = float_cmp(op1_val, op2_val); + bool answer = resolve_cmp_op_id(op_id, cmp_result); + return ir_const_bool(ira, &bin_op_instruction->base, answer); + } else if (resolved_type->id == ZigTypeIdComptimeInt || resolved_type->id == ZigTypeIdInt) { + Cmp cmp_result = bigint_cmp(&op1_val->data.x_bigint, &op2_val->data.x_bigint); + bool answer = resolve_cmp_op_id(op_id, cmp_result); + return ir_const_bool(ira, &bin_op_instruction->base, answer); + } else if (resolved_type->id == ZigTypeIdPointer && op_id != IrBinOpCmpEq && op_id != IrBinOpCmpNotEq) { + if ((op1_val->data.x_ptr.special == ConstPtrSpecialHardCodedAddr || + op1_val->data.x_ptr.special == ConstPtrSpecialNull) && + (op2_val->data.x_ptr.special == ConstPtrSpecialHardCodedAddr || + op2_val->data.x_ptr.special == ConstPtrSpecialNull)) + { + uint64_t op1_addr = op1_val->data.x_ptr.special == ConstPtrSpecialNull ? + 0 : op1_val->data.x_ptr.data.hard_coded_addr.addr; + uint64_t op2_addr = op2_val->data.x_ptr.special == ConstPtrSpecialNull ? + 0 : op2_val->data.x_ptr.data.hard_coded_addr.addr; + Cmp cmp_result; + if (op1_addr > op2_addr) { + cmp_result = CmpGT; + } else if (op1_addr < op2_addr) { + cmp_result = CmpLT; + } else { + cmp_result = CmpEQ; + } + bool answer = resolve_cmp_op_id(op_id, cmp_result); + return ir_const_bool(ira, &bin_op_instruction->base, answer); + } + } else { + bool are_equal = one_possible_value || const_values_equal(ira->codegen, op1_val, op2_val); + bool answer; + if (op_id == IrBinOpCmpEq) { + answer = are_equal; + } else if (op_id == IrBinOpCmpNotEq) { + answer = !are_equal; + } else { + zig_unreachable(); + } + return ir_const_bool(ira, &bin_op_instruction->base, answer); + } + zig_unreachable(); +} + // Returns ErrorNotLazy when the value cannot be determined static Error lazy_cmp_zero(AstNode *source_node, ConstExprValue *val, Cmp *result) { Error err; @@ -13477,51 +13592,22 @@ never_mind_just_calculate_it_normally: ConstExprValue *op2_val = one_possible_value ? &casted_op2->value : ir_resolve_const(ira, casted_op2, UndefBad); if (op2_val == nullptr) return ira->codegen->invalid_instruction; - - if (resolved_type->id == ZigTypeIdComptimeFloat || resolved_type->id == ZigTypeIdFloat) { - if (float_is_nan(op1_val) || float_is_nan(op2_val)) { - return ir_const_bool(ira, &bin_op_instruction->base, op_id == IrBinOpCmpNotEq); - } - Cmp cmp_result = float_cmp(op1_val, op2_val); - bool answer = resolve_cmp_op_id(op_id, cmp_result); - return ir_const_bool(ira, &bin_op_instruction->base, answer); - } else if (resolved_type->id == ZigTypeIdComptimeInt || resolved_type->id == ZigTypeIdInt) { - Cmp cmp_result = bigint_cmp(&op1_val->data.x_bigint, &op2_val->data.x_bigint); - bool answer = resolve_cmp_op_id(op_id, cmp_result); - return ir_const_bool(ira, &bin_op_instruction->base, answer); - } else if (resolved_type->id == ZigTypeIdPointer && op_id != IrBinOpCmpEq && op_id != IrBinOpCmpNotEq) { - if ((op1_val->data.x_ptr.special == ConstPtrSpecialHardCodedAddr || - op1_val->data.x_ptr.special == ConstPtrSpecialNull) && - (op2_val->data.x_ptr.special == ConstPtrSpecialHardCodedAddr || - op2_val->data.x_ptr.special == ConstPtrSpecialNull)) - { - uint64_t op1_addr = op1_val->data.x_ptr.special == ConstPtrSpecialNull ? - 0 : op1_val->data.x_ptr.data.hard_coded_addr.addr; - uint64_t op2_addr = op2_val->data.x_ptr.special == ConstPtrSpecialNull ? - 0 : op2_val->data.x_ptr.data.hard_coded_addr.addr; - Cmp cmp_result; - if (op1_addr > op2_addr) { - cmp_result = CmpGT; - } else if (op1_addr < op2_addr) { - cmp_result = CmpLT; - } else { - cmp_result = CmpEQ; - } - bool answer = resolve_cmp_op_id(op_id, cmp_result); - return ir_const_bool(ira, &bin_op_instruction->base, answer); - } - } else { - bool are_equal = one_possible_value || const_values_equal(ira->codegen, op1_val, op2_val); - bool answer; - if (op_id == IrBinOpCmpEq) { - answer = are_equal; - } else if (op_id == IrBinOpCmpNotEq) { - answer = !are_equal; - } else { - zig_unreachable(); - } - return ir_const_bool(ira, &bin_op_instruction->base, answer); + if (resolved_type->id != ZigTypeIdVector) + return ir_evaluate_bin_op_cmp(ira, resolved_type, op1_val, op2_val, bin_op_instruction, op_id, one_possible_value); + IrInstruction *result = ir_const(ira, &bin_op_instruction->base, + get_vector_type(ira->codegen, resolved_type->data.vector.len, ira->codegen->builtin_types.entry_bool)); + result->value.data.x_array.data.s_none.elements = + create_const_vals(resolved_type->data.vector.len); + + expand_undef_array(ira->codegen, &result->value); + for (size_t i = 0;i < resolved_type->data.vector.len;i++) { + IrInstruction *cur_res = ir_evaluate_bin_op_cmp(ira, resolved_type->data.vector.elem_type, + &op1_val->data.x_array.data.s_none.elements[i], + &op2_val->data.x_array.data.s_none.elements[i], + bin_op_instruction, op_id, one_possible_value); + copy_const_val(&result->value.data.x_array.data.s_none.elements[i], &cur_res->value, false); } + return result; } // some comparisons with unsigned numbers can be evaluated @@ -13564,7 +13650,12 @@ never_mind_just_calculate_it_normally: IrInstruction *result = ir_build_bin_op(&ira->new_irb, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op_id, casted_op1, casted_op2, bin_op_instruction->safety_check_on); - result->value.type = ira->codegen->builtin_types.entry_bool; + if (resolved_type->id == ZigTypeIdVector) { + result->value.type = get_vector_type(ira->codegen, resolved_type->data.vector.len, + ira->codegen->builtin_types.entry_bool); + } else { + result->value.type = ira->codegen->builtin_types.entry_bool; + } return result; } @@ -22018,20 +22109,212 @@ static IrInstruction *ir_analyze_instruction_vector_type(IrAnalyze *ira, IrInstr if (!ir_resolve_unsigned(ira, instruction->len->child, ira->codegen->builtin_types.entry_u32, &len)) return ira->codegen->invalid_instruction; - ZigType *elem_type = ir_resolve_type(ira, instruction->elem_type->child); + ZigType *elem_type = ir_resolve_vector_elem_type(ira, instruction->elem_type->child); if (type_is_invalid(elem_type)) return ira->codegen->invalid_instruction; - if (!is_valid_vector_elem_type(elem_type)) { - ir_add_error(ira, instruction->elem_type, - buf_sprintf("vector element type must be integer, float, or pointer; '%s' is invalid", - buf_ptr(&elem_type->name))); + ZigType *vector_type = get_vector_type(ira->codegen, len, elem_type); + + return ir_const_type(ira, &instruction->base, vector_type); +} + +static IrInstruction *ir_analyze_shuffle_vector(IrAnalyze *ira, IrInstruction *source_instr, + ZigType *scalar_type, IrInstruction *a, IrInstruction *b, IrInstruction *mask) +{ + ir_assert(source_instr && scalar_type && a && b && mask, source_instr); + ir_assert(is_valid_vector_elem_type(scalar_type), source_instr); + + uint32_t len_mask; + if (mask->value.type->id == ZigTypeIdVector) { + len_mask = mask->value.type->data.vector.len; + } else if (mask->value.type->id == ZigTypeIdArray) { + len_mask = mask->value.type->data.array.len; + } else { + ir_add_error(ira, mask, + buf_sprintf("expected vector or array, found '%s'", + buf_ptr(&mask->value.type->name))); return ira->codegen->invalid_instruction; } + mask = ir_implicit_cast(ira, mask, get_vector_type(ira->codegen, len_mask, + ira->codegen->builtin_types.entry_i32)); + if (type_is_invalid(mask->value.type)) + return ira->codegen->invalid_instruction; - ZigType *vector_type = get_vector_type(ira->codegen, len, elem_type); + uint32_t len_a; + if (a->value.type->id == ZigTypeIdVector) { + len_a = a->value.type->data.vector.len; + } else if (a->value.type->id == ZigTypeIdArray) { + len_a = a->value.type->data.array.len; + } else if (a->value.type->id == ZigTypeIdUndefined) { + len_a = UINT32_MAX; + } else { + ir_add_error(ira, a, + buf_sprintf("expected vector or array with element type '%s', found '%s'", + buf_ptr(&scalar_type->name), + buf_ptr(&a->value.type->name))); + return ira->codegen->invalid_instruction; + } - return ir_const_type(ira, &instruction->base, vector_type); + uint32_t len_b; + if (b->value.type->id == ZigTypeIdVector) { + len_b = b->value.type->data.vector.len; + } else if (b->value.type->id == ZigTypeIdArray) { + len_b = b->value.type->data.array.len; + } else if (b->value.type->id == ZigTypeIdUndefined) { + len_b = UINT32_MAX; + } else { + ir_add_error(ira, b, + buf_sprintf("expected vector or array with element type '%s', found '%s'", + buf_ptr(&scalar_type->name), + buf_ptr(&b->value.type->name))); + return ira->codegen->invalid_instruction; + } + + if (len_a == UINT32_MAX && len_b == UINT32_MAX) { + return ir_const_undef(ira, a, get_vector_type(ira->codegen, len_mask, scalar_type)); + } + + if (len_a == UINT32_MAX) { + len_a = len_b; + a = ir_const_undef(ira, a, get_vector_type(ira->codegen, len_a, scalar_type)); + } else { + a = ir_implicit_cast(ira, a, get_vector_type(ira->codegen, len_a, scalar_type)); + if (type_is_invalid(a->value.type)) + return ira->codegen->invalid_instruction; + } + + if (len_b == UINT32_MAX) { + len_b = len_a; + b = ir_const_undef(ira, b, get_vector_type(ira->codegen, len_b, scalar_type)); + } else { + b = ir_implicit_cast(ira, b, get_vector_type(ira->codegen, len_b, scalar_type)); + if (type_is_invalid(b->value.type)) + return ira->codegen->invalid_instruction; + } + + ConstExprValue *mask_val = ir_resolve_const(ira, mask, UndefOk); + if (mask_val == nullptr) + return ira->codegen->invalid_instruction; + + expand_undef_array(ira->codegen, mask_val); + + for (uint32_t i = 0; i < len_mask; i += 1) { + ConstExprValue *mask_elem_val = &mask_val->data.x_array.data.s_none.elements[i]; + if (mask_elem_val->special == ConstValSpecialUndef) + continue; + int32_t v_i32 = bigint_as_signed(&mask_elem_val->data.x_bigint); + uint32_t v; + IrInstruction *chosen_operand; + if (v_i32 >= 0) { + v = (uint32_t)v_i32; + chosen_operand = a; + } else { + v = (uint32_t)~v_i32; + chosen_operand = b; + } + if (v >= chosen_operand->value.type->data.vector.len) { + ErrorMsg *msg = ir_add_error(ira, mask, + buf_sprintf("mask index '%u' has out-of-bounds selection", i)); + add_error_note(ira->codegen, msg, chosen_operand->source_node, + buf_sprintf("selected index '%u' out of bounds of %s", v, + buf_ptr(&chosen_operand->value.type->name))); + if (chosen_operand == a && v < len_a + len_b) { + add_error_note(ira->codegen, msg, b->source_node, + buf_create_from_str("selections from the second vector are specified with negative numbers")); + } + return ira->codegen->invalid_instruction; + } + } + + ZigType *result_type = get_vector_type(ira->codegen, len_mask, scalar_type); + if (instr_is_comptime(a) && instr_is_comptime(b)) { + ConstExprValue *a_val = ir_resolve_const(ira, a, UndefOk); + if (a_val == nullptr) + return ira->codegen->invalid_instruction; + + ConstExprValue *b_val = ir_resolve_const(ira, b, UndefOk); + if (b_val == nullptr) + return ira->codegen->invalid_instruction; + + expand_undef_array(ira->codegen, a_val); + expand_undef_array(ira->codegen, b_val); + + IrInstruction *result = ir_const(ira, source_instr, result_type); + result->value.data.x_array.data.s_none.elements = create_const_vals(len_mask); + for (uint32_t i = 0; i < mask_val->type->data.vector.len; i += 1) { + ConstExprValue *mask_elem_val = &mask_val->data.x_array.data.s_none.elements[i]; + ConstExprValue *result_elem_val = &result->value.data.x_array.data.s_none.elements[i]; + if (mask_elem_val->special == ConstValSpecialUndef) { + result_elem_val->special = ConstValSpecialUndef; + continue; + } + int32_t v = bigint_as_signed(&mask_elem_val->data.x_bigint); + // We've already checked for and emitted compile errors for index out of bounds here. + ConstExprValue *src_elem_val = (v >= 0) ? + &a->value.data.x_array.data.s_none.elements[v] : + &b->value.data.x_array.data.s_none.elements[~v]; + copy_const_val(result_elem_val, src_elem_val, false); + + ir_assert(result_elem_val->special == ConstValSpecialStatic, source_instr); + } + result->value.special = ConstValSpecialStatic; + return result; + } + + // All static analysis passed, and not comptime. + // For runtime codegen, vectors a and b must be the same length. Here we + // recursively @shuffle the smaller vector to append undefined elements + // to it up to the length of the longer vector. This recursion terminates + // in 1 call because these calls to ir_analyze_shuffle_vector guarantee + // len_a == len_b. + if (len_a != len_b) { + uint32_t len_min = min(len_a, len_b); + uint32_t len_max = max(len_a, len_b); + + IrInstruction *expand_mask = ir_const(ira, mask, + get_vector_type(ira->codegen, len_max, ira->codegen->builtin_types.entry_i32)); + expand_mask->value.data.x_array.data.s_none.elements = create_const_vals(len_max); + uint32_t i = 0; + for (; i < len_min; i += 1) + bigint_init_unsigned(&expand_mask->value.data.x_array.data.s_none.elements[i].data.x_bigint, i); + for (; i < len_max; i += 1) + bigint_init_signed(&expand_mask->value.data.x_array.data.s_none.elements[i].data.x_bigint, -1); + + IrInstruction *undef = ir_const_undef(ira, source_instr, + get_vector_type(ira->codegen, len_min, scalar_type)); + + if (len_b < len_a) { + b = ir_analyze_shuffle_vector(ira, source_instr, scalar_type, b, undef, expand_mask); + } else { + a = ir_analyze_shuffle_vector(ira, source_instr, scalar_type, a, undef, expand_mask); + } + } + + IrInstruction *result = ir_build_shuffle_vector(&ira->new_irb, + source_instr->scope, source_instr->source_node, + nullptr, a, b, mask); + result->value.type = result_type; + return result; +} + +static IrInstruction *ir_analyze_instruction_shuffle_vector(IrAnalyze *ira, IrInstructionShuffleVector *instruction) { + ZigType *scalar_type = ir_resolve_vector_elem_type(ira, instruction->scalar_type); + if (type_is_invalid(scalar_type)) + return ira->codegen->invalid_instruction; + + IrInstruction *a = instruction->a->child; + if (type_is_invalid(a->value.type)) + return ira->codegen->invalid_instruction; + + IrInstruction *b = instruction->b->child; + if (type_is_invalid(b->value.type)) + return ira->codegen->invalid_instruction; + + IrInstruction *mask = instruction->mask->child; + if (type_is_invalid(mask->value.type)) + return ira->codegen->invalid_instruction; + + return ir_analyze_shuffle_vector(ira, &instruction->base, scalar_type, a, b, mask); } static IrInstruction *ir_analyze_instruction_bool_not(IrAnalyze *ira, IrInstructionBoolNot *instruction) { @@ -25578,6 +25861,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_int_type(ira, (IrInstructionIntType *)instruction); case IrInstructionIdVectorType: return ir_analyze_instruction_vector_type(ira, (IrInstructionVectorType *)instruction); + case IrInstructionIdShuffleVector: + return ir_analyze_instruction_shuffle_vector(ira, (IrInstructionShuffleVector *)instruction); case IrInstructionIdBoolNot: return ir_analyze_instruction_bool_not(ira, (IrInstructionBoolNot *)instruction); case IrInstructionIdMemset: @@ -25913,6 +26198,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdTruncate: case IrInstructionIdIntType: case IrInstructionIdVectorType: + case IrInstructionIdShuffleVector: case IrInstructionIdBoolNot: case IrInstructionIdSliceSrc: case IrInstructionIdMemberCount: |
