aboutsummaryrefslogtreecommitdiff
path: root/test/stage2/nvptx.zig
blob: 7182092be78254670ff7534bdd2ba99d7b25f7a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
const std = @import("std");
const TestContext = @import("../../src/test.zig").TestContext;

pub fn addCases(ctx: *TestContext) !void {
    {
        var case = addPtx(ctx, "nvptx: simple addition and subtraction");

        case.compiles(
            \\fn add(a: i32, b: i32) i32 {
            \\    return a + b;
            \\}
            \\
            \\pub export fn add_and_substract(a: i32, out: *i32) callconv(.PtxKernel) void {
            \\    const x = add(a, 7);
            \\    var y = add(2, 0);
            \\    y -= x;
            \\    out.* = y;
            \\}
        );
    }

    {
        var case = addPtx(ctx, "nvptx: read special registers");

        case.compiles(
            \\fn threadIdX() usize {
            \\     var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
            \\         : [ret] "=r" (-> u32),
            \\     );
            \\     return @as(usize, tid);
            \\}
            \\
            \\pub export fn special_reg(a: []const i32, out: []i32) callconv(.PtxKernel) void {
            \\    const i = threadIdX();
            \\    out[i] = a[i] + 7;
            \\}
        );
    }

    {
        var case = addPtx(ctx, "nvptx: address spaces");

        case.compiles(
            \\var x: i32 addrspace(.global) = 0;
            \\
            \\pub export fn increment(out: *i32) callconv(.PtxKernel) void {
            \\    x += 1;
            \\    out.* = x;
            \\}
        );
    }
}

const nvptx_target = std.zig.CrossTarget{
    .cpu_arch = .nvptx64,
    .os_tag = .cuda,
};

pub fn addPtx(
    ctx: *TestContext,
    name: []const u8,
) *TestContext.Case {
    ctx.cases.append(TestContext.Case{
        .name = name,
        .target = nvptx_target,
        .updates = std.ArrayList(TestContext.Update).init(ctx.cases.allocator),
        .output_mode = .Obj,
        .files = std.ArrayList(TestContext.File).init(ctx.cases.allocator),
        .link_libc = false,
        .backend = .llvm,
    }) catch @panic("out of memory");
    return &ctx.cases.items[ctx.cases.items.len - 1];
}