aboutsummaryrefslogtreecommitdiff
path: root/src/ThreadPool.zig
blob: cf9c02fa5923b23d6777cb8b723b0da3c0090438 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// SPDX-License-Identifier: MIT
// Copyright (c) 2015-2020 Zig Contributors
// This file is part of [zig](https://ziglang.org/), which is MIT licensed.
// The MIT license requires this copyright notice to be included in all copies
// and substantial portions of the software.
const std = @import("std");
const ThreadPool = @This();

lock: std.Mutex = .{},
is_running: bool = true,
allocator: *std.mem.Allocator,
spawned: usize = 0,
threads: []*std.Thread,
run_queue: RunQueue = .{},
idle_queue: IdleQueue = .{},

const IdleQueue = std.SinglyLinkedList(std.ResetEvent);
const RunQueue = std.SinglyLinkedList(Runnable);
const Runnable = struct {
    runFn: fn (*Runnable) void,
};

pub fn init(self: *ThreadPool, allocator: *std.mem.Allocator) !void {
    self.* = .{
        .allocator = allocator,
        .threads = &[_]*std.Thread{},
    };
    if (std.builtin.single_threaded)
        return;

    errdefer self.deinit();

    var num_threads = std.math.max(1, std.Thread.cpuCount() catch 1);
    self.threads = try allocator.alloc(*std.Thread, num_threads);

    while (num_threads > 0) : (num_threads -= 1) {
        const thread = try std.Thread.spawn(self, runWorker);
        self.threads[self.spawned] = thread;
        self.spawned += 1;
    }
}

pub fn deinit(self: *ThreadPool) void {
    {
        const held = self.lock.acquire();
        defer held.release();

        self.is_running = false;
        while (self.idle_queue.popFirst()) |idle_node|
            idle_node.data.set();
    }

    defer self.allocator.free(self.threads);
    for (self.threads[0..self.spawned]) |thread|
        thread.wait();
}

pub fn spawn(self: *ThreadPool, comptime func: anytype, args: anytype) !void {
    if (std.builtin.single_threaded) {
        const result = @call(.{}, func, args);
        return;
    }

    const Args = @TypeOf(args);
    const Closure = struct {
        arguments: Args,
        pool: *ThreadPool,
        run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } },

        fn runFn(runnable: *Runnable) void {
            const run_node = @fieldParentPtr(RunQueue.Node, "data", runnable);
            const closure = @fieldParentPtr(@This(), "run_node", run_node);
            const result = @call(.{}, func, closure.arguments);

            const held = closure.pool.lock.acquire();
            defer held.release();
            closure.pool.allocator.destroy(closure);
        }
    };

    const held = self.lock.acquire();
    defer held.release();

    const closure = try self.allocator.create(Closure);
    closure.* = .{
        .arguments = args,
        .pool = self,
    };

    self.run_queue.prepend(&closure.run_node);

    if (self.idle_queue.popFirst()) |idle_node|
        idle_node.data.set();
}

fn runWorker(self: *ThreadPool) void {
    while (true) {
        const held = self.lock.acquire();

        if (self.run_queue.popFirst()) |run_node| {
            held.release();
            (run_node.data.runFn)(&run_node.data);
            continue;
        }

        if (self.is_running) {
            var idle_node = IdleQueue.Node{ .data = std.ResetEvent.init() };

            self.idle_queue.prepend(&idle_node);
            held.release();

            idle_node.data.wait();
            idle_node.data.deinit();
            continue;
        }

        held.release();
        return;
    }
}