aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgwenzek <gwenzek@users.noreply.github.com>2022-02-21 20:05:27 +0100
committerGitHub <noreply@github.com>2022-02-21 14:05:27 -0500
commit628e9e6d040979bd0a2cba05e854014dee5a7d55 (patch)
treeb2d86569b54d792808b608402cc9d1ea8ec7d161
parentd8da9a01fcfebf14a9f262cabf36f1c0767d2e2b (diff)
downloadzig-628e9e6d040979bd0a2cba05e854014dee5a7d55.tar.gz
zig-628e9e6d040979bd0a2cba05e854014dee5a7d55.zip
enable Gpu address spaces (#10884)
-rw-r--r--lib/std/builtin.zig6
-rw-r--r--src/Sema.zig5
-rw-r--r--src/codegen/llvm.zig10
-rw-r--r--test/cases.zig1
-rw-r--r--test/stage2/nvptx.zig57
5 files changed, 78 insertions, 1 deletions
diff --git a/lib/std/builtin.zig b/lib/std/builtin.zig
index ad6b5f052b..dda8d846fa 100644
--- a/lib/std/builtin.zig
+++ b/lib/std/builtin.zig
@@ -157,6 +157,12 @@ pub const AddressSpace = enum {
gs,
fs,
ss,
+ // GPU address spaces
+ global,
+ constant,
+ param,
+ shared,
+ local,
};
/// This data structure is used by the Zig language code generation and
diff --git a/src/Sema.zig b/src/Sema.zig
index 93cbb8f2cf..91cddc18ff 100644
--- a/src/Sema.zig
+++ b/src/Sema.zig
@@ -18006,10 +18006,14 @@ pub fn analyzeAddrspace(
const address_space = addrspace_tv.val.toEnum(std.builtin.AddressSpace);
const target = sema.mod.getTarget();
const arch = target.cpu.arch;
+ const is_gpu = arch == .nvptx or arch == .nvptx64;
const supported = switch (address_space) {
.generic => true,
.gs, .fs, .ss => (arch == .i386 or arch == .x86_64) and ctx == .pointer,
+ // TODO: check that .shared and .local are left uninitialized
+ .global, .param, .shared, .local => is_gpu,
+ .constant => is_gpu and (ctx == .constant),
};
if (!supported) {
@@ -18020,7 +18024,6 @@ pub fn analyzeAddrspace(
.constant => "constant values",
.pointer => "pointers",
};
-
return sema.fail(
block,
src,
diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig
index ed57562e4c..f40dbf41de 100644
--- a/src/codegen/llvm.zig
+++ b/src/codegen/llvm.zig
@@ -801,6 +801,16 @@ pub const DeclGen = struct {
.gs => llvm.address_space.x86.gs,
.fs => llvm.address_space.x86.fs,
.ss => llvm.address_space.x86.ss,
+ else => unreachable,
+ },
+ .nvptx, .nvptx64 => switch (address_space) {
+ .generic => llvm.address_space.default,
+ .global => llvm.address_space.nvptx.global,
+ .constant => llvm.address_space.nvptx.constant,
+ .param => llvm.address_space.nvptx.param,
+ .shared => llvm.address_space.nvptx.shared,
+ .local => llvm.address_space.nvptx.local,
+ else => unreachable,
},
else => switch (address_space) {
.generic => llvm.address_space.default,
diff --git a/test/cases.zig b/test/cases.zig
index 20dece3c7d..a65baeeef6 100644
--- a/test/cases.zig
+++ b/test/cases.zig
@@ -16,4 +16,5 @@ pub fn addCases(ctx: *TestContext) !void {
try @import("stage2/riscv64.zig").addCases(ctx);
try @import("stage2/plan9.zig").addCases(ctx);
try @import("stage2/x86_64.zig").addCases(ctx);
+ try @import("stage2/nvptx.zig").addCases(ctx);
}
diff --git a/test/stage2/nvptx.zig b/test/stage2/nvptx.zig
new file mode 100644
index 0000000000..95ca79d448
--- /dev/null
+++ b/test/stage2/nvptx.zig
@@ -0,0 +1,57 @@
+const std = @import("std");
+const TestContext = @import("../../src/test.zig").TestContext;
+
+const nvptx = std.zig.CrossTarget{
+ .cpu_arch = .nvptx64,
+ .os_tag = .cuda,
+};
+
+pub fn addCases(ctx: *TestContext) !void {
+ {
+ var case = ctx.exeUsingLlvmBackend("simple addition and subtraction", nvptx);
+
+ case.compiles(
+ \\fn add(a: i32, b: i32) i32 {
+ \\ return a + b;
+ \\}
+ \\
+ \\pub export fn main(a: i32, out: *i32) callconv(.PtxKernel) void {
+ \\ const x = add(a, 7);
+ \\ var y = add(2, 0);
+ \\ y -= x;
+ \\ out.* = y;
+ \\}
+ );
+ }
+
+ {
+ var case = ctx.exeUsingLlvmBackend("read special registers", nvptx);
+
+ case.compiles(
+ \\fn tid() usize {
+ \\ var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
+ \\ : [ret] "=r" (-> u32),
+ \\ );
+ \\ return @as(usize, tid);
+ \\}
+ \\
+ \\pub export fn main(a: []const i32, out: []i32) callconv(.PtxKernel) void {
+ \\ const i = tid();
+ \\ out[i] = a[i] + 7;
+ \\}
+ );
+ }
+
+ {
+ var case = ctx.exeUsingLlvmBackend("address spaces", nvptx);
+
+ case.compiles(
+ \\var x: u32 addrspace(.global) = 0;
+ \\
+ \\pub export fn increment(out: *i32) callconv(.PtxKernel) void {
+ \\ x += 1;
+ \\ out.* = x;
+ \\}
+ );
+ }
+}