From 8c6fa982cd0a02775264b616c37da9907cc603bb Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Mon, 4 Feb 2019 20:30:00 -0500 Subject: SIMD: array to vector, vector to array, wrapping int add also vectors and arrays now use the same ConstExprVal representation See #903 --- src/ir.cpp | 291 ++++++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 180 insertions(+), 111 deletions(-) (limited to 'src/ir.cpp') diff --git a/src/ir.cpp b/src/ir.cpp index 3d3c501df3..3cbbdc8103 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -168,6 +168,7 @@ static IrInstruction *ir_analyze_ptr_cast(IrAnalyze *ira, IrInstruction *source_ static ConstExprValue *ir_resolve_const(IrAnalyze *ira, IrInstruction *value, UndefAllowed undef_allowed); static void copy_const_val(ConstExprValue *dest, ConstExprValue *src, bool same_global_refs); static Error resolve_ptr_align(IrAnalyze *ira, ZigType *ty, uint32_t *result_align); +static void ir_add_alloca(IrAnalyze *ira, IrInstruction *instruction, ZigType *type_entry); static ConstExprValue *const_ptr_pointee_unchecked(CodeGen *g, ConstExprValue *const_val) { assert(get_src_ptr_type(const_val->type) != nullptr); @@ -899,6 +900,14 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionCheckRuntimeScop return IrInstructionIdCheckRuntimeScope; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionVectorToArray *) { + return IrInstructionIdVectorToArray; +} + +static constexpr IrInstructionId ir_instruction_id(IrInstructionArrayToVector *) { + return IrInstructionIdArrayToVector; +} + template static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) { T *special_instruction = allocate(1); @@ -2821,6 +2830,34 @@ static IrInstruction *ir_build_check_runtime_scope(IrBuilder *irb, Scope *scope, return &instruction->base; } +static IrInstruction *ir_build_vector_to_array(IrAnalyze *ira, IrInstruction *source_instruction, + IrInstruction *vector, ZigType *result_type) +{ + IrInstructionVectorToArray *instruction = ir_build_instruction(&ira->new_irb, + source_instruction->scope, source_instruction->source_node); + instruction->base.value.type = result_type; + instruction->vector = vector; + + ir_ref_instruction(vector, ira->new_irb.current_basic_block); + + ir_add_alloca(ira, &instruction->base, result_type); + + return &instruction->base; +} + +static IrInstruction *ir_build_array_to_vector(IrAnalyze *ira, IrInstruction *source_instruction, + IrInstruction *array, ZigType *result_type) +{ + IrInstructionArrayToVector *instruction = ir_build_instruction(&ira->new_irb, + source_instruction->scope, source_instruction->source_node); + instruction->base.value.type = result_type; + instruction->array = array; + + ir_ref_instruction(array, ira->new_irb.current_basic_block); + + return &instruction->base; +} + static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) { results[ReturnKindUnconditional] = 0; results[ReturnKindError] = 0; @@ -8270,6 +8307,7 @@ static bool ir_num_lit_fits_in_other_type(IrAnalyze *ira, IrInstruction *instruc bool const_val_is_int = (const_val->type->id == ZigTypeIdInt || const_val->type->id == ZigTypeIdComptimeInt); bool const_val_is_float = (const_val->type->id == ZigTypeIdFloat || const_val->type->id == ZigTypeIdComptimeFloat); + assert(const_val_is_int || const_val_is_float); if (other_type->id == ZigTypeIdFloat) { if (const_val->type->id == ZigTypeIdComptimeInt || const_val->type->id == ZigTypeIdComptimeFloat) { @@ -10714,6 +10752,32 @@ static void report_recursive_error(IrAnalyze *ira, AstNode *source_node, ConstCa } } +static IrInstruction *ir_analyze_array_to_vector(IrAnalyze *ira, IrInstruction *source_instr, + IrInstruction *array, ZigType *vector_type) +{ + if (instr_is_comptime(array)) { + // arrays and vectors have the same ConstExprValue representation + IrInstruction *result = ir_const(ira, source_instr, vector_type); + copy_const_val(&result->value, &array->value, false); + result->value.type = vector_type; + return result; + } + return ir_build_array_to_vector(ira, source_instr, array, vector_type); +} + +static IrInstruction *ir_analyze_vector_to_array(IrAnalyze *ira, IrInstruction *source_instr, + IrInstruction *vector, ZigType *array_type) +{ + if (instr_is_comptime(vector)) { + // arrays and vectors have the same ConstExprValue representation + IrInstruction *result = ir_const(ira, source_instr, array_type); + copy_const_val(&result->value, &vector->value, false); + result->value.type = array_type; + return result; + } + return ir_build_vector_to_array(ira, source_instr, vector, array_type); +} + static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_instr, ZigType *wanted_type, IrInstruction *value) { @@ -11102,6 +11166,23 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst } } + // cast from @Vector(N, T) to [N]T + if (wanted_type->id == ZigTypeIdArray && actual_type->id == ZigTypeIdVector && + wanted_type->data.array.len == actual_type->data.vector.len && + types_match_const_cast_only(ira, wanted_type->data.array.child_type, + actual_type->data.vector.elem_type, source_node, false).id == ConstCastResultIdOk) + { + return ir_analyze_vector_to_array(ira, source_instr, value, wanted_type); + } + + // cast from [N]T to @Vector(N, T) + if (actual_type->id == ZigTypeIdArray && wanted_type->id == ZigTypeIdVector && + actual_type->data.array.len == wanted_type->data.vector.len && + types_match_const_cast_only(ira, actual_type->data.array.child_type, + wanted_type->data.vector.elem_type, source_node, false).id == ConstCastResultIdOk) + { + return ir_analyze_array_to_vector(ira, source_instr, value, wanted_type); + } // cast from undefined to anything if (actual_type->id == ZigTypeIdUndefined) { @@ -11780,8 +11861,8 @@ static IrInstruction *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp * return result; } -static int ir_eval_math_op(ZigType *type_entry, ConstExprValue *op1_val, - IrBinOp op_id, ConstExprValue *op2_val, ConstExprValue *out_val) +static ErrorMsg *ir_eval_math_op_scalar(IrAnalyze *ira, IrInstruction *source_instr, ZigType *type_entry, + ConstExprValue *op1_val, IrBinOp op_id, ConstExprValue *op2_val, ConstExprValue *out_val) { bool is_int; bool is_float; @@ -11803,10 +11884,10 @@ static int ir_eval_math_op(ZigType *type_entry, ConstExprValue *op1_val, if ((op_id == IrBinOpDivUnspecified || op_id == IrBinOpRemRem || op_id == IrBinOpRemMod || op_id == IrBinOpDivTrunc || op_id == IrBinOpDivFloor) && op2_zcmp == CmpEQ) { - return ErrorDivByZero; + return ir_add_error(ira, source_instr, buf_sprintf("division by zero")); } if ((op_id == IrBinOpRemRem || op_id == IrBinOpRemMod) && op2_zcmp == CmpLT) { - return ErrorNegativeDenominator; + return ir_add_error(ira, source_instr, buf_sprintf("negative denominator")); } switch (op_id) { @@ -11852,7 +11933,7 @@ static int ir_eval_math_op(ZigType *type_entry, ConstExprValue *op1_val, BigInt orig_bigint; bigint_shl(&orig_bigint, &out_val->data.x_bigint, &op2_val->data.x_bigint); if (bigint_cmp(&op1_val->data.x_bigint, &orig_bigint) != CmpEQ) { - return ErrorShiftedOutOneBits; + return ir_add_error(ira, source_instr, buf_sprintf("exact shift shifted out 1 bits")); } break; } @@ -11920,14 +12001,14 @@ static int ir_eval_math_op(ZigType *type_entry, ConstExprValue *op1_val, BigInt remainder; bigint_rem(&remainder, &op1_val->data.x_bigint, &op2_val->data.x_bigint); if (bigint_cmp_zero(&remainder) != CmpEQ) { - return ErrorExactDivRemainder; + return ir_add_error(ira, source_instr, buf_sprintf("exact division had a remainder")); } } else { float_div_trunc(out_val, op1_val, op2_val); ConstExprValue remainder; float_rem(&remainder, op1_val, op2_val); if (float_cmp_zero(&remainder) != CmpEQ) { - return ErrorExactDivRemainder; + return ir_add_error(ira, source_instr, buf_sprintf("exact division had a remainder")); } } break; @@ -11951,13 +12032,51 @@ static int ir_eval_math_op(ZigType *type_entry, ConstExprValue *op1_val, if (!bigint_fits_in_bits(&out_val->data.x_bigint, type_entry->data.integral.bit_count, type_entry->data.integral.is_signed)) { - return ErrorOverflow; + return ir_add_error(ira, source_instr, buf_sprintf("operation caused overflow")); } } out_val->type = type_entry; out_val->special = ConstValSpecialStatic; - return 0; + return nullptr; +} + +// This works on operands that have already been checked to be comptime known. +static IrInstruction *ir_analyze_math_op(IrAnalyze *ira, IrInstruction *source_instr, + ZigType *type_entry, ConstExprValue *op1_val, IrBinOp op_id, ConstExprValue *op2_val) +{ + IrInstruction *result_instruction = ir_const(ira, source_instr, type_entry); + ConstExprValue *out_val = &result_instruction->value; + if (type_entry->id == ZigTypeIdVector) { + expand_undef_array(ira->codegen, op1_val); + expand_undef_array(ira->codegen, op2_val); + out_val->special = ConstValSpecialUndef; + expand_undef_array(ira->codegen, out_val); + size_t len = type_entry->data.vector.len; + ZigType *scalar_type = type_entry->data.vector.elem_type; + for (size_t i = 0; i < len; i += 1) { + ConstExprValue *scalar_op1_val = &op1_val->data.x_array.data.s_none.elements[i]; + ConstExprValue *scalar_op2_val = &op2_val->data.x_array.data.s_none.elements[i]; + ConstExprValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i]; + assert(scalar_op1_val->type == scalar_type); + assert(scalar_op2_val->type == scalar_type); + assert(scalar_out_val->type == scalar_type); + ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type, + scalar_op1_val, op_id, scalar_op2_val, scalar_out_val); + if (msg != nullptr) { + add_error_note(ira->codegen, msg, source_instr->source_node, + buf_sprintf("when computing vector element at index %" ZIG_PRI_usize, i)); + return ira->codegen->invalid_instruction; + } + } + out_val->type = type_entry; + out_val->special = ConstValSpecialStatic; + } else { + if (ir_eval_math_op_scalar(ira, source_instr, type_entry, op1_val, op_id, op2_val, out_val) != nullptr) { + return ira->codegen->invalid_instruction; + } + } + return ir_implicit_cast(ira, result_instruction, type_entry); } static IrInstruction *ir_analyze_bit_shift(IrAnalyze *ira, IrInstructionBinOp *bin_op_instruction) { @@ -12029,24 +12148,7 @@ static IrInstruction *ir_analyze_bit_shift(IrAnalyze *ira, IrInstructionBinOp *b if (op2_val == nullptr) return ira->codegen->invalid_instruction; - IrInstruction *result_instruction = ir_const(ira, &bin_op_instruction->base, op1->value.type); - - int err; - if ((err = ir_eval_math_op(op1->value.type, op1_val, op_id, op2_val, &result_instruction->value))) { - if (err == ErrorOverflow) { - ir_add_error(ira, &bin_op_instruction->base, buf_sprintf("operation caused overflow")); - return ira->codegen->invalid_instruction; - } else if (err == ErrorShiftedOutOneBits) { - ir_add_error(ira, &bin_op_instruction->base, buf_sprintf("exact shift shifted out 1 bits")); - return ira->codegen->invalid_instruction; - } else { - zig_unreachable(); - } - return ira->codegen->invalid_instruction; - } - - ir_num_lit_fits_in_other_type(ira, result_instruction, op1->value.type, false); - return result_instruction; + return ir_analyze_math_op(ira, &bin_op_instruction->base, op1->value.type, op1_val, op_id, op2_val); } else if (op1->value.type->id == ZigTypeIdComptimeInt) { ir_add_error(ira, &bin_op_instruction->base, buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known")); @@ -12292,30 +12394,7 @@ static IrInstruction *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp if (op2_val == nullptr) return ira->codegen->invalid_instruction; - IrInstruction *result_instruction = ir_const(ira, &instruction->base, resolved_type); - - int err; - if ((err = ir_eval_math_op(resolved_type, op1_val, op_id, op2_val, &result_instruction->value))) { - if (err == ErrorDivByZero) { - ir_add_error(ira, &instruction->base, buf_sprintf("division by zero")); - return ira->codegen->invalid_instruction; - } else if (err == ErrorOverflow) { - ir_add_error(ira, &instruction->base, buf_sprintf("operation caused overflow")); - return ira->codegen->invalid_instruction; - } else if (err == ErrorExactDivRemainder) { - ir_add_error(ira, &instruction->base, buf_sprintf("exact division had a remainder")); - return ira->codegen->invalid_instruction; - } else if (err == ErrorNegativeDenominator) { - ir_add_error(ira, &instruction->base, buf_sprintf("negative denominator")); - return ira->codegen->invalid_instruction; - } else { - zig_unreachable(); - } - return ira->codegen->invalid_instruction; - } - - ir_num_lit_fits_in_other_type(ira, result_instruction, resolved_type, false); - return result_instruction; + return ir_analyze_math_op(ira, &instruction->base, resolved_type, op1_val, op_id, op2_val); } IrInstruction *result = ir_build_bin_op(&ira->new_irb, instruction->base.scope, @@ -18745,10 +18824,7 @@ static IrInstruction *ir_analyze_instruction_vector_type(IrAnalyze *ira, IrInstr if (type_is_invalid(elem_type)) return ira->codegen->invalid_instruction; - if (elem_type->id != ZigTypeIdInt && - elem_type->id != ZigTypeIdFloat && - get_codegen_ptr_type(elem_type) == nullptr) - { + if (!is_valid_vector_elem_type(elem_type)) { ir_add_error(ira, instruction->elem_type, buf_sprintf("vector element type must be integer, float, or pointer; '%s' is invalid", buf_ptr(&elem_type->name))); @@ -20345,6 +20421,17 @@ static IrInstruction *ir_analyze_instruction_ptr_cast(IrAnalyze *ira, IrInstruct return ir_analyze_ptr_cast(ira, &instruction->base, ptr, dest_type, dest_type_value); } +static void buf_write_value_bytes_array(CodeGen *codegen, uint8_t *buf, ConstExprValue *val, size_t len) { + size_t buf_i = 0; + // TODO optimize the buf case + expand_undef_array(codegen, val); + for (size_t elem_i = 0; elem_i < val->type->data.array.len; elem_i += 1) { + ConstExprValue *elem = &val->data.x_array.data.s_none.elements[elem_i]; + buf_write_value_bytes(codegen, &buf[buf_i], elem); + buf_i += type_size(codegen, elem->type); + } +} + static void buf_write_value_bytes(CodeGen *codegen, uint8_t *buf, ConstExprValue *val) { if (val->special == ConstValSpecialUndef) val->special = ConstValSpecialStatic; @@ -20390,26 +20477,9 @@ static void buf_write_value_bytes(CodeGen *codegen, uint8_t *buf, ConstExprValue zig_unreachable(); } case ZigTypeIdArray: - { - size_t buf_i = 0; - // TODO optimize the buf case - expand_undef_array(codegen, val); - for (size_t elem_i = 0; elem_i < val->type->data.array.len; elem_i += 1) { - ConstExprValue *elem = &val->data.x_array.data.s_none.elements[elem_i]; - buf_write_value_bytes(codegen, &buf[buf_i], elem); - buf_i += type_size(codegen, elem->type); - } - } - return; - case ZigTypeIdVector: { - size_t buf_i = 0; - for (uint32_t elem_i = 0; elem_i < val->type->data.vector.len; elem_i += 1) { - ConstExprValue *elem = &val->data.x_vector.elements[elem_i]; - buf_write_value_bytes(codegen, &buf[buf_i], elem); - buf_i += type_size(codegen, elem->type); - } - return; - } + return buf_write_value_bytes_array(codegen, buf, val, val->type->data.array.len); + case ZigTypeIdVector: + return buf_write_value_bytes_array(codegen, buf, val, val->type->data.vector.len); case ZigTypeIdStruct: zig_panic("TODO buf_write_value_bytes struct type"); case ZigTypeIdOptional: @@ -20426,6 +20496,31 @@ static void buf_write_value_bytes(CodeGen *codegen, uint8_t *buf, ConstExprValue zig_unreachable(); } +static Error buf_read_value_bytes_array(IrAnalyze *ira, CodeGen *codegen, AstNode *source_node, uint8_t *buf, + ConstExprValue *val, ZigType *elem_type, size_t len) +{ + Error err; + uint64_t elem_size = type_size(codegen, elem_type); + + switch (val->data.x_array.special) { + case ConstArraySpecialNone: + val->data.x_array.data.s_none.elements = create_const_vals(len); + for (size_t i = 0; i < len; i++) { + ConstExprValue *elem = &val->data.x_array.data.s_none.elements[i]; + elem->special = ConstValSpecialStatic; + elem->type = elem_type; + if ((err = buf_read_value_bytes(ira, codegen, source_node, buf + (elem_size * i), elem))) + return err; + } + return ErrorNone; + case ConstArraySpecialUndef: + zig_panic("TODO buf_read_value_bytes ConstArraySpecialUndef array type"); + case ConstArraySpecialBuf: + zig_panic("TODO buf_read_value_bytes ConstArraySpecialBuf array type"); + } + zig_unreachable(); +} + static Error buf_read_value_bytes(IrAnalyze *ira, CodeGen *codegen, AstNode *source_node, uint8_t *buf, ConstExprValue *val) { Error err; assert(val->special == ConstValSpecialStatic); @@ -20464,42 +20559,12 @@ static Error buf_read_value_bytes(IrAnalyze *ira, CodeGen *codegen, AstNode *sou val->data.x_ptr.data.hard_coded_addr.addr = bigint_as_unsigned(&bn); return ErrorNone; } - case ZigTypeIdArray: { - uint64_t elem_size = type_size(codegen, val->type->data.array.child_type); - size_t len = val->type->data.array.len; - - switch (val->data.x_array.special) { - case ConstArraySpecialNone: - val->data.x_array.data.s_none.elements = create_const_vals(len); - for (size_t i = 0; i < len; i++) { - ConstExprValue *elem = &val->data.x_array.data.s_none.elements[i]; - elem->special = ConstValSpecialStatic; - elem->type = val->type->data.array.child_type; - if ((err = buf_read_value_bytes(ira, codegen, source_node, buf + (elem_size * i), elem))) - return err; - } - return ErrorNone; - case ConstArraySpecialUndef: - zig_panic("TODO buf_read_value_bytes ConstArraySpecialUndef array type"); - case ConstArraySpecialBuf: - zig_panic("TODO buf_read_value_bytes ConstArraySpecialBuf array type"); - } - zig_unreachable(); - } - case ZigTypeIdVector: { - uint64_t elem_size = type_size(codegen, val->type->data.vector.elem_type); - uint32_t len = val->type->data.vector.len; - - val->data.x_vector.elements = create_const_vals(len); - for (uint32_t i = 0; i < len; i += 1) { - ConstExprValue *elem = &val->data.x_vector.elements[i]; - elem->special = ConstValSpecialStatic; - elem->type = val->type->data.vector.elem_type; - if ((err = buf_read_value_bytes(ira, codegen, source_node, buf + (elem_size * i), elem))) - return err; - } - return ErrorNone; - } + case ZigTypeIdArray: + return buf_read_value_bytes_array(ira, codegen, source_node, buf, val, val->type->data.array.child_type, + val->type->data.array.len); + case ZigTypeIdVector: + return buf_read_value_bytes_array(ira, codegen, source_node, buf, val, val->type->data.vector.elem_type, + val->type->data.vector.len); case ZigTypeIdEnum: switch (val->type->data.enumeration.layout) { case ContainerLayoutAuto: @@ -21634,6 +21699,8 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio case IrInstructionIdDeclVarGen: case IrInstructionIdPtrCastGen: case IrInstructionIdCmpxchgGen: + case IrInstructionIdArrayToVector: + case IrInstructionIdVectorToArray: zig_unreachable(); case IrInstructionIdReturn: @@ -22129,6 +22196,8 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdFromBytes: case IrInstructionIdToBytes: case IrInstructionIdEnumToInt: + case IrInstructionIdVectorToArray: + case IrInstructionIdArrayToVector: return false; case IrInstructionIdAsm: -- cgit v1.2.3