aboutsummaryrefslogtreecommitdiff
path: root/src/codegen
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2022-03-07 04:00:45 -0500
committerGitHub <noreply@github.com>2022-03-07 04:00:45 -0500
commit8c32d989c995f8675f1824fb084245b833b26223 (patch)
tree81f80b17835931b5fa0a877e9c554225f75ea7a9 /src/codegen
parent6547da8f97b94453fb08f582c2c7ce4eb1782a80 (diff)
parent3c1ebf95567db0c844c2618c5b8971d62c27352f (diff)
downloadzig-8c32d989c995f8675f1824fb084245b833b26223.tar.gz
zig-8c32d989c995f8675f1824fb084245b833b26223.zip
Merge pull request #11054 from schmee/mul-add
Implement `@mulAdd` for scalar floats
Diffstat (limited to 'src/codegen')
-rw-r--r--src/codegen/c.zig32
-rw-r--r--src/codegen/llvm.zig143
2 files changed, 172 insertions, 3 deletions
diff --git a/src/codegen/c.zig b/src/codegen/c.zig
index ba7bb6fa3a..2a10a8094a 100644
--- a/src/codegen/c.zig
+++ b/src/codegen/c.zig
@@ -16,6 +16,7 @@ const trace = @import("../tracy.zig").trace;
const LazySrcLoc = Module.LazySrcLoc;
const Air = @import("../Air.zig");
const Liveness = @import("../Liveness.zig");
+const CType = @import("../type.zig").CType;
const Mutability = enum { Const, Mut };
const BigIntConst = std.math.big.int.Const;
@@ -1635,6 +1636,8 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
.trunc_float,
=> |tag| return f.fail("TODO: C backend: implement unary op for tag '{s}'", .{@tagName(tag)}),
+ .mul_add => try airMulAdd(f, inst),
+
.add_with_overflow => try airAddWithOverflow(f, inst),
.sub_with_overflow => try airSubWithOverflow(f, inst),
.mul_with_overflow => try airMulWithOverflow(f, inst),
@@ -3621,6 +3624,35 @@ fn airWasmMemoryGrow(f: *Function, inst: Air.Inst.Index) !CValue {
return local;
}
+fn airMulAdd(f: *Function, inst: Air.Inst.Index) !CValue {
+ if (f.liveness.isUnused(inst)) return CValue.none;
+ const pl_op = f.air.instructions.items(.data)[inst].pl_op;
+ const extra = f.air.extraData(Air.Bin, pl_op.payload).data;
+ const inst_ty = f.air.typeOfIndex(inst);
+ const mulend1 = try f.resolveInst(extra.lhs);
+ const mulend2 = try f.resolveInst(extra.rhs);
+ const addend = try f.resolveInst(pl_op.operand);
+ const writer = f.object.writer();
+ const target = f.object.dg.module.getTarget();
+ const fn_name = switch (inst_ty.floatBits(target)) {
+ 16, 32 => "fmaf",
+ 64 => "fma",
+ 80 => if (CType.longdouble.sizeInBits(target) == 80) "fmal" else "__fmax",
+ 128 => if (CType.longdouble.sizeInBits(target) == 128) "fmal" else "fmaq",
+ else => unreachable,
+ };
+ const local = try f.allocLocal(inst_ty, .Const);
+ try writer.writeAll(" = ");
+ try writer.print("{s}(", .{fn_name});
+ try f.writeCValue(writer, mulend1);
+ try writer.writeAll(", ");
+ try f.writeCValue(writer, mulend2);
+ try writer.writeAll(", ");
+ try f.writeCValue(writer, addend);
+ try writer.writeAll(");\n");
+ return local;
+}
+
fn toMemoryOrder(order: std.builtin.AtomicOrder) [:0]const u8 {
return switch (order) {
.Unordered => "memory_order_relaxed",
diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig
index 446876dfe5..c837d9a00d 100644
--- a/src/codegen/llvm.zig
+++ b/src/codegen/llvm.zig
@@ -18,6 +18,7 @@ const target_util = @import("../target.zig");
const Value = @import("../value.zig").Value;
const Type = @import("../type.zig").Type;
const LazySrcLoc = Module.LazySrcLoc;
+const CType = @import("../type.zig").CType;
const Error = error{ OutOfMemory, CodegenFail };
@@ -2189,6 +2190,7 @@ pub const FuncGen = struct {
.min => try self.airMin(inst),
.max => try self.airMax(inst),
.slice => try self.airSlice(inst),
+ .mul_add => try self.airMulAdd(inst),
.add_with_overflow => try self.airOverflow(inst, "llvm.sadd.with.overflow", "llvm.uadd.with.overflow"),
.sub_with_overflow => try self.airOverflow(inst, "llvm.ssub.with.overflow", "llvm.usub.with.overflow"),
@@ -3842,6 +3844,46 @@ pub const FuncGen = struct {
return overflow_bit;
}
+ fn airMulAdd(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+ if (self.liveness.isUnused(inst)) return null;
+
+ const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+ const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+
+ const mulend1 = try self.resolveInst(extra.lhs);
+ const mulend2 = try self.resolveInst(extra.rhs);
+ const addend = try self.resolveInst(pl_op.operand);
+
+ const ty = self.air.typeOfIndex(inst);
+ const llvm_ty = try self.dg.llvmType(ty);
+ const target = self.dg.module.getTarget();
+
+ const Strat = union(enum) {
+ intrinsic,
+ libc: [*:0]const u8,
+ };
+ const strat: Strat = switch (ty.floatBits(target)) {
+ 16, 32, 64 => Strat.intrinsic,
+ 80 => if (CType.longdouble.sizeInBits(target) == 80) Strat{ .intrinsic = {} } else Strat{ .libc = "__fmax" },
+ // LLVM always lowers the fma builtin for f128 to fmal, which is for `long double`.
+ // On some targets this will be correct; on others it will be incorrect.
+ 128 => if (CType.longdouble.sizeInBits(target) == 128) Strat{ .intrinsic = {} } else Strat{ .libc = "fmaq" },
+ else => unreachable,
+ };
+
+ const llvm_fn = switch (strat) {
+ .intrinsic => self.getIntrinsic("llvm.fma", &.{llvm_ty}),
+ .libc => |fn_name| self.dg.object.llvm_module.getNamedFunction(fn_name) orelse b: {
+ const param_types = [_]*const llvm.Type{ llvm_ty, llvm_ty, llvm_ty };
+ const fn_type = llvm.functionType(llvm_ty, &param_types, param_types.len, .False);
+ break :b self.dg.object.llvm_module.addFunction(fn_name, fn_type);
+ },
+ };
+
+ const params = [_]*const llvm.Value{ mulend1, mulend2, addend };
+ return self.builder.buildCall(llvm_fn, &params, params.len, .C, .Auto, "");
+ }
+
fn airShlWithOverflow(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
if (self.liveness.isUnused(inst))
return null;
@@ -4020,8 +4062,15 @@ pub const FuncGen = struct {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const operand = try self.resolveInst(ty_op.operand);
- const dest_llvm_ty = try self.dg.llvmType(self.air.typeOfIndex(inst));
-
+ const operand_ty = self.air.typeOf(ty_op.operand);
+ const dest_ty = self.air.typeOfIndex(inst);
+ const target = self.dg.module.getTarget();
+ const dest_bits = dest_ty.floatBits(target);
+ const src_bits = operand_ty.floatBits(target);
+ if (!backendSupportsF80(target) and (src_bits == 80 or dest_bits == 80)) {
+ return softF80TruncOrExt(self, operand, src_bits, dest_bits);
+ }
+ const dest_llvm_ty = try self.dg.llvmType(dest_ty);
return self.builder.buildFPTrunc(operand, dest_llvm_ty, "");
}
@@ -4031,8 +4080,15 @@ pub const FuncGen = struct {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const operand = try self.resolveInst(ty_op.operand);
+ const operand_ty = self.air.typeOf(ty_op.operand);
+ const dest_ty = self.air.typeOfIndex(inst);
+ const target = self.dg.module.getTarget();
+ const dest_bits = dest_ty.floatBits(target);
+ const src_bits = operand_ty.floatBits(target);
+ if (!backendSupportsF80(target) and (src_bits == 80 or dest_bits == 80)) {
+ return softF80TruncOrExt(self, operand, src_bits, dest_bits);
+ }
const dest_llvm_ty = try self.dg.llvmType(self.air.typeOfIndex(inst));
-
return self.builder.buildFPExt(operand, dest_llvm_ty, "");
}
@@ -5064,6 +5120,87 @@ pub const FuncGen = struct {
return null;
}
+ fn softF80TruncOrExt(
+ self: *FuncGen,
+ operand: *const llvm.Value,
+ src_bits: u16,
+ dest_bits: u16,
+ ) !?*const llvm.Value {
+ const target = self.dg.module.getTarget();
+
+ var param_llvm_ty: *const llvm.Type = self.context.intType(80);
+ var ret_llvm_ty: *const llvm.Type = param_llvm_ty;
+ var fn_name: [*:0]const u8 = undefined;
+ var arg = operand;
+ var final_cast: ?*const llvm.Type = null;
+
+ assert(src_bits == 80 or dest_bits == 80);
+
+ if (src_bits == 80) switch (dest_bits) {
+ 16 => {
+ // See corresponding condition at definition of
+ // __truncxfhf2 in compiler-rt.
+ if (target.cpu.arch.isAARCH64()) {
+ ret_llvm_ty = self.context.halfType();
+ } else {
+ ret_llvm_ty = self.context.intType(16);
+ final_cast = self.context.halfType();
+ }
+ fn_name = "__truncxfhf2";
+ },
+ 32 => {
+ ret_llvm_ty = self.context.floatType();
+ fn_name = "__truncxfsf2";
+ },
+ 64 => {
+ ret_llvm_ty = self.context.doubleType();
+ fn_name = "__truncxfdf2";
+ },
+ 80 => return operand,
+ 128 => {
+ ret_llvm_ty = self.context.fp128Type();
+ fn_name = "__extendxftf2";
+ },
+ else => unreachable,
+ } else switch (src_bits) {
+ 16 => {
+ // See corresponding condition at definition of
+ // __extendhfxf2 in compiler-rt.
+ param_llvm_ty = if (target.cpu.arch.isAARCH64())
+ self.context.halfType()
+ else
+ self.context.intType(16);
+ arg = self.builder.buildBitCast(arg, param_llvm_ty, "");
+ fn_name = "__extendhfxf2";
+ },
+ 32 => {
+ param_llvm_ty = self.context.floatType();
+ fn_name = "__extendsfxf2";
+ },
+ 64 => {
+ param_llvm_ty = self.context.doubleType();
+ fn_name = "__extenddfxf2";
+ },
+ 80 => return operand,
+ 128 => {
+ param_llvm_ty = self.context.fp128Type();
+ fn_name = "__trunctfxf2";
+ },
+ else => unreachable,
+ }
+
+ const llvm_fn = self.dg.object.llvm_module.getNamedFunction(fn_name) orelse f: {
+ const param_types = [_]*const llvm.Type{param_llvm_ty};
+ const fn_type = llvm.functionType(ret_llvm_ty, &param_types, param_types.len, .False);
+ break :f self.dg.object.llvm_module.addFunction(fn_name, fn_type);
+ };
+
+ var args: [1]*const llvm.Value = .{arg};
+ const result = self.builder.buildCall(llvm_fn, &args, args.len, .C, .Auto, "");
+ const final_cast_llvm_ty = final_cast orelse return result;
+ return self.builder.buildBitCast(result, final_cast_llvm_ty, "");
+ }
+
fn getErrorNameTable(self: *FuncGen) !*const llvm.Value {
if (self.dg.object.error_name_table) |table| {
return table;