diff options
| author | Robin Voetter <robin@voetter.nl> | 2024-01-15 23:38:43 +0100 |
|---|---|---|
| committer | Robin Voetter <robin@voetter.nl> | 2024-02-04 19:09:18 +0100 |
| commit | 403c6262bb4c9087f1d0138fc83fe4dd979864ad (patch) | |
| tree | 339f7d3d1c6c318a789129472299bda7653737e5 | |
| parent | cb9e20da00a2c33706e2c7bf2008887c6c72a896 (diff) | |
| download | zig-403c6262bb4c9087f1d0138fc83fe4dd979864ad.tar.gz zig-403c6262bb4c9087f1d0138fc83fe4dd979864ad.zip | |
spirv: use new vector stuff for arithOp and shift
| -rw-r--r-- | src/codegen/spirv.zig | 156 |
1 files changed, 82 insertions, 74 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index a8bc385f7a..a3b8a6c8f6 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -1782,6 +1782,19 @@ const DeclGen = struct { wip.dg.gpa.free(wip.results); } + /// Return the scalar type of an input vector. This type is expected to be a vector + /// if `wip.is_vector`, and a scalar otherwise. + fn scalarType(wip: WipElementWise, ty: Type) Type { + const mod = wip.dg.module; + if (wip.is_vector) { + assert(ty.isVector(mod)); + return ty.childType(mod); + } else { + assert(!ty.isVector(mod)); + return ty; + } + } + /// Utility function to extract the element at a particular index in an /// input vector. This type is expected to be a vector if `wip.is_vector`, and /// a scalar otherwise. @@ -1789,7 +1802,7 @@ const DeclGen = struct { const mod = wip.dg.module; if (wip.is_vector) { assert(ty.isVector(mod)); - return try wip.dg.extractField(ty, value, @intCast(index)); + return try wip.dg.extractField(ty.childType(mod), value, @intCast(index)); } else { assert(!ty.isVector(mod)); assert(index == 0); @@ -2331,36 +2344,45 @@ const DeclGen = struct { const lhs_id = try self.resolve(bin_op.lhs); const rhs_id = try self.resolve(bin_op.rhs); const result_ty = self.typeOfIndex(inst); - const result_ty_ref = try self.resolveType(result_ty, .direct); - - const result_id = self.spv.allocId(); // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that, // so just manually upcast it if required. - const shift_ty_ref = try self.resolveType(self.typeOf(bin_op.rhs), .direct); - const shift_id = if (shift_ty_ref != result_ty_ref) blk: { - const shift_id = self.spv.allocId(); - try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ - .id_result_type = self.typeId(result_ty_ref), - .id_result = shift_id, - .unsigned_value = rhs_id, - }); - break :blk shift_id; - } else rhs_id; + // TODO(robin) - const args = .{ - .id_result_type = self.typeId(result_ty_ref), - .id_result = result_id, - .base = lhs_id, - .shift = shift_id, - }; + var wip = try self.elementWise(result_ty); + defer wip.deinit(); - if (result_ty.isSignedInt(mod)) { - try self.func.body.emit(self.spv.gpa, signed, args); - } else { - try self.func.body.emit(self.spv.gpa, unsigned, args); + const shift_ty = wip.scalarType(self.typeOf(bin_op.rhs)); + const shift_ty_ref = try self.resolveType(shift_ty, .direct); + + for (0..wip.results.len) |i| { + const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i); + + const shift_id = if (shift_ty_ref != wip.result_ty_ref) blk: { + const shift_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ + .id_result_type = wip.scalar_ty_id, + .id_result = shift_id, + .unsigned_value = rhs_elem_id, + }); + break :blk shift_id; + } else rhs_elem_id; + + const args = .{ + .id_result_type = wip.scalar_ty_id, + .id_result = wip.allocId(i), + .base = lhs_elem_id, + .shift = shift_id, + }; + + if (result_ty.isSignedInt(mod)) { + try self.func.body.emit(self.spv.gpa, signed, args); + } else { + try self.func.body.emit(self.spv.gpa, unsigned, args); + } } - return result_id; + return try wip.finalize(); } fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef { @@ -2483,35 +2505,14 @@ const DeclGen = struct { fn arithOp( self: *DeclGen, ty: Type, - lhs_id_: IdRef, - rhs_id_: IdRef, + lhs_id: IdRef, + rhs_id: IdRef, comptime fop: Opcode, comptime sop: Opcode, comptime uop: Opcode, /// true if this operation holds under modular arithmetic. comptime modular: bool, ) !IdRef { - var rhs_id = rhs_id_; - var lhs_id = lhs_id_; - - const mod = self.module; - const result_ty_ref = try self.resolveType(ty, .direct); - - if (ty.isVector(mod)) { - const child_ty = ty.childType(mod); - const vector_len = ty.vectorLen(mod); - const constituents = try self.gpa.alloc(IdRef, vector_len); - defer self.gpa.free(constituents); - - for (constituents, 0..) |*constituent, i| { - const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i)); - const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i)); - constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular); - } - - return self.constructArray(ty, constituents); - } - // Binary operations are generally applicable to both scalar and vector operations // in SPIR-V, but int and float versions of operations require different opcodes. const info = try self.arithmeticTypeInfo(ty); @@ -2520,17 +2521,7 @@ const DeclGen = struct { .composite_integer => { return self.todo("binary operations for composite integers", .{}); }, - .strange_integer => blk: { - if (!modular) { - lhs_id = try self.normalizeInt(result_ty_ref, lhs_id, info); - rhs_id = try self.normalizeInt(result_ty_ref, rhs_id, info); - } - break :blk switch (info.signedness) { - .signed => @as(usize, 1), - .unsigned => @as(usize, 2), - }; - }, - .integer => switch (info.signedness) { + .integer, .strange_integer => switch (info.signedness) { .signed => @as(usize, 1), .unsigned => @as(usize, 2), }, @@ -2538,24 +2529,41 @@ const DeclGen = struct { .bool => unreachable, }; - const result_id = self.spv.allocId(); - const operands = .{ - .id_result_type = self.typeId(result_ty_ref), - .id_result = result_id, - .operand_1 = lhs_id, - .operand_2 = rhs_id, - }; + var wip = try self.elementWise(ty); + defer wip.deinit(); + for (0..wip.results.len) |i| { + const lhs_elem_id = try wip.elementAt(ty, lhs_id, i); + const rhs_elem_id = try wip.elementAt(ty, rhs_id, i); - switch (opcode_index) { - 0 => try self.func.body.emit(self.spv.gpa, fop, operands), - 1 => try self.func.body.emit(self.spv.gpa, sop, operands), - 2 => try self.func.body.emit(self.spv.gpa, uop, operands), - else => unreachable, + const lhs_norm_id = if (modular and info.class == .strange_integer) + try self.normalizeInt(wip.scalar_ty_ref, lhs_elem_id, info) + else + lhs_elem_id; + + const rhs_norm_id = if (modular and info.class == .strange_integer) + try self.normalizeInt(wip.scalar_ty_ref, rhs_elem_id, info) + else + rhs_elem_id; + + const operands = .{ + .id_result_type = wip.scalar_ty_id, + .id_result = wip.allocId(i), + .operand_1 = lhs_norm_id, + .operand_2 = rhs_norm_id, + }; + + switch (opcode_index) { + 0 => try self.func.body.emit(self.spv.gpa, fop, operands), + 1 => try self.func.body.emit(self.spv.gpa, sop, operands), + 2 => try self.func.body.emit(self.spv.gpa, uop, operands), + else => unreachable, + } + + // TODO: Trap on overflow? Probably going to be annoying. + // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap. } - // TODO: Trap on overflow? Probably going to be annoying. - // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap. - return result_id; + return try wip.finalize(); } fn airAddSubOverflow( |
