aboutsummaryrefslogtreecommitdiff
path: root/src/ir.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/ir.cpp')
-rw-r--r--src/ir.cpp526
1 files changed, 465 insertions, 61 deletions
diff --git a/src/ir.cpp b/src/ir.cpp
index b74a99b37d..5a4a53b804 100644
--- a/src/ir.cpp
+++ b/src/ir.cpp
@@ -747,6 +747,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestErr *) {
return IrInstructionIdTestErr;
}
+static constexpr IrInstructionId ir_instruction_id(IrInstructionMulAdd *) {
+ return IrInstructionIdMulAdd;
+}
+
static constexpr IrInstructionId ir_instruction_id(IrInstructionUnwrapErrCode *) {
return IrInstructionIdUnwrapErrCode;
}
@@ -987,8 +991,8 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionMarkErrRetTraceP
return IrInstructionIdMarkErrRetTracePtr;
}
-static constexpr IrInstructionId ir_instruction_id(IrInstructionSqrt *) {
- return IrInstructionIdSqrt;
+static constexpr IrInstructionId ir_instruction_id(IrInstructionFloatOp *) {
+ return IrInstructionIdFloatOp;
}
static constexpr IrInstructionId ir_instruction_id(IrInstructionCheckRuntimeScope *) {
@@ -2308,6 +2312,75 @@ static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode
return &instruction->base;
}
+
+//TODO Powi, Pow, minnum, maxnum, maximum, minimum, copysign,
+// lround, llround, lrint, llrint
+// So far this is only non-complicated type functions.
+const char *float_op_to_name(BuiltinFnId op, bool llvm_name) {
+ const bool b = llvm_name;
+
+ switch (op) {
+ case BuiltinFnIdSqrt:
+ return "sqrt";
+ case BuiltinFnIdSin:
+ return "sin";
+ case BuiltinFnIdCos:
+ return "cos";
+ case BuiltinFnIdExp:
+ return "exp";
+ case BuiltinFnIdExp2:
+ return "exp2";
+ case BuiltinFnIdLn:
+ return b ? "log" : "ln";
+ case BuiltinFnIdLog10:
+ return "log10";
+ case BuiltinFnIdLog2:
+ return "log2";
+ case BuiltinFnIdFabs:
+ return "fabs";
+ case BuiltinFnIdFloor:
+ return "floor";
+ case BuiltinFnIdCeil:
+ return "ceil";
+ case BuiltinFnIdTrunc:
+ return "trunc";
+ case BuiltinFnIdNearbyInt:
+ return b ? "nearbyint" : "nearbyInt";
+ case BuiltinFnIdRound:
+ return "round";
+ default:
+ zig_unreachable();
+ }
+}
+
+static IrInstruction *ir_build_float_op(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op1, BuiltinFnId op) {
+ IrInstructionFloatOp *instruction = ir_build_instruction<IrInstructionFloatOp>(irb, scope, source_node);
+ instruction->type = type;
+ instruction->op1 = op1;
+ instruction->op = op;
+
+ if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block);
+ ir_ref_instruction(op1, irb->current_basic_block);
+
+ return &instruction->base;
+}
+
+static IrInstruction *ir_build_mul_add(IrBuilder *irb, Scope *scope, AstNode *source_node,
+ IrInstruction *type_value, IrInstruction *op1, IrInstruction *op2, IrInstruction *op3) {
+ IrInstructionMulAdd *instruction = ir_build_instruction<IrInstructionMulAdd>(irb, scope, source_node);
+ instruction->type_value = type_value;
+ instruction->op1 = op1;
+ instruction->op2 = op2;
+ instruction->op3 = op3;
+
+ ir_ref_instruction(type_value, irb->current_basic_block);
+ ir_ref_instruction(op1, irb->current_basic_block);
+ ir_ref_instruction(op2, irb->current_basic_block);
+ ir_ref_instruction(op3, irb->current_basic_block);
+
+ return &instruction->base;
+}
+
static IrInstruction *ir_build_align_of(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type_value) {
IrInstructionAlignOf *instruction = ir_build_instruction<IrInstructionAlignOf>(irb, scope, source_node);
instruction->type_value = type_value;
@@ -3013,17 +3086,6 @@ static IrInstruction *ir_build_mark_err_ret_trace_ptr(IrBuilder *irb, Scope *sco
return &instruction->base;
}
-static IrInstruction *ir_build_sqrt(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op) {
- IrInstructionSqrt *instruction = ir_build_instruction<IrInstructionSqrt>(irb, scope, source_node);
- instruction->type = type;
- instruction->op = op;
-
- if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block);
- ir_ref_instruction(op, irb->current_basic_block);
-
- return &instruction->base;
-}
-
static IrInstruction *ir_build_has_decl(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *container, IrInstruction *name)
{
@@ -4028,6 +4090,33 @@ static IrInstruction *ir_gen_overflow_op(IrBuilder *irb, Scope *scope, AstNode *
return ir_build_overflow_op(irb, scope, node, op, type_value, op1, op2, result_ptr, nullptr);
}
+static IrInstruction *ir_gen_mul_add(IrBuilder *irb, Scope *scope, AstNode *node) {
+ assert(node->type == NodeTypeFnCallExpr);
+
+ AstNode *type_node = node->data.fn_call_expr.params.at(0);
+ AstNode *op1_node = node->data.fn_call_expr.params.at(1);
+ AstNode *op2_node = node->data.fn_call_expr.params.at(2);
+ AstNode *op3_node = node->data.fn_call_expr.params.at(3);
+
+ IrInstruction *type_value = ir_gen_node(irb, type_node, scope);
+ if (type_value == irb->codegen->invalid_instruction)
+ return irb->codegen->invalid_instruction;
+
+ IrInstruction *op1 = ir_gen_node(irb, op1_node, scope);
+ if (op1 == irb->codegen->invalid_instruction)
+ return irb->codegen->invalid_instruction;
+
+ IrInstruction *op2 = ir_gen_node(irb, op2_node, scope);
+ if (op2 == irb->codegen->invalid_instruction)
+ return irb->codegen->invalid_instruction;
+
+ IrInstruction *op3 = ir_gen_node(irb, op3_node, scope);
+ if (op3 == irb->codegen->invalid_instruction)
+ return irb->codegen->invalid_instruction;
+
+ return ir_build_mul_add(irb, scope, node, type_value, op1, op2, op3);
+}
+
static IrInstruction *ir_gen_this(IrBuilder *irb, Scope *orig_scope, AstNode *node) {
for (Scope *it_scope = orig_scope; it_scope != nullptr; it_scope = it_scope->parent) {
if (it_scope->id == ScopeIdDecls) {
@@ -4353,6 +4442,19 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
return ir_lval_wrap(irb, scope, bin_op, lval);
}
case BuiltinFnIdSqrt:
+ case BuiltinFnIdSin:
+ case BuiltinFnIdCos:
+ case BuiltinFnIdExp:
+ case BuiltinFnIdExp2:
+ case BuiltinFnIdLn:
+ case BuiltinFnIdLog2:
+ case BuiltinFnIdLog10:
+ case BuiltinFnIdFabs:
+ case BuiltinFnIdFloor:
+ case BuiltinFnIdCeil:
+ case BuiltinFnIdTrunc:
+ case BuiltinFnIdNearbyInt:
+ case BuiltinFnIdRound:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
@@ -4364,7 +4466,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
if (arg1_value == irb->codegen->invalid_instruction)
return arg1_value;
- IrInstruction *ir_sqrt = ir_build_sqrt(irb, scope, node, arg0_value, arg1_value);
+ IrInstruction *ir_sqrt = ir_build_float_op(irb, scope, node, arg0_value, arg1_value, builtin_fn->id);
return ir_lval_wrap(irb, scope, ir_sqrt, lval);
}
case BuiltinFnIdTruncate:
@@ -4687,6 +4789,8 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
return ir_lval_wrap(irb, scope, ir_gen_overflow_op(irb, scope, node, IrOverflowOpMul), lval);
case BuiltinFnIdShlWithOverflow:
return ir_lval_wrap(irb, scope, ir_gen_overflow_op(irb, scope, node, IrOverflowOpShl), lval);
+ case BuiltinFnIdMulAdd:
+ return ir_lval_wrap(irb, scope, ir_gen_mul_add(irb, scope, node), lval);
case BuiltinFnIdTypeName:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@@ -21187,6 +21291,125 @@ static IrInstruction *ir_analyze_instruction_overflow_op(IrAnalyze *ira, IrInstr
return result;
}
+static void ir_eval_mul_add(IrAnalyze *ira, IrInstructionMulAdd *source_instr, ZigType *float_type,
+ ConstExprValue *op1, ConstExprValue *op2, ConstExprValue *op3, ConstExprValue *out_val) {
+ if (float_type->id == ZigTypeIdComptimeFloat) {
+ f128M_mulAdd(&out_val->data.x_bigfloat.value, &op1->data.x_bigfloat.value, &op2->data.x_bigfloat.value,
+ &op3->data.x_bigfloat.value);
+ } else if (float_type->id == ZigTypeIdFloat) {
+ switch (float_type->data.floating.bit_count) {
+ case 16:
+ out_val->data.x_f16 = f16_mulAdd(op1->data.x_f16, op2->data.x_f16, op3->data.x_f16);
+ break;
+ case 32:
+ out_val->data.x_f32 = fmaf(op1->data.x_f32, op2->data.x_f32, op3->data.x_f32);
+ break;
+ case 64:
+ out_val->data.x_f64 = fma(op1->data.x_f64, op2->data.x_f64, op3->data.x_f64);
+ break;
+ case 128:
+ f128M_mulAdd(&op1->data.x_f128, &op2->data.x_f128, &op3->data.x_f128, &out_val->data.x_f128);
+ break;
+ default:
+ zig_unreachable();
+ }
+ } else {
+ zig_unreachable();
+ }
+}
+
+static IrInstruction *ir_analyze_instruction_mul_add(IrAnalyze *ira, IrInstructionMulAdd *instruction) {
+ IrInstruction *type_value = instruction->type_value->child;
+ if (type_is_invalid(type_value->value.type))
+ return ira->codegen->invalid_instruction;
+
+ ZigType *expr_type = ir_resolve_type(ira, type_value);
+ if (type_is_invalid(expr_type))
+ return ira->codegen->invalid_instruction;
+
+ // Only allow float types, and vectors of floats.
+ ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
+ if (float_type->id != ZigTypeIdFloat) {
+ ir_add_error(ira, type_value,
+ buf_sprintf("expected float or vector of float type, found '%s'", buf_ptr(&float_type->name)));
+ return ira->codegen->invalid_instruction;
+ }
+
+ IrInstruction *op1 = instruction->op1->child;
+ if (type_is_invalid(op1->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, expr_type);
+ if (type_is_invalid(casted_op1->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *op2 = instruction->op2->child;
+ if (type_is_invalid(op2->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *casted_op2 = ir_implicit_cast(ira, op2, expr_type);
+ if (type_is_invalid(casted_op2->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *op3 = instruction->op3->child;
+ if (type_is_invalid(op3->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *casted_op3 = ir_implicit_cast(ira, op3, expr_type);
+ if (type_is_invalid(casted_op3->value.type))
+ return ira->codegen->invalid_instruction;
+
+ if (instr_is_comptime(casted_op1) &&
+ instr_is_comptime(casted_op2) &&
+ instr_is_comptime(casted_op3)) {
+ ConstExprValue *op1_const = ir_resolve_const(ira, casted_op1, UndefBad);
+ if (!op1_const)
+ return ira->codegen->invalid_instruction;
+ ConstExprValue *op2_const = ir_resolve_const(ira, casted_op2, UndefBad);
+ if (!op2_const)
+ return ira->codegen->invalid_instruction;
+ ConstExprValue *op3_const = ir_resolve_const(ira, casted_op3, UndefBad);
+ if (!op3_const)
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *result = ir_const(ira, &instruction->base, expr_type);
+ ConstExprValue *out_val = &result->value;
+
+ if (expr_type->id == ZigTypeIdVector) {
+ expand_undef_array(ira->codegen, op1_const);
+ expand_undef_array(ira->codegen, op2_const);
+ expand_undef_array(ira->codegen, op3_const);
+ out_val->special = ConstValSpecialUndef;
+ expand_undef_array(ira->codegen, out_val);
+ size_t len = expr_type->data.vector.len;
+ for (size_t i = 0; i < len; i += 1) {
+ ConstExprValue *float_operand_op1 = &op1_const->data.x_array.data.s_none.elements[i];
+ ConstExprValue *float_operand_op2 = &op2_const->data.x_array.data.s_none.elements[i];
+ ConstExprValue *float_operand_op3 = &op3_const->data.x_array.data.s_none.elements[i];
+ ConstExprValue *float_out_val = &out_val->data.x_array.data.s_none.elements[i];
+ assert(float_operand_op1->type == float_type);
+ assert(float_operand_op2->type == float_type);
+ assert(float_operand_op3->type == float_type);
+ assert(float_out_val->type == float_type);
+ ir_eval_mul_add(ira, instruction, float_type,
+ op1_const, op2_const, op3_const, float_out_val);
+ float_out_val->type = float_type;
+ }
+ out_val->type = expr_type;
+ out_val->special = ConstValSpecialStatic;
+ } else {
+ ir_eval_mul_add(ira, instruction, float_type, op1_const, op2_const, op3_const, out_val);
+ }
+ return result;
+ }
+
+ IrInstruction *result = ir_build_mul_add(&ira->new_irb,
+ instruction->base.scope, instruction->base.source_node,
+ type_value, casted_op1, casted_op2, casted_op3);
+ result->value.type = expr_type;
+ return result;
+}
+
static IrInstruction *ir_analyze_instruction_test_err(IrAnalyze *ira, IrInstructionTestErr *instruction) {
IrInstruction *value = instruction->value->child;
if (type_is_invalid(value->value.type))
@@ -23048,70 +23271,248 @@ static IrInstruction *ir_analyze_instruction_mark_err_ret_trace_ptr(IrAnalyze *i
return result;
}
-static IrInstruction *ir_analyze_instruction_sqrt(IrAnalyze *ira, IrInstructionSqrt *instruction) {
- ZigType *float_type = ir_resolve_type(ira, instruction->type->child);
- if (type_is_invalid(float_type))
- return ira->codegen->invalid_instruction;
+static void ir_eval_float_op(IrAnalyze *ira, IrInstructionFloatOp *source_instr, ZigType *float_type,
+ ConstExprValue *op, ConstExprValue *out_val) {
+ assert(ira && source_instr && float_type && out_val && op);
+ assert(float_type->id == ZigTypeIdFloat ||
+ float_type->id == ZigTypeIdComptimeFloat);
- IrInstruction *op = instruction->op->child;
- if (type_is_invalid(op->value.type))
+ BuiltinFnId fop = source_instr->op;
+ unsigned bits;
+
+ if (float_type->id == ZigTypeIdComptimeFloat) {
+ bits = 128;
+ } else if (float_type->id == ZigTypeIdFloat)
+ bits = float_type->data.floating.bit_count;
+
+ switch (bits) {
+ case 16: {
+ switch (fop) {
+ case BuiltinFnIdSqrt:
+ out_val->data.x_f16 = f16_sqrt(op->data.x_f16);
+ break;
+ case BuiltinFnIdSin:
+ case BuiltinFnIdCos:
+ case BuiltinFnIdExp:
+ case BuiltinFnIdExp2:
+ case BuiltinFnIdLn:
+ case BuiltinFnIdLog10:
+ case BuiltinFnIdLog2:
+ case BuiltinFnIdFabs:
+ case BuiltinFnIdFloor:
+ case BuiltinFnIdCeil:
+ case BuiltinFnIdTrunc:
+ case BuiltinFnIdNearbyInt:
+ case BuiltinFnIdRound:
+ zig_panic("unimplemented f16 builtin");
+ default:
+ zig_unreachable();
+ };
+ break;
+ };
+ case 32: {
+ switch (fop) {
+ case BuiltinFnIdSqrt:
+ out_val->data.x_f32 = sqrtf(op->data.x_f32);
+ break;
+ case BuiltinFnIdSin:
+ out_val->data.x_f32 = sinf(op->data.x_f32);
+ break;
+ case BuiltinFnIdCos:
+ out_val->data.x_f32 = cosf(op->data.x_f32);
+ break;
+ case BuiltinFnIdExp:
+ out_val->data.x_f32 = expf(op->data.x_f32);
+ break;
+ case BuiltinFnIdExp2:
+ out_val->data.x_f32 = exp2f(op->data.x_f32);
+ break;
+ case BuiltinFnIdLn:
+ out_val->data.x_f32 = logf(op->data.x_f32);
+ break;
+ case BuiltinFnIdLog10:
+ out_val->data.x_f32 = log10f(op->data.x_f32);
+ break;
+ case BuiltinFnIdLog2:
+ out_val->data.x_f32 = log2f(op->data.x_f32);
+ break;
+ case BuiltinFnIdFabs:
+ out_val->data.x_f32 = fabsf(op->data.x_f32);
+ break;
+ case BuiltinFnIdFloor:
+ out_val->data.x_f32 = floorf(op->data.x_f32);
+ break;
+ case BuiltinFnIdCeil:
+ out_val->data.x_f32 = ceilf(op->data.x_f32);
+ break;
+ case BuiltinFnIdTrunc:
+ out_val->data.x_f32 = truncf(op->data.x_f32);
+ break;
+ case BuiltinFnIdNearbyInt:
+ out_val->data.x_f32 = nearbyintf(op->data.x_f32);
+ break;
+ case BuiltinFnIdRound:
+ out_val->data.x_f32 = roundf(op->data.x_f32);
+ break;
+ default:
+ zig_unreachable();
+ };
+ break;
+ };
+ case 64: {
+ switch (fop) {
+ case BuiltinFnIdSqrt:
+ out_val->data.x_f64 = sqrt(op->data.x_f64);
+ break;
+ case BuiltinFnIdSin:
+ out_val->data.x_f64 = sin(op->data.x_f64);
+ break;
+ case BuiltinFnIdCos:
+ out_val->data.x_f64 = cos(op->data.x_f64);
+ break;
+ case BuiltinFnIdExp:
+ out_val->data.x_f64 = exp(op->data.x_f64);
+ break;
+ case BuiltinFnIdExp2:
+ out_val->data.x_f64 = exp2(op->data.x_f64);
+ break;
+ case BuiltinFnIdLn:
+ out_val->data.x_f64 = log(op->data.x_f64);
+ break;
+ case BuiltinFnIdLog10:
+ out_val->data.x_f64 = log10(op->data.x_f64);
+ break;
+ case BuiltinFnIdLog2:
+ out_val->data.x_f64 = log2(op->data.x_f64);
+ break;
+ case BuiltinFnIdFabs:
+ out_val->data.x_f64 = fabs(op->data.x_f64);
+ break;
+ case BuiltinFnIdFloor:
+ out_val->data.x_f64 = floor(op->data.x_f64);
+ break;
+ case BuiltinFnIdCeil:
+ out_val->data.x_f64 = ceil(op->data.x_f64);
+ break;
+ case BuiltinFnIdTrunc:
+ out_val->data.x_f64 = trunc(op->data.x_f64);
+ break;
+ case BuiltinFnIdNearbyInt:
+ out_val->data.x_f64 = nearbyint(op->data.x_f64);
+ break;
+ case BuiltinFnIdRound:
+ out_val->data.x_f64 = round(op->data.x_f64);
+ break;
+ default:
+ zig_unreachable();
+ }
+ break;
+ };
+ case 128: {
+ float128_t *out, *in;
+ if (float_type->id == ZigTypeIdComptimeFloat) {
+ out = &out_val->data.x_bigfloat.value;
+ in = &op->data.x_bigfloat.value;
+ } else {
+ out = &out_val->data.x_f128;
+ in = &op->data.x_f128;
+ }
+ switch (fop) {
+ case BuiltinFnIdSqrt:
+ f128M_sqrt(in, out);
+ break;
+ case BuiltinFnIdNearbyInt:
+ case BuiltinFnIdSin:
+ case BuiltinFnIdCos:
+ case BuiltinFnIdExp:
+ case BuiltinFnIdExp2:
+ case BuiltinFnIdLn:
+ case BuiltinFnIdLog10:
+ case BuiltinFnIdLog2:
+ case BuiltinFnIdFabs:
+ case BuiltinFnIdFloor:
+ case BuiltinFnIdCeil:
+ case BuiltinFnIdTrunc:
+ case BuiltinFnIdRound:
+ zig_panic("unimplemented f128 builtin");
+ default:
+ zig_unreachable();
+ }
+ break;
+ };
+ default:
+ zig_unreachable();
+ }
+}
+
+static IrInstruction *ir_analyze_instruction_float_op(IrAnalyze *ira, IrInstructionFloatOp *instruction) {
+ IrInstruction *type = instruction->type->child;
+ if (type_is_invalid(type->value.type))
+ return ira->codegen->invalid_instruction;
+
+ ZigType *expr_type = ir_resolve_type(ira, type);
+ if (type_is_invalid(expr_type))
return ira->codegen->invalid_instruction;
- bool ok_type = float_type->id == ZigTypeIdComptimeFloat || float_type->id == ZigTypeIdFloat;
- if (!ok_type) {
- ir_add_error(ira, instruction->type, buf_sprintf("@sqrt does not support type '%s'", buf_ptr(&float_type->name)));
+ // Only allow float types, and vectors of floats.
+ ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
+ if (float_type->id != ZigTypeIdFloat && float_type->id != ZigTypeIdComptimeFloat) {
+ ir_add_error(ira, instruction->type, buf_sprintf("@%s does not support type '%s'", float_op_to_name(instruction->op, false), buf_ptr(&float_type->name)));
return ira->codegen->invalid_instruction;
}
- IrInstruction *casted_op = ir_implicit_cast(ira, op, float_type);
- if (type_is_invalid(casted_op->value.type))
+ IrInstruction *op1 = instruction->op1->child;
+ if (type_is_invalid(op1->value.type))
return ira->codegen->invalid_instruction;
- if (instr_is_comptime(casted_op)) {
- ConstExprValue *val = ir_resolve_const(ira, casted_op, UndefBad);
- if (!val)
+ IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, float_type);
+ if (type_is_invalid(casted_op1->value.type))
+ return ira->codegen->invalid_instruction;
+
+ if (instr_is_comptime(casted_op1)) {
+ // Our comptime 16-bit and 128-bit support is quite limited.
+ if ((float_type->id == ZigTypeIdComptimeFloat ||
+ float_type->data.floating.bit_count == 16 ||
+ float_type->data.floating.bit_count == 128) &&
+ instruction->op != BuiltinFnIdSqrt) {
+ ir_add_error(ira, instruction->type, buf_sprintf("@%s does not support type '%s'", float_op_to_name(instruction->op, false), buf_ptr(&float_type->name)));
+ return ira->codegen->invalid_instruction;
+ }
+
+ ConstExprValue *op1_const = ir_resolve_const(ira, casted_op1, UndefBad);
+ if (!op1_const)
return ira->codegen->invalid_instruction;
- IrInstruction *result = ir_const(ira, &instruction->base, float_type);
+ IrInstruction *result = ir_const(ira, &instruction->base, expr_type);
ConstExprValue *out_val = &result->value;
- if (float_type->id == ZigTypeIdComptimeFloat) {
- bigfloat_sqrt(&out_val->data.x_bigfloat, &val->data.x_bigfloat);
- } else if (float_type->id == ZigTypeIdFloat) {
- switch (float_type->data.floating.bit_count) {
- case 16:
- out_val->data.x_f16 = f16_sqrt(val->data.x_f16);
- break;
- case 32:
- out_val->data.x_f32 = sqrtf(val->data.x_f32);
- break;
- case 64:
- out_val->data.x_f64 = sqrt(val->data.x_f64);
- break;
- case 128:
- f128M_sqrt(&val->data.x_f128, &out_val->data.x_f128);
- break;
- default:
- zig_unreachable();
+ if (expr_type->id == ZigTypeIdVector) {
+ expand_undef_array(ira->codegen, op1_const);
+ out_val->special = ConstValSpecialUndef;
+ expand_undef_array(ira->codegen, out_val);
+ size_t len = expr_type->data.vector.len;
+ for (size_t i = 0; i < len; i += 1) {
+ ConstExprValue *float_operand_op1 = &op1_const->data.x_array.data.s_none.elements[i];
+ ConstExprValue *float_out_val = &out_val->data.x_array.data.s_none.elements[i];
+ assert(float_operand_op1->type == float_type);
+ assert(float_out_val->type == float_type);
+ ir_eval_float_op(ira, instruction, float_type,
+ op1_const, float_out_val);
+ float_out_val->type = float_type;
}
+ out_val->type = expr_type;
+ out_val->special = ConstValSpecialStatic;
} else {
- zig_unreachable();
+ ir_eval_float_op(ira, instruction, float_type, op1_const, out_val);
}
-
return result;
}
ir_assert(float_type->id == ZigTypeIdFloat, &instruction->base);
- if (float_type->data.floating.bit_count != 16 &&
- float_type->data.floating.bit_count != 32 &&
- float_type->data.floating.bit_count != 64) {
- ir_add_error(ira, instruction->type, buf_sprintf("compiler TODO: add implementation of sqrt for '%s'", buf_ptr(&float_type->name)));
- return ira->codegen->invalid_instruction;
- }
- IrInstruction *result = ir_build_sqrt(&ira->new_irb, instruction->base.scope,
- instruction->base.source_node, nullptr, casted_op);
- result->value.type = float_type;
+ IrInstruction *result = ir_build_float_op(&ira->new_irb, instruction->base.scope,
+ instruction->base.source_node, nullptr, casted_op1, instruction->op);
+ result->value.type = expr_type;
return result;
}
@@ -23596,8 +23997,10 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
return ir_analyze_instruction_merge_err_ret_traces(ira, (IrInstructionMergeErrRetTraces *)instruction);
case IrInstructionIdMarkErrRetTracePtr:
return ir_analyze_instruction_mark_err_ret_trace_ptr(ira, (IrInstructionMarkErrRetTracePtr *)instruction);
- case IrInstructionIdSqrt:
- return ir_analyze_instruction_sqrt(ira, (IrInstructionSqrt *)instruction);
+ case IrInstructionIdFloatOp:
+ return ir_analyze_instruction_float_op(ira, (IrInstructionFloatOp *)instruction);
+ case IrInstructionIdMulAdd:
+ return ir_analyze_instruction_mul_add(ira, (IrInstructionMulAdd *)instruction);
case IrInstructionIdIntToErr:
return ir_analyze_instruction_int_to_err(ira, (IrInstructionIntToErr *)instruction);
case IrInstructionIdErrToInt:
@@ -23836,7 +24239,8 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdCoroFree:
case IrInstructionIdCoroPromise:
case IrInstructionIdPromiseResultType:
- case IrInstructionIdSqrt:
+ case IrInstructionIdFloatOp:
+ case IrInstructionIdMulAdd:
case IrInstructionIdAtomicLoad:
case IrInstructionIdIntCast:
case IrInstructionIdFloatCast: