diff options
| author | John Schmidt <john.schmidt.h@gmail.com> | 2022-03-15 23:25:38 +0100 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2022-03-16 20:11:05 -0700 |
| commit | c8ed813097ebb679e858a7764673f6236e638ea4 (patch) | |
| tree | 2f0c506bbebe70a70e9678110f87d43cc204147f /src | |
| parent | 312536540baf26728a56304811f63f01a7414b7a (diff) | |
| download | zig-c8ed813097ebb679e858a7764673f6236e638ea4.tar.gz zig-c8ed813097ebb679e858a7764673f6236e638ea4.zip | |
Implement `@mulAdd` for vectors
Diffstat (limited to 'src')
| -rw-r--r-- | src/Sema.zig | 124 | ||||
| -rw-r--r-- | src/codegen/llvm.zig | 55 |
2 files changed, 121 insertions, 58 deletions
diff --git a/src/Sema.zig b/src/Sema.zig index 3df23ca2df..cc52f6c549 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -14499,19 +14499,24 @@ fn zirMulAdd(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air. const target = sema.mod.getTarget(); + const maybe_mulend1 = try sema.resolveMaybeUndefVal(block, mulend1_src, mulend1); + const maybe_mulend2 = try sema.resolveMaybeUndefVal(block, mulend2_src, mulend2); + const maybe_addend = try sema.resolveMaybeUndefVal(block, addend_src, addend); + switch (ty.zigTypeTag()) { - .ComptimeFloat, .Float => { - const maybe_mulend1 = try sema.resolveMaybeUndefVal(block, mulend1_src, mulend1); - const maybe_mulend2 = try sema.resolveMaybeUndefVal(block, mulend2_src, mulend2); - const maybe_addend = try sema.resolveMaybeUndefVal(block, addend_src, addend); + .ComptimeFloat, .Float, .Vector => {}, + else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{ty}), + } - const runtime_src = if (maybe_mulend1) |mulend1_val| rs: { - if (maybe_mulend2) |mulend2_val| { - if (mulend2_val.isUndef()) return sema.addConstUndef(ty); + const runtime_src = if (maybe_mulend1) |mulend1_val| rs: { + if (maybe_mulend2) |mulend2_val| { + if (mulend2_val.isUndef()) return sema.addConstUndef(ty); - if (maybe_addend) |addend_val| { - if (addend_val.isUndef()) return sema.addConstUndef(ty); + if (maybe_addend) |addend_val| { + if (addend_val.isUndef()) return sema.addConstUndef(ty); + switch (ty.zigTypeTag()) { + .ComptimeFloat, .Float => { const result_val = try Value.mulAdd( ty, mulend1_val, @@ -14521,47 +14526,70 @@ fn zirMulAdd(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air. target, ); return sema.addConstant(ty, result_val); - } else { - break :rs addend_src; - } - } else { - if (maybe_addend) |addend_val| { - if (addend_val.isUndef()) return sema.addConstUndef(ty); - } - break :rs mulend2_src; - } - } else rs: { - if (maybe_mulend2) |mulend2_val| { - if (mulend2_val.isUndef()) return sema.addConstUndef(ty); - } - if (maybe_addend) |addend_val| { - if (addend_val.isUndef()) return sema.addConstUndef(ty); - } - break :rs mulend1_src; - }; + }, + .Vector => { + const scalar_ty = ty.scalarType(); + switch (scalar_ty.zigTypeTag()) { + .ComptimeFloat, .Float => {}, + else => return sema.fail(block, src, "expected vector of floats, found vector of '{}'", .{scalar_ty}), + } - try sema.requireRuntimeBlock(block, runtime_src); - return block.addInst(.{ - .tag = .mul_add, - .data = .{ .pl_op = .{ - .operand = addend, - .payload = try sema.addExtra(Air.Bin{ - .lhs = mulend1, - .rhs = mulend2, - }), - } }, - }); - }, - .Vector => { - const scalar_ty = ty.scalarType(); - switch (scalar_ty.zigTypeTag()) { - .ComptimeFloat, .Float => {}, - else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{scalar_ty}), + const vec_len = ty.vectorLen(); + const result_ty = try Type.vector(sema.arena, vec_len, scalar_ty); + var mulend1_buf: Value.ElemValueBuffer = undefined; + var mulend2_buf: Value.ElemValueBuffer = undefined; + var addend_buf: Value.ElemValueBuffer = undefined; + const elems = try sema.arena.alloc(Value, vec_len); + for (elems) |*elem, i| { + const mulend1_elem_val = mulend1_val.elemValueBuffer(i, &mulend1_buf); + const mulend2_elem_val = mulend2_val.elemValueBuffer(i, &mulend2_buf); + const addend_elem_val = addend_val.elemValueBuffer(i, &addend_buf); + elem.* = try Value.mulAdd( + scalar_ty, + mulend1_elem_val, + mulend2_elem_val, + addend_elem_val, + sema.arena, + target, + ); + } + return sema.addConstant( + result_ty, + try Value.Tag.aggregate.create(sema.arena, elems), + ); + }, + else => unreachable, + } + } else { + break :rs addend_src; } - return sema.fail(block, src, "TODO: implement @mulAdd for vectors", .{}); - }, - else => return sema.fail(block, src, "expected vector of floats or float type, found '{}'", .{ty}), - } + } else { + if (maybe_addend) |addend_val| { + if (addend_val.isUndef()) return sema.addConstUndef(ty); + } + break :rs mulend2_src; + } + } else rs: { + if (maybe_mulend2) |mulend2_val| { + if (mulend2_val.isUndef()) return sema.addConstUndef(ty); + } + if (maybe_addend) |addend_val| { + if (addend_val.isUndef()) return sema.addConstUndef(ty); + } + break :rs mulend1_src; + }; + + try sema.requireRuntimeBlock(block, runtime_src); + return block.addInst(.{ + .tag = .mul_add, + .data = .{ .pl_op = .{ + .operand = addend, + .payload = try sema.addExtra(Air.Bin{ + .lhs = mulend1, + .rhs = mulend2, + }), + } }, + }); } fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 67eac94af3..b41611813e 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -5166,7 +5166,13 @@ pub const FuncGen = struct { intrinsic, libc: [*:0]const u8, }; - const strat: Strat = switch (ty.floatBits(target)) { + + const scalar_ty = if (ty.zigTypeTag() == .Vector) + ty.elemType() + else + ty; + + const strat: Strat = switch (scalar_ty.floatBits(target)) { 16, 32, 64 => Strat.intrinsic, 80 => if (CType.longdouble.sizeInBits(target) == 80) Strat{ .intrinsic = {} } else Strat{ .libc = "__fmax" }, // LLVM always lowers the fma builtin for f128 to fmal, which is for `long double`. @@ -5175,17 +5181,46 @@ pub const FuncGen = struct { else => unreachable, }; - const llvm_fn = switch (strat) { - .intrinsic => self.getIntrinsic("llvm.fma", &.{llvm_ty}), - .libc => |fn_name| self.dg.object.llvm_module.getNamedFunction(fn_name) orelse b: { - const param_types = [_]*const llvm.Type{ llvm_ty, llvm_ty, llvm_ty }; - const fn_type = llvm.functionType(llvm_ty, ¶m_types, param_types.len, .False); - break :b self.dg.object.llvm_module.addFunction(fn_name, fn_type); + switch (strat) { + .intrinsic => { + const llvm_fn = self.getIntrinsic("llvm.fma", &.{llvm_ty}); + const params = [_]*const llvm.Value{ mulend1, mulend2, addend }; + return self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, ""); }, - }; + .libc => |fn_name| { + const scalar_llvm_ty = try self.dg.llvmType(scalar_ty); + const llvm_fn = self.dg.object.llvm_module.getNamedFunction(fn_name) orelse b: { + const param_types = [_]*const llvm.Type{ scalar_llvm_ty, scalar_llvm_ty, scalar_llvm_ty }; + const fn_type = llvm.functionType(scalar_llvm_ty, ¶m_types, param_types.len, .False); + break :b self.dg.object.llvm_module.addFunction(fn_name, fn_type); + }; + + if (ty.zigTypeTag() == .Vector) { + const llvm_i32 = self.context.intType(32); + const vector_llvm_ty = try self.dg.llvmType(ty); + + var i: usize = 0; + var vector = vector_llvm_ty.getUndef(); + while (i < ty.vectorLen()) : (i += 1) { + const index_i32 = llvm_i32.constInt(i, .False); + + const mulend1_elem = self.builder.buildExtractElement(mulend1, index_i32, ""); + const mulend2_elem = self.builder.buildExtractElement(mulend2, index_i32, ""); + const addend_elem = self.builder.buildExtractElement(addend, index_i32, ""); - const params = [_]*const llvm.Value{ mulend1, mulend2, addend }; - return self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, ""); + const params = [_]*const llvm.Value{ mulend1_elem, mulend2_elem, addend_elem }; + const mul_add = self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, ""); + + vector = self.builder.buildInsertElement(vector, mul_add, index_i32, ""); + } + + return vector; + } else { + const params = [_]*const llvm.Value{ mulend1, mulend2, addend }; + return self.builder.buildCall(llvm_fn, ¶ms, params.len, .C, .Auto, ""); + } + }, + } } fn airShlWithOverflow(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { |
