From 38b2d6209239f0dad7cb38e656d9d38506f126ca Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Wed, 8 Dec 2021 15:19:13 -0700 Subject: stage1: saturating shl operates using LHS type Saturating shift left (`<<|`) previously used the `ir_analyze_bin_op_math` codepath rather than the `ir_analyze_bit_shift` codepath, leading to it doing peer type resolution (incorrect) instead of using the LHS type as the number of bits to do the saturating against. This required implementing SIMD vector support for `@truncate`. Additionall, this commit adds a compile error for saturating shift left on a comptime_int. stage2 does not pass these new behavior tests yet. closes #10298 --- src/stage1/ir.cpp | 170 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 116 insertions(+), 54 deletions(-) (limited to 'src') diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 0684b7bce1..5cec62bc80 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -9900,6 +9900,100 @@ static Stage1AirInst *ir_analyze_math_op(IrAnalyze *ira, Scope *scope, AstNode * return ir_implicit_cast(ira, result_instruction, type_entry); } +static Stage1AirInst *ir_analyze_truncate(IrAnalyze *ira, Scope *scope, AstNode *source_node, + ZigType *dest_scalar_type, AstNode *dest_type_node, + Stage1AirInst *operand, AstNode *operand_node) +{ + if (dest_scalar_type->id != ZigTypeIdInt && + dest_scalar_type->id != ZigTypeIdComptimeInt) + { + ir_add_error_node(ira, dest_type_node, + buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_scalar_type->name))); + return ira->codegen->invalid_inst_gen; + } + + ZigType *src_type = operand->value->type; + bool is_vector = (src_type->id == ZigTypeIdVector); + ZigType *src_scalar_type = is_vector ? + src_type->data.vector.elem_type : src_type; + + ZigType *dest_type = is_vector ? + get_vector_type(ira->codegen, src_type->data.vector.len, dest_scalar_type) : + dest_scalar_type; + + if (src_scalar_type->id != ZigTypeIdInt && src_scalar_type->id != ZigTypeIdComptimeInt) { + ir_add_error_node(ira, operand_node, + buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_scalar_type->name))); + return ira->codegen->invalid_inst_gen; + } + + if (dest_scalar_type->id == ZigTypeIdComptimeInt) { + return ir_implicit_cast2(ira, scope, operand_node, operand, dest_type); + } + + if (src_scalar_type->id != ZigTypeIdComptimeInt) { + if (src_scalar_type->data.integral.is_signed != dest_scalar_type->data.integral.is_signed) { + const char *sign_str = dest_scalar_type->data.integral.is_signed ? "signed" : "unsigned"; + ir_add_error_node(ira, operand_node, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_scalar_type->name))); + return ira->codegen->invalid_inst_gen; + } else if (src_scalar_type->data.integral.bit_count > 0 && src_scalar_type->data.integral.bit_count < dest_scalar_type->data.integral.bit_count) { + ir_add_error_node(ira, operand_node, buf_sprintf("type '%s' has fewer bits than destination type '%s'", + buf_ptr(&src_scalar_type->name), buf_ptr(&dest_scalar_type->name))); + return ira->codegen->invalid_inst_gen; + } + } + + if (instr_is_comptime(operand)) { + ZigValue *val = ir_resolve_const(ira, operand, UndefBad); + if (val == nullptr) + return ira->codegen->invalid_inst_gen; + + if (!is_vector) { + Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type); + bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint, + dest_scalar_type->data.integral.bit_count, + dest_scalar_type->data.integral.is_signed); + return result; + } + + Stage1AirInst *result_instruction = ir_const(ira, scope, source_node, dest_type); + ZigValue *out_val = result_instruction->value; + expand_undef_array(ira->codegen, operand->value); + out_val->special = ConstValSpecialUndef; + expand_undef_array(ira->codegen, out_val); + size_t len = dest_type->data.vector.len; + for (size_t i = 0; i < len; i += 1) { + ZigValue *scalar_operand_val = &operand->value->data.x_array.data.s_none.elements[i]; + ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i]; + assert(scalar_operand_val->type == dest_scalar_type); + assert(scalar_out_val->type == dest_scalar_type); + + bigint_truncate(&scalar_out_val->data.x_bigint, + &scalar_operand_val->data.x_bigint, + dest_scalar_type->data.integral.bit_count, + dest_scalar_type->data.integral.is_signed); + + scalar_out_val->type = dest_scalar_type; + scalar_out_val->special = ConstValSpecialStatic; + } + out_val->type = dest_type; + out_val->special = ConstValSpecialStatic; + return result_instruction; + } + + if (src_scalar_type->data.integral.bit_count == 0 || + dest_scalar_type->data.integral.bit_count == 0) + { + Stage1AirInst *result = ir_const(ira, scope, source_node, dest_type); + if (!is_vector) { + bigint_init_unsigned(&result->value->data.x_bigint, 0); + } + return result; + } + + return ir_build_truncate_gen(ira, scope, source_node, dest_type, operand); +} + static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *bin_op_instruction) { Stage1AirInst *op1 = bin_op_instruction->op1->child; if (type_is_invalid(op1->value->type)) @@ -9951,6 +10045,12 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b // comptime_int has no finite bit width casted_op2 = op2; + if (op_id == IrBinOpShlSat) { + ir_add_error_node(ira, bin_op_instruction->base.source_node, + buf_sprintf("saturating shift on a comptime_int which has unlimited bits")); + return ira->codegen->invalid_inst_gen; + } + if (op_id == IrBinOpBitShiftLeftLossy) { op_id = IrBinOpBitShiftLeftExact; } @@ -9972,6 +10072,13 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b buf_sprintf("shift by negative value %s", buf_ptr(val_buf))); return ira->codegen->invalid_inst_gen; } + } else if (op_id == IrBinOpShlSat) { + casted_op2 = ir_analyze_truncate(ira, + bin_op_instruction->base.scope, bin_op_instruction->base.source_node, + op1_scalar_type, bin_op_instruction->op1->source_node, + op2, bin_op_instruction->op2->source_node); + if (type_is_invalid(casted_op2->value->type)) + return ira->codegen->invalid_inst_gen; } else { const unsigned bit_count = op1_scalar_type->data.integral.bit_count; ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen, @@ -10030,8 +10137,9 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b return ir_analyze_math_op(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1_type, op1_val, op_id, op2_val); } - return ir_build_bin_op_gen(ira, bin_op_instruction->base.scope, bin_op_instruction->base.source_node, op1->value->type, - op_id, op1, casted_op2, bin_op_instruction->safety_check_on); + return ir_build_bin_op_gen(ira, + bin_op_instruction->base.scope, bin_op_instruction->base.source_node, + op1->value->type, op_id, op1, casted_op2, bin_op_instruction->safety_check_on); } static bool ok_float_op(IrBinOp op) { @@ -11035,6 +11143,7 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns case IrBinOpBitShiftLeftExact: case IrBinOpBitShiftRightLossy: case IrBinOpBitShiftRightExact: + case IrBinOpShlSat: return ir_analyze_bit_shift(ira, bin_op_instruction); case IrBinOpBinOr: case IrBinOpBinXor: @@ -11057,7 +11166,6 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns case IrBinOpAddSat: case IrBinOpSubSat: case IrBinOpMultSat: - case IrBinOpShlSat: return ir_analyze_bin_op_math(ira, bin_op_instruction); case IrBinOpArrayCat: return ir_analyze_array_cat(ira, bin_op_instruction); @@ -20017,59 +20125,13 @@ static Stage1AirInst *ir_analyze_instruction_truncate(IrAnalyze *ira, Stage1ZirI if (type_is_invalid(dest_type)) return ira->codegen->invalid_inst_gen; - if (dest_type->id != ZigTypeIdInt && - dest_type->id != ZigTypeIdComptimeInt) - { - ir_add_error(ira, dest_type_value, buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_type->name))); - return ira->codegen->invalid_inst_gen; - } - - Stage1AirInst *target = instruction->target->child; - ZigType *src_type = target->value->type; - if (type_is_invalid(src_type)) - return ira->codegen->invalid_inst_gen; - - if (src_type->id != ZigTypeIdInt && - src_type->id != ZigTypeIdComptimeInt) - { - ir_add_error(ira, target, buf_sprintf("expected integer type, found '%s'", buf_ptr(&src_type->name))); + Stage1AirInst *operand = instruction->target->child; + if (type_is_invalid(operand->value->type)) return ira->codegen->invalid_inst_gen; - } - - if (dest_type->id == ZigTypeIdComptimeInt) { - return ir_implicit_cast2(ira, instruction->target->scope, instruction->target->source_node, target, dest_type); - } - if (src_type->id != ZigTypeIdComptimeInt) { - if (src_type->data.integral.is_signed != dest_type->data.integral.is_signed) { - const char *sign_str = dest_type->data.integral.is_signed ? "signed" : "unsigned"; - ir_add_error(ira, target, buf_sprintf("expected %s integer type, found '%s'", sign_str, buf_ptr(&src_type->name))); - return ira->codegen->invalid_inst_gen; - } else if (src_type->data.integral.bit_count > 0 && src_type->data.integral.bit_count < dest_type->data.integral.bit_count) { - ir_add_error(ira, target, buf_sprintf("type '%s' has fewer bits than destination type '%s'", - buf_ptr(&src_type->name), buf_ptr(&dest_type->name))); - return ira->codegen->invalid_inst_gen; - } - } - - if (instr_is_comptime(target)) { - ZigValue *val = ir_resolve_const(ira, target, UndefBad); - if (val == nullptr) - return ira->codegen->invalid_inst_gen; - - Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type); - bigint_truncate(&result->value->data.x_bigint, &val->data.x_bigint, - dest_type->data.integral.bit_count, dest_type->data.integral.is_signed); - return result; - } - - if (src_type->data.integral.bit_count == 0 || dest_type->data.integral.bit_count == 0) { - Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, dest_type); - bigint_init_unsigned(&result->value->data.x_bigint, 0); - return result; - } - - return ir_build_truncate_gen(ira, instruction->base.scope, instruction->base.source_node, dest_type, target); + return ir_analyze_truncate(ira, instruction->base.scope, instruction->base.source_node, + dest_type, instruction->dest_type->source_node, + operand, instruction->target->source_node); } static Stage1AirInst *ir_analyze_int_cast(IrAnalyze *ira, Scope *scope, AstNode *source_node, -- cgit v1.2.3