aboutsummaryrefslogtreecommitdiff
path: root/src/codegen/spirv.zig
diff options
context:
space:
mode:
authorRobin Voetter <robin@voetter.nl>2024-04-06 13:37:25 +0200
committerGitHub <noreply@github.com>2024-04-06 13:37:25 +0200
commit39420838061a9049fbc889212836a9d4d2ab9af4 (patch)
treede835335172000e497871f9593bac17bcff882c0 /src/codegen/spirv.zig
parent3eeb70540d7f40526b4f4549deb6e2bc792bb3b2 (diff)
parent436f53f55d3191bfa56418d98130d763fa5a6b22 (diff)
downloadzig-39420838061a9049fbc889212836a9d4d2ab9af4.tar.gz
zig-39420838061a9049fbc889212836a9d4d2ab9af4.zip
Merge pull request #18984 from alichraghi/vector
spirv: implement `@divFloor`, `@floor`, `@mod` and `@mulWithOverflow`
Diffstat (limited to 'src/codegen/spirv.zig')
-rw-r--r--src/codegen/spirv.zig236
1 files changed, 198 insertions, 38 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig
index 3df43ae236..a47497d89d 100644
--- a/src/codegen/spirv.zig
+++ b/src/codegen/spirv.zig
@@ -1016,7 +1016,7 @@ const DeclGen = struct {
const elem_ty = Type.fromInterned(array_type.child);
const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
- const constituents = try self.gpa.alloc(IdRef, @as(u32, @intCast(ty.arrayLenIncludingSentinel(mod))));
+ const constituents = try self.gpa.alloc(IdRef, @intCast(ty.arrayLenIncludingSentinel(mod)));
defer self.gpa.free(constituents);
switch (aggregate.storage) {
@@ -1736,7 +1736,6 @@ const DeclGen = struct {
.EnumLiteral,
.ComptimeFloat,
.ComptimeInt,
- .Type,
=> unreachable, // Must be comptime.
else => |tag| return self.todo("Implement zig type '{}'", .{tag}),
@@ -2316,21 +2315,23 @@ const DeclGen = struct {
.sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
.mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
+
.abs => try self.airAbs(inst),
+ .floor => try self.airFloor(inst),
+
+ .div_floor => try self.airDivFloor(inst),
.div_float,
.div_float_optimized,
- // TODO: Check that this is the right operation.
.div_trunc,
- .div_trunc_optimized,
- => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv),
- // TODO: Check if this is the right operation
- .rem,
- .rem_optimized,
- => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem),
+ .div_trunc_optimized => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv),
+ .rem, .rem_optimized => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem),
+ .mod, .mod_optimized => try self.airArithOp(inst, .OpFMod, .OpSMod, .OpSMod),
+
.add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan),
.sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
+ .mul_with_overflow => try self.airMulOverflow(inst),
.shl_with_overflow => try self.airShlOverflow(inst),
.mul_add => try self.airMulAdd(inst),
@@ -2340,7 +2341,7 @@ const DeclGen = struct {
.splat => try self.airSplat(inst),
.reduce, .reduce_optimized => try self.airReduce(inst),
- .shuffle => try self.airShuffle(inst),
+ .shuffle => try self.airShuffle(inst),
.ptr_add => try self.airPtrAdd(inst),
.ptr_sub => try self.airPtrSub(inst),
@@ -2661,6 +2662,95 @@ const DeclGen = struct {
}
}
+ fn airDivFloor(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
+ const lhs_id = try self.resolve(bin_op.lhs);
+ const rhs_id = try self.resolve(bin_op.rhs);
+ const ty = self.typeOfIndex(inst);
+ const ty_ref = try self.resolveType(ty, .direct);
+ const info = self.arithmeticTypeInfo(ty);
+ switch (info.class) {
+ .composite_integer => unreachable, // TODO
+ .integer, .strange_integer => {
+ const zero_id = try self.constInt(ty_ref, 0);
+ const one_id = try self.constInt(ty_ref, 1);
+
+ // (a ^ b) > 0
+ const bin_bitwise_id = try self.binOpSimple(ty, lhs_id, rhs_id, .OpBitwiseXor);
+ const is_positive_id = try self.cmp(.gt, Type.bool, ty, bin_bitwise_id, zero_id);
+
+ // a / b
+ const positive_div_id = try self.arithOp(ty, lhs_id, rhs_id, .OpFDiv, .OpSDiv, .OpUDiv);
+
+ // - (abs(a) + abs(b) - 1) / abs(b)
+ const lhs_abs = try self.abs(ty, ty, lhs_id);
+ const rhs_abs = try self.abs(ty, ty, rhs_id);
+ const negative_div_lhs = try self.arithOp(
+ ty,
+ try self.arithOp(ty, lhs_abs, rhs_abs, .OpFAdd, .OpIAdd, .OpIAdd),
+ one_id,
+ .OpFSub,
+ .OpISub,
+ .OpISub,
+ );
+ const negative_div_id = try self.arithOp(ty, negative_div_lhs, rhs_abs, .OpFDiv, .OpSDiv, .OpUDiv);
+ const negated_negative_div_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpSNegate, .{
+ .id_result_type = self.typeId(ty_ref),
+ .id_result = negated_negative_div_id,
+ .operand = negative_div_id,
+ });
+
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+ .id_result_type = self.typeId(ty_ref),
+ .id_result = result_id,
+ .condition = is_positive_id,
+ .object_1 = positive_div_id,
+ .object_2 = negated_negative_div_id,
+ });
+ return result_id;
+ },
+ .float => {
+ const div_id = try self.arithOp(ty, lhs_id, rhs_id, .OpFDiv, .OpSDiv, .OpUDiv);
+ return try self.floor(ty, div_id);
+ },
+ .bool => unreachable,
+ }
+ }
+
+ fn airFloor(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op;
+ const operand_id = try self.resolve(un_op);
+ const result_ty = self.typeOfIndex(inst);
+ return try self.floor(result_ty, operand_id);
+ }
+
+ fn floor(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
+ const target = self.getTarget();
+ const ty_ref = try self.resolveType(ty, .direct);
+ const ext_inst: Word = switch (target.os.tag) {
+ .opencl => 25,
+ .vulkan => 8,
+ else => unreachable,
+ };
+ const set_id = switch (target.os.tag) {
+ .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"),
+ .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"),
+ else => unreachable,
+ };
+
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpExtInst, .{
+ .id_result_type = self.typeId(ty_ref),
+ .id_result = result_id,
+ .set = set_id,
+ .instruction = .{ .inst = ext_inst },
+ .id_ref_4 = &.{operand_id},
+ });
+ return result_id;
+ }
+
fn airArithOp(
self: *DeclGen,
inst: Air.Inst.Index,
@@ -2668,7 +2758,6 @@ const DeclGen = struct {
comptime sop: Opcode,
comptime uop: Opcode,
) !?IdRef {
-
// LHS and RHS are guaranteed to have the same type, and AIR guarantees
// the result to be the same as the LHS and RHS, which matches SPIR-V.
const ty = self.typeOfIndex(inst);
@@ -2700,8 +2789,8 @@ const DeclGen = struct {
return self.todo("binary operations for composite integers", .{});
},
.integer, .strange_integer => switch (info.signedness) {
- .signed => @as(usize, 1),
- .unsigned => @as(usize, 2),
+ .signed => 1,
+ .unsigned => 2,
},
.float => 0,
.bool => unreachable,
@@ -2737,12 +2826,16 @@ const DeclGen = struct {
}
fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
- const target = self.getTarget();
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_id = try self.resolve(ty_op.operand);
// Note: operand_ty may be signed, while ty is always unsigned!
const operand_ty = self.typeOf(ty_op.operand);
const result_ty = self.typeOfIndex(inst);
+ return try self.abs(result_ty, operand_ty, operand_id);
+ }
+
+ fn abs(self: *DeclGen, result_ty: Type, operand_ty: Type, operand_id: IdRef) !IdRef {
+ const target = self.getTarget();
const operand_info = self.arithmeticTypeInfo(operand_ty);
var wip = try self.elementWise(result_ty, false);
@@ -2907,6 +3000,61 @@ const DeclGen = struct {
);
}
+ fn airMulOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+ const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
+ const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
+ const lhs = try self.resolve(extra.lhs);
+ const rhs = try self.resolve(extra.rhs);
+
+ const result_ty = self.typeOfIndex(inst);
+ const operand_ty = self.typeOf(extra.lhs);
+ const ov_ty = result_ty.structFieldType(1, self.module);
+
+ const info = self.arithmeticTypeInfo(operand_ty);
+ switch (info.class) {
+ .composite_integer => return self.todo("overflow ops for composite integers", .{}),
+ .strange_integer, .integer => {},
+ .float, .bool => unreachable,
+ }
+
+ var wip_result = try self.elementWise(operand_ty, true);
+ defer wip_result.deinit();
+ var wip_ov = try self.elementWise(ov_ty, true);
+ defer wip_ov.deinit();
+
+ const zero_id = try self.constInt(wip_result.ty_ref, 0);
+ const zero_ov_id = try self.constInt(wip_ov.ty_ref, 0);
+ const one_ov_id = try self.constInt(wip_ov.ty_ref, 1);
+
+ for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| {
+ const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
+ const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i);
+
+ result_id.* = try self.arithOp(wip_result.ty, lhs_elem_id, rhs_elem_id, .OpFMul, .OpIMul, .OpIMul);
+
+ // (a != 0) and (x / a != b)
+ const not_zero_id = try self.cmp(.neq, Type.bool, wip_result.ty, lhs_elem_id, zero_id);
+ const res_rhs_id = try self.arithOp(wip_result.ty, result_id.*, lhs_elem_id, .OpFDiv, .OpSDiv, .OpUDiv);
+ const res_rhs_not_rhs_id = try self.cmp(.neq, Type.bool, wip_result.ty, res_rhs_id, rhs_elem_id);
+ const cond_id = try self.binOpSimple(Type.bool, not_zero_id, res_rhs_not_rhs_id, .OpLogicalAnd);
+
+ ov_id.* = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+ .id_result_type = wip_ov.ty_id,
+ .id_result = ov_id.*,
+ .condition = cond_id,
+ .object_1 = one_ov_id,
+ .object_2 = zero_ov_id,
+ });
+ }
+
+ return try self.constructStruct(
+ result_ty,
+ &.{ operand_ty, ov_ty },
+ &.{ try wip_result.finalize(), try wip_ov.finalize() },
+ );
+ }
+
fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
const mod = self.module;
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
@@ -3692,19 +3840,22 @@ const DeclGen = struct {
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_ty = self.typeOf(ty_op.operand);
const operand_id = try self.resolve(ty_op.operand);
- const operand_info = self.arithmeticTypeInfo(operand_ty);
- const dest_ty = self.typeOfIndex(inst);
- const dest_ty_id = try self.resolveTypeId(dest_ty);
+ const result_ty = self.typeOfIndex(inst);
+ const result_ty_ref = try self.resolveType(result_ty, .direct);
+ return try self.floatFromInt(result_ty_ref, operand_ty, operand_id);
+ }
+ fn floatFromInt(self: *DeclGen, result_ty_ref: CacheRef, operand_ty: Type, operand_id: IdRef) !IdRef {
+ const operand_info = self.arithmeticTypeInfo(operand_ty);
const result_id = self.spv.allocId();
switch (operand_info.signedness) {
.signed => try self.func.body.emit(self.spv.gpa, .OpConvertSToF, .{
- .id_result_type = dest_ty_id,
+ .id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.signed_value = operand_id,
}),
.unsigned => try self.func.body.emit(self.spv.gpa, .OpConvertUToF, .{
- .id_result_type = dest_ty_id,
+ .id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.unsigned_value = operand_id,
}),
@@ -3715,19 +3866,22 @@ const DeclGen = struct {
fn airIntFromFloat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const operand_id = try self.resolve(ty_op.operand);
- const dest_ty = self.typeOfIndex(inst);
- const dest_info = self.arithmeticTypeInfo(dest_ty);
- const dest_ty_id = try self.resolveTypeId(dest_ty);
+ const result_ty = self.typeOfIndex(inst);
+ return try self.intFromFloat(result_ty, operand_id);
+ }
+ fn intFromFloat(self: *DeclGen, result_ty: Type, operand_id: IdRef) !IdRef {
+ const result_info = self.arithmeticTypeInfo(result_ty);
+ const result_ty_ref = try self.resolveType(result_ty, .direct);
const result_id = self.spv.allocId();
- switch (dest_info.signedness) {
+ switch (result_info.signedness) {
.signed => try self.func.body.emit(self.spv.gpa, .OpConvertFToS, .{
- .id_result_type = dest_ty_id,
+ .id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.float_value = operand_id,
}),
.unsigned => try self.func.body.emit(self.spv.gpa, .OpConvertFToU, .{
- .id_result_type = dest_ty_id,
+ .id_result_type = self.typeId(result_ty_ref),
.id_result = result_id,
.float_value = operand_id,
}),
@@ -5237,20 +5391,21 @@ const DeclGen = struct {
fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
const mod = self.module;
+ const target = self.getTarget();
const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
const cond_ty = self.typeOf(pl_op.operand);
const cond = try self.resolve(pl_op.operand);
- const cond_indirect = try self.convertToIndirect(cond_ty, cond);
+ var cond_indirect = try self.convertToIndirect(cond_ty, cond);
const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload);
const cond_words: u32 = switch (cond_ty.zigTypeTag(mod)) {
- .Bool => 1,
+ .Bool, .ErrorSet => 1,
.Int => blk: {
const bits = cond_ty.intInfo(mod).bits;
const backing_bits = self.backingIntBits(bits) orelse {
return self.todo("implement composite int switch", .{});
};
- break :blk if (backing_bits <= 32) @as(u32, 1) else 2;
+ break :blk if (backing_bits <= 32) 1 else 2;
},
.Enum => blk: {
const int_ty = cond_ty.intTagType(mod);
@@ -5258,10 +5413,14 @@ const DeclGen = struct {
const backing_bits = self.backingIntBits(int_info.bits) orelse {
return self.todo("implement composite int switch", .{});
};
- break :blk if (backing_bits <= 32) @as(u32, 1) else 2;
+ break :blk if (backing_bits <= 32) 1 else 2;
+ },
+ .Pointer => blk: {
+ cond_indirect = try self.intFromPtr(cond_indirect);
+ break :blk target.ptrBitWidth() / 32;
},
- .ErrorSet => 1,
- else => return self.todo("implement switch for type {s}", .{@tagName(cond_ty.zigTypeTag(mod))}), // TODO: Figure out which types apply here, and work around them as we can only do integers.
+ // TODO: Figure out which types apply here, and work around them as we can only do integers.
+ else => return self.todo("implement switch for type {s}", .{@tagName(cond_ty.zigTypeTag(mod))}),
};
const num_cases = switch_br.data.cases_len;
@@ -5308,7 +5467,7 @@ const DeclGen = struct {
for (0..num_cases) |case_i| {
// SPIR-V needs a literal here, which' width depends on the case condition.
const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
- const items = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[case.end..][0..case.data.items_len]));
+ const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
extra_index = case.end + case.data.items_len + case_body.len;
@@ -5316,13 +5475,14 @@ const DeclGen = struct {
for (items) |item| {
const value = (try self.air.value(item, mod)) orelse unreachable;
- const int_val = switch (cond_ty.zigTypeTag(mod)) {
- .Bool, .Int => if (cond_ty.isSignedInt(mod)) @as(u64, @bitCast(value.toSignedInt(mod))) else value.toUnsignedInt(mod),
+ const int_val: u64 = switch (cond_ty.zigTypeTag(mod)) {
+ .Bool, .Int => if (cond_ty.isSignedInt(mod)) @bitCast(value.toSignedInt(mod)) else value.toUnsignedInt(mod),
.Enum => blk: {
// TODO: figure out of cond_ty is correct (something with enum literals)
break :blk (try value.intFromEnum(cond_ty, mod)).toUnsignedInt(mod); // TODO: composite integer constants
},
.ErrorSet => value.getErrorInt(mod),
+ .Pointer => value.toUnsignedInt(mod),
else => unreachable,
};
const int_lit: spec.LiteralContextDependentNumber = switch (cond_words) {
@@ -5438,14 +5598,14 @@ const DeclGen = struct {
const extra = self.air.extraData(Air.Asm, ty_pl.payload);
const is_volatile = @as(u1, @truncate(extra.data.flags >> 31)) != 0;
- const clobbers_len = @as(u31, @truncate(extra.data.flags));
+ const clobbers_len: u31 = @truncate(extra.data.flags);
if (!is_volatile and self.liveness.isUnused(inst)) return null;
var extra_i: usize = extra.end;
- const outputs = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[extra_i..][0..extra.data.outputs_len]));
+ const outputs: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra_i..][0..extra.data.outputs_len]);
extra_i += outputs.len;
- const inputs = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[extra_i..][0..extra.data.inputs_len]));
+ const inputs: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra_i..][0..extra.data.inputs_len]);
extra_i += inputs.len;
if (outputs.len > 1) {
@@ -5567,7 +5727,7 @@ const DeclGen = struct {
const mod = self.module;
const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
const extra = self.air.extraData(Air.Call, pl_op.payload);
- const args = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[extra.end..][0..extra.data.args_len]));
+ const args: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra.end..][0..extra.data.args_len]);
const callee_ty = self.typeOf(pl_op.operand);
const zig_fn_ty = switch (callee_ty.zigTypeTag(mod)) {
.Fn => callee_ty,