diff options
| author | Robin Voetter <robin@voetter.nl> | 2023-04-10 20:34:15 +0200 |
|---|---|---|
| committer | Robin Voetter <robin@voetter.nl> | 2023-05-11 20:31:50 +0200 |
| commit | 0bae2caaf382dfb168ee404e3ffb717975f8289b (patch) | |
| tree | d3601779e10b0a4a2368498eadf62fcb4558a808 /src/codegen/spirv.zig | |
| parent | dfecf89d06dc2caad41ff54b05240506ea2c47e8 (diff) | |
| download | zig-0bae2caaf382dfb168ee404e3ffb717975f8289b.tar.gz zig-0bae2caaf382dfb168ee404e3ffb717975f8289b.zip | |
spirv: lower air try
Implements code generation for the try air tag. This commit also adds
a utility `errorUnionLayout` function that helps keeping the layout
of a spir-v error union consistent.
Diffstat (limited to 'src/codegen/spirv.zig')
| -rw-r--r-- | src/codegen/spirv.zig | 126 |
1 files changed, 105 insertions, 21 deletions
diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 7c12335575..d4d8ed312e 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -765,21 +765,18 @@ pub const DeclGen = struct { const is_pl = val.errorUnionIsPayload(); const error_val = if (!is_pl) val else Value.initTag(.zero); - if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { + const eu_layout = dg.errorUnionLayout(payload_ty); + if (!eu_layout.payload_has_bits) { return try self.lower(Type.anyerror, error_val); } - const payload_align = payload_ty.abiAlignment(target); - const error_align = Type.anyerror.abiAlignment(target); - const payload_size = payload_ty.abiSize(target); const error_size = Type.anyerror.abiAlignment(target); const ty_size = ty.abiSize(target); const padding = ty_size - payload_size - error_size; - const payload_val = if (val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef); - if (error_align > payload_align) { + if (eu_layout.error_first) { try self.lower(Type.anyerror, error_val); try self.lower(payload_ty, payload_val); } else { @@ -1277,18 +1274,16 @@ pub const DeclGen = struct { .ErrorUnion => { const payload_ty = ty.errorUnionPayload(); const error_ty_ref = try self.resolveType(Type.anyerror, .indirect); - if (!payload_ty.hasRuntimeBitsIgnoreComptime()) { + + const eu_layout = self.errorUnionLayout(payload_ty); + if (!eu_layout.payload_has_bits) { return error_ty_ref; } const payload_ty_ref = try self.resolveType(payload_ty, .indirect); - const payload_align = payload_ty.abiAlignment(target); - const error_align = Type.anyerror.abiAlignment(target); - var members = std.BoundedArray(SpvType.Payload.Struct.Member, 2){}; - // Similar to unions, we're going to put the most aligned member first. - if (error_align > payload_align) { + if (eu_layout.error_first) { // Put the error first members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" }); members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" }); @@ -1336,6 +1331,34 @@ pub const DeclGen = struct { }; } + const ErrorUnionLayout = struct { + payload_has_bits: bool, + error_first: bool, + + fn errorFieldIndex(self: @This()) u32 { + assert(self.payload_has_bits); + return if (self.error_first) 0 else 1; + } + + fn payloadFieldIndex(self: @This()) u32 { + assert(self.payload_has_bits); + return if (self.error_first) 1 else 0; + } + }; + + fn errorUnionLayout(self: *DeclGen, payload_ty: Type) ErrorUnionLayout { + const target = self.getTarget(); + + const error_align = Type.anyerror.abiAlignment(target); + const payload_align = payload_ty.abiAlignment(target); + + const error_first = error_align > payload_align; + return .{ + .payload_has_bits = payload_ty.hasRuntimeBitsIgnoreComptime(), + .error_first = error_first, + }; + } + /// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure. /// In order to be able to run tests, we "temporarily" lower test kernels into separate entry- /// points. The test executor will then be able to invoke these to run the tests. @@ -1585,6 +1608,7 @@ pub const DeclGen = struct { .loop => return self.airLoop(inst), .ret => return self.airRet(inst), .ret_load => return self.airRetLoad(inst), + .@"try" => try self.airTry(inst), .switch_br => return self.airSwitchBr(inst), .unreach => return self.airUnreach(), @@ -1752,16 +1776,15 @@ pub const DeclGen = struct { const operand_ty_id = try self.resolveTypeId(operand_ty); const result_type_id = try self.resolveTypeId(result_ty); - const overflow_member_ty = try self.intType(.unsigned, info.bits); - const overflow_member_ty_id = self.typeId(overflow_member_ty); + const overflow_member_ty_ref = try self.intType(.unsigned, info.bits); const op_result_id = blk: { // Construct the SPIR-V result type. // It is almost the same as the zig one, except that the fields must be the same type // and they must be unsigned. const overflow_result_ty_ref = try self.spv.simpleStructType(&.{ - .{ .ty = overflow_member_ty, .name = "res" }, - .{ .ty = overflow_member_ty, .name = "ov" }, + .{ .ty = overflow_member_ty_ref, .name = "res" }, + .{ .ty = overflow_member_ty_ref, .name = "ov" }, }); const result_id = self.spv.allocId(); try self.func.body.emit(self.spv.gpa, .OpIAddCarry, .{ @@ -1775,8 +1798,8 @@ pub const DeclGen = struct { // Now convert the SPIR-V flavor result into a Zig-flavor result. // First, extract the two fields. - const unsigned_result = try self.extractField(overflow_member_ty_id, op_result_id, 0); - const overflow = try self.extractField(overflow_member_ty_id, op_result_id, 1); + const unsigned_result = try self.extractField(overflow_member_ty_ref, op_result_id, 0); + const overflow = try self.extractField(overflow_member_ty_ref, op_result_id, 1); // We need to convert the results to the types that Zig expects here. // The `result` is the same type except unsigned, so we can just bitcast that. @@ -1954,15 +1977,16 @@ pub const DeclGen = struct { return result_id; } - fn extractField(self: *DeclGen, result_ty: IdResultType, object: IdRef, field: u32) !IdRef { + fn extractField(self: *DeclGen, result_ty_ref: SpvType.Ref, object: IdRef, field: u32) !IdRef { const result_id = self.spv.allocId(); const indexes = [_]u32{field}; try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{ - .id_result_type = result_ty, + .id_result_type = self.typeId(result_ty_ref), .id_result = result_id, .composite = object, .indexes = &indexes, }); + // TODO: Convert bools, direct structs should have their field types as indirect values. return result_id; } @@ -1970,7 +1994,7 @@ pub const DeclGen = struct { if (self.liveness.isUnused(inst)) return null; const ty_op = self.air.instructions.items(.data)[inst].ty_op; return try self.extractField( - try self.resolveTypeId(self.air.typeOfIndex(inst)), + try self.resolveType(self.air.typeOfIndex(inst), .direct), try self.resolve(ty_op.operand), field, ); @@ -2451,6 +2475,66 @@ pub const DeclGen = struct { }); } + fn airTry(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { + const pl_op = self.air.instructions.items(.data)[inst].pl_op; + const err_union_id = try self.resolve(pl_op.operand); + const extra = self.air.extraData(Air.Try, pl_op.payload); + const body = self.air.extra[extra.end..][0..extra.data.body_len]; + + const err_union_ty = self.air.typeOf(pl_op.operand); + const payload_ty = self.air.typeOfIndex(inst); + + const err_ty_ref = try self.resolveType(Type.anyerror, .direct); + const payload_ty_ref = try self.resolveType(payload_ty, .direct); + const bool_ty_ref = try self.resolveType(Type.bool, .direct); + + const eu_layout = self.errorUnionLayout(payload_ty); + + if (!err_union_ty.errorUnionSet().errorSetIsEmpty()) { + const err_id = if (eu_layout.payload_has_bits) + try self.extractField(err_ty_ref, err_union_id, eu_layout.errorFieldIndex()) + else + err_union_id; + + const zero_id = try self.constInt(err_ty_ref, 0); + const is_err_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{ + .id_result_type = self.typeId(bool_ty_ref), + .id_result = is_err_id, + .operand_1 = err_id, + .operand_2 = zero_id, + }); + + // When there is an error, we must evaluate `body`. Otherwise we must continue + // with the current body. + // Just generate a new block here, then generate a new block inline for the remainder of the body. + + const err_block = self.spv.allocId(); + const ok_block = self.spv.allocId(); + + // TODO: Merge block + try self.func.body.emit(self.spv.gpa, .OpBranchConditional, .{ + .condition = is_err_id, + .true_label = err_block, + .false_label = ok_block, + }); + + try self.beginSpvBlock(err_block); + try self.genBody(body); + + try self.beginSpvBlock(ok_block); + // Now just extract the payload, if required. + } + if (self.liveness.isUnused(inst)) { + return null; + } + if (!eu_layout.payload_has_bits) { + return null; + } + + return try self.extractField(payload_ty_ref, err_union_id, eu_layout.payloadFieldIndex()); + } + fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void { const target = self.getTarget(); const pl_op = self.air.instructions.items(.data)[inst].pl_op; |
