aboutsummaryrefslogtreecommitdiff
path: root/src/codegen
diff options
context:
space:
mode:
authorRobin Voetter <robin@voetter.nl>2024-01-19 01:12:56 +0100
committerRobin Voetter <robin@voetter.nl>2024-02-04 19:09:27 +0100
commit761594e2260eb780ab1861568e38a7066a7513df (patch)
treec5b5edf66d371f19708c65ccfb12d2ae36824ef6 /src/codegen
parent2f815853dcae49bbfd109675cde1f4097b75c8cc (diff)
downloadzig-761594e2260eb780ab1861568e38a7066a7513df.tar.gz
zig-761594e2260eb780ab1861568e38a7066a7513df.zip
spirv: reduce, reduce_optimized
Diffstat (limited to 'src/codegen')
-rw-r--r--src/codegen/spirv.zig75
1 files changed, 74 insertions, 1 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig
index 286b45f973..33e068032c 100644
--- a/src/codegen/spirv.zig
+++ b/src/codegen/spirv.zig
@@ -2187,6 +2187,7 @@ const DeclGen = struct {
.sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
.shl_with_overflow => try self.airShlOverflow(inst),
+ .reduce, .reduce_optimized => try self.airReduce(inst),
.shuffle => try self.airShuffle(inst),
.ptr_add => try self.airPtrAdd(inst),
@@ -2388,9 +2389,14 @@ const DeclGen = struct {
const lhs_id = try self.resolve(bin_op.lhs);
const rhs_id = try self.resolve(bin_op.rhs);
const result_ty = self.typeOfIndex(inst);
- const result_ty_ref = try self.resolveType(result_ty, .direct);
+ return try self.minMax(result_ty, op, lhs_id, rhs_id);
+ }
+
+ fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
+ const result_ty_ref = try self.resolveType(result_ty, .direct);
const info = try self.arithmeticTypeInfo(result_ty);
+
// TODO: Use fmin for OpenCL
const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
const selection_id = switch (info.class) {
@@ -2758,6 +2764,73 @@ const DeclGen = struct {
);
}
+ fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ if (self.liveness.isUnused(inst)) return null;
+ const mod = self.module;
+ const reduce = self.air.instructions.items(.data)[@intFromEnum(inst)].reduce;
+ const operand = try self.resolve(reduce.operand);
+ const operand_ty = self.typeOf(reduce.operand);
+ const scalar_ty = operand_ty.scalarType(mod);
+ const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
+ const scalar_ty_id = self.typeId(scalar_ty_ref);
+
+ const info = try self.arithmeticTypeInfo(operand_ty);
+
+ var result_id = try self.extractField(scalar_ty, operand, 0);
+ const len = operand_ty.vectorLen(mod);
+
+ switch (reduce.operation) {
+ .Min, .Max => |op| {
+ const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt;
+ for (1..len) |i| {
+ const lhs = result_id;
+ const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+ result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs);
+ }
+
+ return result_id;
+ },
+ else => {},
+ }
+
+ const opcode: Opcode = switch (info.class) {
+ .bool => switch (reduce.operation) {
+ .And => .OpLogicalAnd,
+ .Or => .OpLogicalOr,
+ .Xor => .OpLogicalNotEqual,
+ else => unreachable,
+ },
+ .strange_integer, .integer => switch (reduce.operation) {
+ .And => .OpBitwiseAnd,
+ .Or => .OpBitwiseOr,
+ .Xor => .OpBitwiseXor,
+ .Add => .OpIAdd,
+ .Mul => .OpIMul,
+ else => unreachable,
+ },
+ .float => switch (reduce.operation) {
+ .Add => .OpFAdd,
+ .Mul => .OpFMul,
+ else => unreachable,
+ },
+ .composite_integer => unreachable, // TODO
+ };
+
+ for (1..len) |i| {
+ const lhs = result_id;
+ const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+ result_id = self.spv.allocId();
+
+ try self.func.body.emitRaw(self.spv.gpa, opcode, 4);
+ self.func.body.writeOperand(spec.IdResultType, scalar_ty_id);
+ self.func.body.writeOperand(spec.IdResult, result_id);
+ self.func.body.writeOperand(spec.IdResultType, lhs);
+ self.func.body.writeOperand(spec.IdResultType, rhs);
+ }
+
+ return result_id;
+ }
+
fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
const mod = self.module;
if (self.liveness.isUnused(inst)) return null;