aboutsummaryrefslogtreecommitdiff
path: root/src/codegen/spirv.zig
diff options
context:
space:
mode:
authorRobin Voetter <robin@voetter.nl>2024-03-17 18:18:35 +0100
committerRobin Voetter <robin@voetter.nl>2024-03-18 19:13:50 +0100
commit335ff5a5f422256765892dfb4ffebeeb2b9d581b (patch)
tree7dcc7156472aa7972a1635fd7a2c25fdd3b25382 /src/codegen/spirv.zig
parent8ed134243ac9b3d1286153f95495176875472669 (diff)
downloadzig-335ff5a5f422256765892dfb4ffebeeb2b9d581b.tar.gz
zig-335ff5a5f422256765892dfb4ffebeeb2b9d581b.zip
spirv: fix optional comparison
Diffstat (limited to 'src/codegen/spirv.zig')
-rw-r--r--src/codegen/spirv.zig79
1 files changed, 60 insertions, 19 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig
index 038265f838..7117ae7e7a 100644
--- a/src/codegen/spirv.zig
+++ b/src/codegen/spirv.zig
@@ -3307,35 +3307,76 @@ const DeclGen = struct {
else
try self.convertToDirect(Type.bool, rhs_id);
- const valid_cmp_id = try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
- return valid_cmp_id;
+ return try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
}
- // TODO: Should we short circuit here? It shouldn't affect correctness, but
- // perhaps it will generate more efficient code.
+ // a = lhs_valid
+ // b = rhs_valid
+ // c = lhs_pl == rhs_pl
+ //
+ // For op == .eq we have:
+ // a == b && a -> c
+ // = a == b && (!a || c)
+ //
+ // For op == .neq we have
+ // a == b && a -> c
+ // = !(a == b && a -> c)
+ // = a != b || !(a -> c
+ // = a != b || !(!a || c)
+ // = a != b || a && !c
const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0);
const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0);
- const pl_cmp_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
-
- // op == .eq => lhs_valid == rhs_valid && lhs_pl == rhs_pl
- // op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl
-
- const result_id = self.spv.allocId();
- const args = .{
- .id_result_type = self.typeId(bool_ty_ref),
- .id_result = result_id,
- .operand_1 = valid_cmp_id,
- .operand_2 = pl_cmp_id,
- };
switch (op) {
- .eq => try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, args),
- .neq => try self.func.body.emit(self.spv.gpa, .OpLogicalOr, args),
+ .eq => {
+ const valid_eq_id = try self.cmp(.eq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
+ const pl_eq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
+ const lhs_not_valid_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = lhs_not_valid_id,
+ .operand = lhs_valid_id,
+ });
+ const impl_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = impl_id,
+ .operand_1 = lhs_not_valid_id,
+ .operand_2 = pl_eq_id,
+ });
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = result_id,
+ .operand_1 = valid_eq_id,
+ .operand_2 = impl_id,
+ });
+ return result_id;
+ },
+ .neq => {
+ const valid_neq_id = try self.cmp(.neq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
+ const pl_neq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
+
+ const impl_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = impl_id,
+ .operand_1 = lhs_valid_id,
+ .operand_2 = pl_neq_id,
+ });
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
+ .id_result_type = self.typeId(bool_ty_ref),
+ .id_result = result_id,
+ .operand_1 = valid_neq_id,
+ .operand_2 = impl_id,
+ });
+ return result_id;
+ },
else => unreachable,
}
- return result_id;
},
.Vector => {
var wip = try self.elementWise(result_ty, true);