diff options
| author | Frank Denis <124872+jedisct1@users.noreply.github.com> | 2025-10-15 14:03:56 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-15 14:03:56 +0200 |
| commit | 6669885aa2a33228baa29daba3b14985158d866f (patch) | |
| tree | 857aa08a89d005ae9944808ebf1fc8b93db1594d /lib/std | |
| parent | 70c21fdbab064ca060e5f111010836845ca35930 (diff) | |
| download | zig-6669885aa2a33228baa29daba3b14985158d866f.tar.gz zig-6669885aa2a33228baa29daba3b14985158d866f.zip | |
Faster BLAKE3 implementation (#25574)
This is a rewrite of the BLAKE3 implementation, with vectorization.
On Apple Silicon, the new implementation is about twice as fast as the previous one.
With AVX2, it is more than 4 times faster.
With AVX512, it is more than 7.5x faster than the previous implementation (from 678 MB/s to 5086 MB/s).
Diffstat (limited to 'lib/std')
| -rw-r--r-- | lib/std/crypto/blake3.zig | 1261 |
1 files changed, 880 insertions, 381 deletions
diff --git a/lib/std/crypto/blake3.zig b/lib/std/crypto/blake3.zig index a840a30632..3b056b02d0 100644 --- a/lib/std/crypto/blake3.zig +++ b/lib/std/crypto/blake3.zig @@ -1,391 +1,833 @@ -// Translated from BLAKE3 reference implementation. -// Source: https://github.com/BLAKE3-team/BLAKE3 - -const std = @import("../std.zig"); +const std = @import("std"); const builtin = @import("builtin"); const fmt = std.fmt; -const math = std.math; const mem = std.mem; -const testing = std.testing; -const ChunkIterator = struct { - slice: []u8, - chunk_len: usize, +const Vec4 = @Vector(4, u32); +const Vec8 = @Vector(8, u32); +const Vec16 = @Vector(16, u32); - fn init(slice: []u8, chunk_len: usize) ChunkIterator { - return ChunkIterator{ - .slice = slice, - .chunk_len = chunk_len, - }; +const chunk_length = 1024; +const max_depth = 54; + +pub const simd_degree = std.simd.suggestVectorLength(u32) orelse 1; +pub const max_simd_degree = simd_degree; +const max_simd_degree_or_2 = if (max_simd_degree > 2) max_simd_degree else 2; + +const iv: [8]u32 = .{ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, + 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +}; + +const msg_schedule: [7][16]u8 = .{ + .{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, + .{ 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 }, + .{ 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 }, + .{ 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 }, + .{ 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 }, + .{ 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 }, + .{ 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 }, +}; + +const Flags = packed struct(u8) { + chunk_start: bool = false, + chunk_end: bool = false, + parent: bool = false, + root: bool = false, + keyed_hash: bool = false, + derive_key_context: bool = false, + derive_key_material: bool = false, + reserved: bool = false, + + fn toInt(self: Flags) u8 { + return @bitCast(self); } - fn next(self: *ChunkIterator) ?[]u8 { - const next_chunk = self.slice[0..@min(self.chunk_len, self.slice.len)]; - self.slice = self.slice[next_chunk.len..]; - return if (next_chunk.len > 0) next_chunk else null; + fn with(self: Flags, other: Flags) Flags { + return @bitCast(self.toInt() | other.toInt()); } }; -const OUT_LEN: usize = 32; -const KEY_LEN: usize = 32; -const BLOCK_LEN: usize = 64; -const CHUNK_LEN: usize = 1024; +const rotr = std.math.rotr; -const IV = [8]u32{ - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, -}; +inline fn rotr32(w: u32, c: u5) u32 { + return rotr(u32, w, c); +} -const MSG_SCHEDULE = [7][16]u8{ - [_]u8{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, - [_]u8{ 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 }, - [_]u8{ 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 }, - [_]u8{ 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 }, - [_]u8{ 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 }, - [_]u8{ 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 }, - [_]u8{ 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 }, -}; +inline fn load32(bytes: []const u8) u32 { + return mem.readInt(u32, bytes[0..4], .little); +} -// These are the internal flags that we use to domain separate root/non-root, -// chunk/parent, and chunk beginning/middle/end. These get set at the high end -// of the block flags word in the compression function, so their values start -// high and go down. -const CHUNK_START: u8 = 1 << 0; -const CHUNK_END: u8 = 1 << 1; -const PARENT: u8 = 1 << 2; -const ROOT: u8 = 1 << 3; -const KEYED_HASH: u8 = 1 << 4; -const DERIVE_KEY_CONTEXT: u8 = 1 << 5; -const DERIVE_KEY_MATERIAL: u8 = 1 << 6; - -const CompressVectorized = struct { - const Lane = @Vector(4, u32); - const Rows = [4]Lane; - - fn g(comptime even: bool, rows: *Rows, m: Lane) void { - rows[0] +%= rows[1] +% m; - rows[3] ^= rows[0]; - rows[3] = math.rotr(Lane, rows[3], if (even) 8 else 16); - rows[2] +%= rows[3]; - rows[1] ^= rows[2]; - rows[1] = math.rotr(Lane, rows[1], if (even) 7 else 12); - } - - fn diagonalize(rows: *Rows) void { - rows[0] = @shuffle(u32, rows[0], undefined, [_]i32{ 3, 0, 1, 2 }); - rows[3] = @shuffle(u32, rows[3], undefined, [_]i32{ 2, 3, 0, 1 }); - rows[2] = @shuffle(u32, rows[2], undefined, [_]i32{ 1, 2, 3, 0 }); - } - - fn undiagonalize(rows: *Rows) void { - rows[0] = @shuffle(u32, rows[0], undefined, [_]i32{ 1, 2, 3, 0 }); - rows[3] = @shuffle(u32, rows[3], undefined, [_]i32{ 2, 3, 0, 1 }); - rows[2] = @shuffle(u32, rows[2], undefined, [_]i32{ 3, 0, 1, 2 }); - } - - fn compress( - chaining_value: [8]u32, - block_words: [16]u32, - block_len: u32, - counter: u64, - flags: u8, - ) [16]u32 { - const md = Lane{ @as(u32, @truncate(counter)), @as(u32, @truncate(counter >> 32)), block_len, @as(u32, flags) }; - var rows = Rows{ chaining_value[0..4].*, chaining_value[4..8].*, IV[0..4].*, md }; - - var m = Rows{ block_words[0..4].*, block_words[4..8].*, block_words[8..12].*, block_words[12..16].* }; - var t0 = @shuffle(u32, m[0], m[1], [_]i32{ 0, 2, (-1 - 0), (-1 - 2) }); - g(false, &rows, t0); - var t1 = @shuffle(u32, m[0], m[1], [_]i32{ 1, 3, (-1 - 1), (-1 - 3) }); - g(true, &rows, t1); - diagonalize(&rows); - var t2 = @shuffle(u32, m[2], m[3], [_]i32{ 0, 2, (-1 - 0), (-1 - 2) }); - t2 = @shuffle(u32, t2, undefined, [_]i32{ 3, 0, 1, 2 }); - g(false, &rows, t2); - var t3 = @shuffle(u32, m[2], m[3], [_]i32{ 1, 3, (-1 - 1), (-1 - 3) }); - t3 = @shuffle(u32, t3, undefined, [_]i32{ 3, 0, 1, 2 }); - g(true, &rows, t3); - undiagonalize(&rows); - m = Rows{ t0, t1, t2, t3 }; - - var i: usize = 0; - while (i < 6) : (i += 1) { - t0 = @shuffle(u32, m[0], m[1], [_]i32{ 2, 1, (-1 - 1), (-1 - 3) }); - t0 = @shuffle(u32, t0, undefined, [_]i32{ 1, 2, 3, 0 }); - g(false, &rows, t0); - t1 = @shuffle(u32, m[2], m[3], [_]i32{ 2, 2, (-1 - 3), (-1 - 3) }); - var tt = @shuffle(u32, m[0], undefined, [_]i32{ 3, 3, 0, 0 }); - t1 = @shuffle(u32, tt, t1, [_]i32{ 0, (-1 - 1), 2, (-1 - 3) }); - g(true, &rows, t1); - diagonalize(&rows); - t2 = @shuffle(u32, m[3], m[1], [_]i32{ 0, 1, (-1 - 0), (-1 - 1) }); - tt = @shuffle(u32, t2, m[2], [_]i32{ 0, 1, 2, (-1 - 3) }); - t2 = @shuffle(u32, tt, undefined, [_]i32{ 0, 2, 3, 1 }); - g(false, &rows, t2); - t3 = @shuffle(u32, m[1], m[3], [_]i32{ 2, (-1 - 2), 3, (-1 - 3) }); - tt = @shuffle(u32, m[2], t3, [_]i32{ 0, (-1 - 0), 1, (-1 - 1) }); - t3 = @shuffle(u32, tt, undefined, [_]i32{ 2, 3, 1, 0 }); - g(true, &rows, t3); - undiagonalize(&rows); - m = Rows{ t0, t1, t2, t3 }; - } +inline fn store32(bytes: []u8, w: u32) void { + mem.writeInt(u32, bytes[0..4], w, .little); +} - rows[0] ^= rows[2]; - rows[1] ^= rows[3]; - rows[2] ^= @Vector(4, u32){ chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3] }; - rows[3] ^= @Vector(4, u32){ chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7] }; +fn loadKeyWords(key: [Blake3.key_length]u8) [8]u32 { + var key_words: [8]u32 = undefined; + for (0..8) |i| { + key_words[i] = load32(key[i * 4 ..][0..4]); + } + return key_words; +} - return @as([16]u32, @bitCast(rows)); +fn storeCvWords(cv_words: [8]u32) [Blake3.digest_length]u8 { + var bytes: [Blake3.digest_length]u8 = undefined; + for (0..8) |i| { + store32(bytes[i * 4 ..][0..4], cv_words[i]); } -}; + return bytes; +} -const CompressGeneric = struct { - fn g(state: *[16]u32, comptime a: usize, comptime b: usize, comptime c: usize, comptime d: usize, mx: u32, my: u32) void { - state[a] +%= state[b] +% mx; - state[d] = math.rotr(u32, state[d] ^ state[a], 16); - state[c] +%= state[d]; - state[b] = math.rotr(u32, state[b] ^ state[c], 12); - state[a] +%= state[b] +% my; - state[d] = math.rotr(u32, state[d] ^ state[a], 8); - state[c] +%= state[d]; - state[b] = math.rotr(u32, state[b] ^ state[c], 7); - } - - fn round(state: *[16]u32, msg: [16]u32, schedule: [16]u8) void { - // Mix the columns. - g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]); - g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]); - g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]); - g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]); - - // Mix the diagonals. - g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]); - g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]); - g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]); - g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]); - } - - fn compress( - chaining_value: [8]u32, - block_words: [16]u32, - block_len: u32, - counter: u64, - flags: u8, - ) [16]u32 { - var state = [16]u32{ - chaining_value[0], - chaining_value[1], - chaining_value[2], - chaining_value[3], - chaining_value[4], - chaining_value[5], - chaining_value[6], - chaining_value[7], - IV[0], - IV[1], - IV[2], - IV[3], - @as(u32, @truncate(counter)), - @as(u32, @truncate(counter >> 32)), - block_len, - flags, - }; - for (MSG_SCHEDULE) |schedule| { - round(&state, block_words, schedule); +fn loadCvWords(bytes: [Blake3.digest_length]u8) [8]u32 { + var cv_words: [8]u32 = undefined; + for (0..8) |i| { + cv_words[i] = load32(bytes[i * 4 ..][0..4]); + } + return cv_words; +} + +inline fn counterLow(counter: u64) u32 { + return @truncate(counter); +} + +inline fn counterHigh(counter: u64) u32 { + return @truncate(counter >> 32); +} + +fn highestOne(x: u64) u6 { + if (x == 0) return 0; + return @intCast(63 - @clz(x)); +} + +fn roundDownToPowerOf2(x: u64) u64 { + return @as(u64, 1) << highestOne(x | 1); +} + +inline fn g(state: *[16]u32, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) void { + state[a] +%= state[b] +% x; + state[d] = rotr32(state[d] ^ state[a], 16); + state[c] +%= state[d]; + state[b] = rotr32(state[b] ^ state[c], 12); + state[a] +%= state[b] +% y; + state[d] = rotr32(state[d] ^ state[a], 8); + state[c] +%= state[d]; + state[b] = rotr32(state[b] ^ state[c], 7); +} + +inline fn roundFn(state: *[16]u32, msg: *const [16]u32, round: usize) void { + const schedule = &msg_schedule[round]; + + g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]); + g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]); + g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]); + g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]); + + g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]); + g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]); + g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]); + g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]); +} + +fn compressPre(state: *[16]u32, cv: *const [8]u32, block: []const u8, block_len: u8, counter: u64, flags: Flags) void { + var block_words: [16]u32 = undefined; + for (0..16) |i| { + block_words[i] = load32(block[i * 4 ..][0..4]); + } + + for (0..8) |i| { + state[i] = cv[i]; + } + for (0..4) |i| { + state[i + 8] = iv[i]; + } + state[12] = counterLow(counter); + state[13] = counterHigh(counter); + state[14] = @as(u32, block_len); + state[15] = @as(u32, flags.toInt()); + + for (0..7) |round| { + roundFn(state, &block_words, round); + } +} + +fn compressInPlace(cv: *[8]u32, block: []const u8, block_len: u8, counter: u64, flags: Flags) void { + var state: [16]u32 = undefined; + compressPre(&state, cv, block, block_len, counter, flags); + for (0..8) |i| { + cv[i] = state[i] ^ state[i + 8]; + } +} + +fn compressXof(cv: *const [8]u32, block: []const u8, block_len: u8, counter: u64, flags: Flags, out: *[64]u8) void { + var state: [16]u32 = undefined; + compressPre(&state, cv, block, block_len, counter, flags); + + for (0..8) |i| { + store32(out[i * 4 ..][0..4], state[i] ^ state[i + 8]); + } + for (0..8) |i| { + store32(out[(i + 8) * 4 ..][0..4], state[i + 8] ^ cv[i]); + } +} + +fn hashOne(input: []const u8, blocks: usize, key: [8]u32, counter: u64, flags: Flags, flags_start: Flags, flags_end: Flags) [Blake3.digest_length]u8 { + var cv = key; + var block_flags = flags.with(flags_start); + var inp = input; + var remaining_blocks = blocks; + + while (remaining_blocks > 0) { + if (remaining_blocks == 1) { + block_flags = block_flags.with(flags_end); } - for (chaining_value, 0..) |_, i| { - state[i] ^= state[i + 8]; - state[i + 8] ^= chaining_value[i]; + compressInPlace(&cv, inp[0..Blake3.block_length], Blake3.block_length, counter, block_flags); + inp = inp[Blake3.block_length..]; + remaining_blocks -= 1; + block_flags = flags; + } + + return storeCvWords(cv); +} + +fn hashManyPortable(inputs: [][*]const u8, num_inputs: usize, blocks: usize, key: [8]u32, counter_arg: u64, increment_counter: bool, flags: Flags, flags_start: Flags, flags_end: Flags, out: []u8) void { + var counter = counter_arg; + for (0..num_inputs) |i| { + const input = inputs[i][0 .. blocks * Blake3.block_length]; + const result = hashOne(input, blocks, key, counter, flags, flags_start, flags_end); + @memcpy(out[i * Blake3.digest_length ..][0..Blake3.digest_length], &result); + if (increment_counter) { + counter += 1; } - return state; } -}; +} -const compress = if (builtin.cpu.arch == .x86_64) - CompressVectorized.compress -else - CompressGeneric.compress; +fn transposeNxN(comptime Vec: type, comptime n: comptime_int, vecs: *[n]Vec) void { + const temp: [n]Vec = vecs.*; -fn first8Words(words: [16]u32) [8]u32 { - return @as(*const [8]u32, @ptrCast(&words)).*; + inline for (0..n) |i| { + inline for (0..n) |j| { + vecs[i][j] = temp[j][i]; + } + } } -fn wordsFromLittleEndianBytes(comptime count: usize, bytes: [count * 4]u8) [count]u32 { - var words: [count]u32 = undefined; - for (&words, 0..) |*word, i| { - word.* = mem.readInt(u32, bytes[4 * i ..][0..4], .little); +fn transposeMsg(comptime Vec: type, comptime n: comptime_int, inputs: [n][*]const u8, block_offset: usize, out: *[16]Vec) void { + const info = @typeInfo(Vec); + if (info != .vector) @compileError("transposeMsg requires a vector type"); + if (info.vector.len != n) @compileError("vector width must match N"); + + var temp: [n][16]u32 = undefined; + + for (0..n) |i| { + const block = inputs[i] + block_offset; + for (0..16) |j| { + temp[i][j] = load32(block[j * 4 ..][0..4]); + } + } + + for (0..16) |j| { + var result: Vec = undefined; + inline for (0..n) |i| { + result[i] = temp[i][j]; + } + out[j] = result; } - return words; } -// Each chunk or parent node can produce either an 8-word chaining value or, by -// setting the ROOT flag, any number of final output bytes. The Output struct -// captures the state just prior to choosing between those two possibilities. -const Output = struct { - input_chaining_value: [8]u32 align(16), - block_words: [16]u32 align(16), - block_len: u32, +fn roundFnVec(comptime Vec: type, v: *[16]Vec, m: *const [16]Vec, r: usize) void { + const schedule = &msg_schedule[r]; + + // Column round - first half + inline for (0..4) |i| { + v[i] +%= m[schedule[i * 2]]; + } + inline for (0..4) |i| { + v[i] +%= v[i + 4]; + } + inline for (0..4) |i| { + v[i + 12] ^= v[i]; + } + inline for (0..4) |i| { + v[i + 12] = rotr(Vec, v[i + 12], 16); + } + inline for (0..4) |i| { + v[i + 8] +%= v[i + 12]; + } + inline for (0..4) |i| { + v[i + 4] ^= v[i + 8]; + } + inline for (0..4) |i| { + v[i + 4] = rotr(Vec, v[i + 4], 12); + } + + // Column round - second half + inline for (0..4) |i| { + v[i] +%= m[schedule[i * 2 + 1]]; + } + inline for (0..4) |i| { + v[i] +%= v[i + 4]; + } + inline for (0..4) |i| { + v[i + 12] ^= v[i]; + } + inline for (0..4) |i| { + v[i + 12] = rotr(Vec, v[i + 12], 8); + } + inline for (0..4) |i| { + v[i + 8] +%= v[i + 12]; + } + inline for (0..4) |i| { + v[i + 4] ^= v[i + 8]; + } + inline for (0..4) |i| { + v[i + 4] = rotr(Vec, v[i + 4], 7); + } + + // Diagonal round - first half + inline for (0..4) |i| { + v[i] +%= m[schedule[i * 2 + 8]]; + } + const b_indices = [4]u8{ 5, 6, 7, 4 }; + inline for (0..4) |i| { + v[i] +%= v[b_indices[i]]; + } + const d_indices = [4]u8{ 15, 12, 13, 14 }; + inline for (0..4) |i| { + v[d_indices[i]] ^= v[i]; + } + inline for (0..4) |i| { + v[d_indices[i]] = rotr(Vec, v[d_indices[i]], 16); + } + const c_indices = [4]u8{ 10, 11, 8, 9 }; + inline for (0..4) |i| { + v[c_indices[i]] +%= v[d_indices[i]]; + } + inline for (0..4) |i| { + v[b_indices[i]] ^= v[c_indices[i]]; + } + inline for (0..4) |i| { + v[b_indices[i]] = rotr(Vec, v[b_indices[i]], 12); + } + + // Diagonal round - second half + inline for (0..4) |i| { + v[i] +%= m[schedule[i * 2 + 9]]; + } + inline for (0..4) |i| { + v[i] +%= v[b_indices[i]]; + } + inline for (0..4) |i| { + v[d_indices[i]] ^= v[i]; + } + inline for (0..4) |i| { + v[d_indices[i]] = rotr(Vec, v[d_indices[i]], 8); + } + inline for (0..4) |i| { + v[c_indices[i]] +%= v[d_indices[i]]; + } + inline for (0..4) |i| { + v[b_indices[i]] ^= v[c_indices[i]]; + } + inline for (0..4) |i| { + v[b_indices[i]] = rotr(Vec, v[b_indices[i]], 7); + } +} + +fn hashVec( + comptime Vec: type, + comptime n: comptime_int, + inputs: [n][*]const u8, + blocks: usize, + key: [8]u32, counter: u64, - flags: u8, + increment_counter: bool, + flags: Flags, + flags_start: Flags, + flags_end: Flags, + out: *[n * Blake3.digest_length]u8, +) void { + var h_vecs: [8]Vec = undefined; + for (0..8) |i| { + h_vecs[i] = @splat(key[i]); + } - fn chainingValue(self: *const Output) [8]u32 { - return first8Words(compress( - self.input_chaining_value, - self.block_words, - self.block_len, - self.counter, - self.flags, - )); - } - - fn rootOutputBytes(self: *const Output, output: []u8) void { - var out_block_it = ChunkIterator.init(output, 2 * OUT_LEN); - var output_block_counter: usize = 0; - while (out_block_it.next()) |out_block| { - const words = compress( - self.input_chaining_value, - self.block_words, - self.block_len, - output_block_counter, - self.flags | ROOT, - ); - var out_word_it = ChunkIterator.init(out_block, 4); - var word_counter: usize = 0; - while (out_word_it.next()) |out_word| { - var word_bytes: [4]u8 = undefined; - mem.writeInt(u32, &word_bytes, words[word_counter], .little); - @memcpy(out_word, word_bytes[0..out_word.len]); - word_counter += 1; + const counter_low_vec = if (increment_counter) blk: { + var result: Vec = undefined; + inline for (0..n) |i| { + result[i] = counterLow(counter + i); + } + break :blk result; + } else @as(Vec, @splat(counterLow(counter))); + + const counter_high_vec = if (increment_counter) blk: { + var result: Vec = undefined; + inline for (0..n) |i| { + result[i] = counterHigh(counter + i); + } + break :blk result; + } else @as(Vec, @splat(counterHigh(counter))); + + var block_flags = flags.with(flags_start); + + for (0..blocks) |block| { + if (block + 1 == blocks) { + block_flags = block_flags.with(flags_end); + } + + const block_len_vec: Vec = @splat(Blake3.block_length); + const block_flags_vec: Vec = @splat(@as(u32, block_flags.toInt())); + + var msg_vecs: [16]Vec = undefined; + transposeMsg(Vec, n, inputs, block * Blake3.block_length, &msg_vecs); + + var v: [16]Vec = .{ + h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], + h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], + @splat(iv[0]), @splat(iv[1]), @splat(iv[2]), @splat(iv[3]), + counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, + }; + + inline for (0..7) |r| { + roundFnVec(Vec, &v, &msg_vecs, r); + } + + inline for (0..8) |i| { + h_vecs[i] = v[i] ^ v[i + 8]; + } + + block_flags = flags; + } + + // Output serialization - different strategies for different widths + switch (n) { + 4 => { + // Special interleaved pattern for Vec4 + var out_vecs = [4]Vec{ h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3] }; + transposeNxN(Vec, 4, &out_vecs); + inline for (0..4) |i| { + mem.writeInt(u32, out[0 * 16 + i * 4 ..][0..4], out_vecs[0][i], .little); } - output_block_counter += 1; + inline for (0..4) |i| { + mem.writeInt(u32, out[2 * 16 + i * 4 ..][0..4], out_vecs[1][i], .little); + } + inline for (0..4) |i| { + mem.writeInt(u32, out[4 * 16 + i * 4 ..][0..4], out_vecs[2][i], .little); + } + inline for (0..4) |i| { + mem.writeInt(u32, out[6 * 16 + i * 4 ..][0..4], out_vecs[3][i], .little); + } + + out_vecs = [4]Vec{ h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7] }; + transposeNxN(Vec, 4, &out_vecs); + inline for (0..4) |i| { + mem.writeInt(u32, out[1 * 16 + i * 4 ..][0..4], out_vecs[0][i], .little); + } + inline for (0..4) |i| { + mem.writeInt(u32, out[3 * 16 + i * 4 ..][0..4], out_vecs[1][i], .little); + } + inline for (0..4) |i| { + mem.writeInt(u32, out[5 * 16 + i * 4 ..][0..4], out_vecs[2][i], .little); + } + inline for (0..4) |i| { + mem.writeInt(u32, out[7 * 16 + i * 4 ..][0..4], out_vecs[3][i], .little); + } + }, + 8 => { + // Linear pattern with transpose for Vec8 + var out_vecs = [8]Vec{ h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7] }; + transposeNxN(Vec, 8, &out_vecs); + inline for (0..8) |i| { + mem.writeInt(u32, out[0 * 32 + i * 4 ..][0..4], out_vecs[0][i], .little); + } + inline for (0..8) |i| { + mem.writeInt(u32, out[1 * 32 + i * 4 ..][0..4], out_vecs[1][i], .little); + } + inline for (0..8) |i| { + mem.writeInt(u32, out[2 * 32 + i * 4 ..][0..4], out_vecs[2][i], .little); + } + inline for (0..8) |i| { + mem.writeInt(u32, out[3 * 32 + i * 4 ..][0..4], out_vecs[3][i], .little); + } + inline for (0..8) |i| { + mem.writeInt(u32, out[4 * 32 + i * 4 ..][0..4], out_vecs[4][i], .little); + } + inline for (0..8) |i| { + mem.writeInt(u32, out[5 * 32 + i * 4 ..][0..4], out_vecs[5][i], .little); + } + inline for (0..8) |i| { + mem.writeInt(u32, out[6 * 32 + i * 4 ..][0..4], out_vecs[6][i], .little); + } + inline for (0..8) |i| { + mem.writeInt(u32, out[7 * 32 + i * 4 ..][0..4], out_vecs[7][i], .little); + } + }, + 16 => { + // Direct lane-by-lane output for Vec16 (no transpose) + inline for (0..16) |lane| { + const hash_offset = lane * Blake3.digest_length; + inline for (0..8) |word_idx| { + const word = h_vecs[word_idx][lane]; + out[hash_offset + word_idx * 4 + 0] = @truncate(word); + out[hash_offset + word_idx * 4 + 1] = @truncate(word >> 8); + out[hash_offset + word_idx * 4 + 2] = @truncate(word >> 16); + out[hash_offset + word_idx * 4 + 3] = @truncate(word >> 24); + } + } + }, + else => @compileError("Unsupported SIMD width"), + } +} + +fn hashManySimd( + inputs: [][*]const u8, + num_inputs: usize, + blocks: usize, + key: [8]u32, + counter: u64, + increment_counter: bool, + flags: Flags, + flags_start: Flags, + flags_end: Flags, + out: []u8, +) void { + var remaining = num_inputs; + var inp = inputs.ptr; + var out_ptr = out.ptr; + var cnt = counter; + + const simd_deg = comptime simd_degree; + + if (comptime simd_deg >= 16) { + while (remaining >= 16) { + const sixteen_inputs = [16][*]const u8{ + inp[0], inp[1], inp[2], inp[3], + inp[4], inp[5], inp[6], inp[7], + inp[8], inp[9], inp[10], inp[11], + inp[12], inp[13], inp[14], inp[15], + }; + + var simd_out: [16 * Blake3.digest_length]u8 = undefined; + hashVec(Vec16, 16, sixteen_inputs, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, &simd_out); + + @memcpy(out_ptr[0 .. 16 * Blake3.digest_length], &simd_out); + + if (increment_counter) cnt += 16; + inp += 16; + remaining -= 16; + out_ptr += 16 * Blake3.digest_length; } } -}; + + if (comptime simd_deg >= 8) { + while (remaining >= 8) { + const eight_inputs = [8][*]const u8{ + inp[0], inp[1], inp[2], inp[3], + inp[4], inp[5], inp[6], inp[7], + }; + + var simd_out: [8 * Blake3.digest_length]u8 = undefined; + hashVec(Vec8, 8, eight_inputs, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, &simd_out); + + @memcpy(out_ptr[0 .. 8 * Blake3.digest_length], &simd_out); + + if (increment_counter) cnt += 8; + inp += 8; + remaining -= 8; + out_ptr += 8 * Blake3.digest_length; + } + } + + if (comptime simd_deg >= 4) { + while (remaining >= 4) { + const four_inputs = [4][*]const u8{ + inp[0], + inp[1], + inp[2], + inp[3], + }; + + var simd_out: [4 * Blake3.digest_length]u8 = undefined; + hashVec(Vec4, 4, four_inputs, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, &simd_out); + + @memcpy(out_ptr[0 .. 4 * Blake3.digest_length], &simd_out); + + if (increment_counter) cnt += 4; + inp += 4; + remaining -= 4; + out_ptr += 4 * Blake3.digest_length; + } + } + + if (remaining > 0) { + hashManyPortable(inp[0..remaining], remaining, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, out_ptr[0 .. remaining * Blake3.digest_length]); + } +} + +fn hashMany(inputs: [][*]const u8, num_inputs: usize, blocks: usize, key: [8]u32, counter: u64, increment_counter: bool, flags: Flags, flags_start: Flags, flags_end: Flags, out: []u8) void { + if (comptime max_simd_degree >= 4) { + hashManySimd(inputs, num_inputs, blocks, key, counter, increment_counter, flags, flags_start, flags_end, out); + } else { + hashManyPortable(inputs, num_inputs, blocks, key, counter, increment_counter, flags, flags_start, flags_end, out); + } +} + +fn compressChunksParallel(input: []const u8, key: [8]u32, chunk_counter: u64, flags: Flags, out: []u8) usize { + var chunks_array: [max_simd_degree][*]const u8 = undefined; + var input_position: usize = 0; + var chunks_array_len: usize = 0; + + while (input.len - input_position >= chunk_length) { + chunks_array[chunks_array_len] = input[input_position..].ptr; + input_position += chunk_length; + chunks_array_len += 1; + } + + hashMany(chunks_array[0..chunks_array_len], chunks_array_len, chunk_length / Blake3.block_length, key, chunk_counter, true, flags, .{ .chunk_start = true }, .{ .chunk_end = true }, out); + + if (input.len > input_position) { + const counter = chunk_counter + @as(u64, chunks_array_len); + var chunk_state = ChunkState.init(key, flags); + chunk_state.chunk_counter = counter; + chunk_state.update(input[input_position..]); + const output = chunk_state.output(); + const cv = output.chainingValue(); + const cv_bytes = storeCvWords(cv); + @memcpy(out[chunks_array_len * Blake3.digest_length ..][0..Blake3.digest_length], &cv_bytes); + return chunks_array_len + 1; + } else { + return chunks_array_len; + } +} + +fn compressParentsParallel(child_chaining_values: []const u8, num_chaining_values: usize, key: [8]u32, flags: Flags, out: []u8) usize { + var parents_array: [max_simd_degree_or_2][*]const u8 = undefined; + var parents_array_len: usize = 0; + + while (num_chaining_values - (2 * parents_array_len) >= 2) { + parents_array[parents_array_len] = child_chaining_values[2 * parents_array_len * Blake3.digest_length ..].ptr; + parents_array_len += 1; + } + + hashMany(parents_array[0..parents_array_len], parents_array_len, 1, key, 0, false, flags.with(.{ .parent = true }), .{}, .{}, out); + + if (num_chaining_values > 2 * parents_array_len) { + @memcpy(out[parents_array_len * Blake3.digest_length ..][0..Blake3.digest_length], child_chaining_values[2 * parents_array_len * Blake3.digest_length ..][0..Blake3.digest_length]); + return parents_array_len + 1; + } else { + return parents_array_len; + } +} + +fn compressSubtreeWide(input: []const u8, key: [8]u32, chunk_counter: u64, flags: Flags, out: []u8) usize { + if (input.len <= max_simd_degree * chunk_length) { + return compressChunksParallel(input, key, chunk_counter, flags, out); + } + + const left_input_len = leftSubtreeLen(input.len); + const right_input = input[left_input_len..]; + const right_chunk_counter = chunk_counter + @as(u64, left_input_len / chunk_length); + + var cv_array: [2 * max_simd_degree_or_2 * Blake3.digest_length]u8 = undefined; + var degree: usize = max_simd_degree; + if (left_input_len > chunk_length and degree == 1) { + degree = 2; + } + const right_cvs = cv_array[degree * Blake3.digest_length ..]; + + const left_n = compressSubtreeWide(input[0..left_input_len], key, chunk_counter, flags, cv_array[0..]); + const right_n = compressSubtreeWide(right_input, key, right_chunk_counter, flags, right_cvs); + + if (left_n == 1) { + @memcpy(out[0 .. 2 * Blake3.digest_length], cv_array[0 .. 2 * Blake3.digest_length]); + return 2; + } + + const num_chaining_values = left_n + right_n; + return compressParentsParallel(&cv_array, num_chaining_values, key, flags, out); +} + +fn compressSubtreeToParentNode(input: []const u8, key: [8]u32, chunk_counter: u64, flags: Flags, out: *[2 * Blake3.digest_length]u8) void { + var cv_array: [max_simd_degree_or_2 * Blake3.digest_length]u8 = undefined; + var num_cvs = compressSubtreeWide(input, key, chunk_counter, flags, &cv_array); + + if (max_simd_degree_or_2 > 2) { + var out_array: [max_simd_degree_or_2 * Blake3.digest_length / 2]u8 = undefined; + while (num_cvs > 2) { + num_cvs = compressParentsParallel(&cv_array, num_cvs, key, flags, &out_array); + @memcpy(cv_array[0 .. num_cvs * Blake3.digest_length], out_array[0 .. num_cvs * Blake3.digest_length]); + } + } + + @memcpy(out, cv_array[0 .. 2 * Blake3.digest_length]); +} + +fn leftSubtreeLen(input_len: usize) usize { + const full_chunks = (input_len - 1) / chunk_length; + return @intCast(roundDownToPowerOf2(full_chunks) * chunk_length); +} + +fn parentOutput(parent_block: []const u8, key: [8]u32, flags: Flags) Output { + var block: [Blake3.block_length]u8 = undefined; + @memcpy(&block, parent_block[0..Blake3.block_length]); + return Output{ + .input_cv = key, + .block = block, + .block_len = Blake3.block_length, + .counter = 0, + .flags = flags.with(.{ .parent = true }), + }; +} + +fn parentOutputFromCvs(left_cv: [8]u32, right_cv: [8]u32, key: [8]u32, flags: Flags) Output { + var block: [Blake3.block_length]u8 align(16) = undefined; + for (0..8) |i| { + store32(block[i * 4 ..][0..4], left_cv[i]); + store32(block[(i + 8) * 4 ..][0..4], right_cv[i]); + } + return Output{ + .input_cv = key, + .block = block, + .block_len = Blake3.block_length, + .counter = 0, + .flags = flags.with(.{ .parent = true }), + }; +} const ChunkState = struct { - chaining_value: [8]u32 align(16), + cv: [8]u32 align(16), chunk_counter: u64, - block: [BLOCK_LEN]u8 align(16) = [_]u8{0} ** BLOCK_LEN, - block_len: u8 = 0, - blocks_compressed: u8 = 0, - flags: u8, + buf: [Blake3.block_length]u8 align(16), + buf_len: u8, + blocks_compressed: u8, + flags: Flags, - fn init(key: [8]u32, chunk_counter: u64, flags: u8) ChunkState { + fn init(key: [8]u32, flags: Flags) ChunkState { return ChunkState{ - .chaining_value = key, - .chunk_counter = chunk_counter, + .cv = key, + .chunk_counter = 0, + .buf = [_]u8{0} ** Blake3.block_length, + .buf_len = 0, + .blocks_compressed = 0, .flags = flags, }; } + fn reset(self: *ChunkState, key: [8]u32, chunk_counter: u64) void { + self.cv = key; + self.chunk_counter = chunk_counter; + self.blocks_compressed = 0; + self.buf = [_]u8{0} ** Blake3.block_length; + self.buf_len = 0; + } + fn len(self: *const ChunkState) usize { - return BLOCK_LEN * @as(usize, self.blocks_compressed) + @as(usize, self.block_len); - } - - fn fillBlockBuf(self: *ChunkState, input: []const u8) []const u8 { - const want = BLOCK_LEN - self.block_len; - const take = @min(want, input.len); - @memcpy(self.block[self.block_len..][0..take], input[0..take]); - self.block_len += @as(u8, @truncate(take)); - return input[take..]; - } - - fn startFlag(self: *const ChunkState) u8 { - return if (self.blocks_compressed == 0) CHUNK_START else 0; - } - - fn update(self: *ChunkState, input_slice: []const u8) void { - var input = input_slice; - while (input.len > 0) { - // If the block buffer is full, compress it and clear it. More - // input is coming, so this compression is not CHUNK_END. - if (self.block_len == BLOCK_LEN) { - const block_words = wordsFromLittleEndianBytes(16, self.block); - self.chaining_value = first8Words(compress( - self.chaining_value, - block_words, - BLOCK_LEN, - self.chunk_counter, - self.flags | self.startFlag(), - )); + return (Blake3.block_length * @as(usize, self.blocks_compressed)) + @as(usize, self.buf_len); + } + + fn fillBuf(self: *ChunkState, input: []const u8) usize { + const take = @min(Blake3.block_length - @as(usize, self.buf_len), input.len); + @memcpy(self.buf[self.buf_len..][0..take], input[0..take]); + self.buf_len += @intCast(take); + return take; + } + + fn maybeStartFlag(self: *const ChunkState) Flags { + return if (self.blocks_compressed == 0) .{ .chunk_start = true } else .{}; + } + + fn update(self: *ChunkState, input: []const u8) void { + var inp = input; + + while (inp.len > 0) { + if (self.buf_len == Blake3.block_length) { + compressInPlace(&self.cv, &self.buf, Blake3.block_length, self.chunk_counter, self.flags.with(self.maybeStartFlag())); self.blocks_compressed += 1; - self.block = [_]u8{0} ** BLOCK_LEN; - self.block_len = 0; + self.buf = [_]u8{0} ** Blake3.block_length; + self.buf_len = 0; } - // Copy input bytes into the block buffer. - input = self.fillBlockBuf(input); + const take = self.fillBuf(inp); + inp = inp[take..]; } } fn output(self: *const ChunkState) Output { - const block_words = wordsFromLittleEndianBytes(16, self.block); + const block_flags = self.flags.with(self.maybeStartFlag()).with(.{ .chunk_end = true }); return Output{ - .input_chaining_value = self.chaining_value, - .block_words = block_words, - .block_len = self.block_len, + .input_cv = self.cv, + .block = self.buf, + .block_len = self.buf_len, .counter = self.chunk_counter, - .flags = self.flags | self.startFlag() | CHUNK_END, + .flags = block_flags, }; } }; -fn parentOutput( - left_child_cv: [8]u32, - right_child_cv: [8]u32, - key: [8]u32, - flags: u8, -) Output { - var block_words: [16]u32 align(16) = undefined; - block_words[0..8].* = left_child_cv; - block_words[8..].* = right_child_cv; - return Output{ - .input_chaining_value = key, - .block_words = block_words, - .block_len = BLOCK_LEN, // Always BLOCK_LEN (64) for parent nodes. - .counter = 0, // Always 0 for parent nodes. - .flags = PARENT | flags, - }; -} +const Output = struct { + input_cv: [8]u32 align(16), + block: [Blake3.block_length]u8 align(16), + block_len: u8, + counter: u64, + flags: Flags, -fn parentCv( - left_child_cv: [8]u32, - right_child_cv: [8]u32, - key: [8]u32, - flags: u8, -) [8]u32 { - return parentOutput(left_child_cv, right_child_cv, key, flags).chainingValue(); -} + fn chainingValue(self: *const Output) [8]u32 { + var cv_words = self.input_cv; + compressInPlace(&cv_words, &self.block, self.block_len, self.counter, self.flags); + return cv_words; + } + + fn rootBytes(self: *const Output, seek: u64, out: []u8) void { + if (out.len == 0) return; + + var output_block_counter = seek / 64; + const offset_within_block = @as(usize, @intCast(seek % 64)); + var out_remaining = out; + + if (offset_within_block > 0) { + var wide_buf: [64]u8 = undefined; + compressXof(&self.input_cv, &self.block, self.block_len, output_block_counter, self.flags.with(.{ .root = true }), &wide_buf); + const available_bytes = 64 - offset_within_block; + const bytes = @min(out_remaining.len, available_bytes); + @memcpy(out_remaining[0..bytes], wide_buf[offset_within_block..][0..bytes]); + out_remaining = out_remaining[bytes..]; + output_block_counter += 1; + } -/// An incremental hasher that can accept any number of writes. + while (out_remaining.len >= 64) { + compressXof(&self.input_cv, &self.block, self.block_len, output_block_counter, self.flags.with(.{ .root = true }), out_remaining[0..64]); + out_remaining = out_remaining[64..]; + output_block_counter += 1; + } + + if (out_remaining.len > 0) { + var wide_buf: [64]u8 = undefined; + compressXof(&self.input_cv, &self.block, self.block_len, output_block_counter, self.flags.with(.{ .root = true }), &wide_buf); + @memcpy(out_remaining, wide_buf[0..out_remaining.len]); + } + } +}; + +/// BLAKE3 is a cryptographic hash function that produces a 256-bit digest by default but also supports extendable output. pub const Blake3 = struct { + pub const block_length = 64; + pub const digest_length = 32; + pub const key_length = 32; + pub const Options = struct { key: ?[digest_length]u8 = null }; pub const KdfOptions = struct {}; - chunk_state: ChunkState, key: [8]u32, - cv_stack: [54][8]u32 = undefined, // Space for 54 subtree chaining values: - cv_stack_len: u8 = 0, // 2^54 * CHUNK_LEN = 2^64 - flags: u8, - - pub const block_length = BLOCK_LEN; - pub const digest_length = OUT_LEN; - pub const key_length = KEY_LEN; - - fn init_internal(key: [8]u32, flags: u8) Blake3 { - return Blake3{ - .chunk_state = ChunkState.init(key, 0, flags), - .key = key, - .flags = flags, - }; - } + chunk: ChunkState, + cv_stack_len: u8, + cv_stack: [max_depth + 1][8]u32, /// Construct a new `Blake3` for the hash function, with an optional key pub fn init(options: Options) Blake3 { if (options.key) |key| { - const key_words = wordsFromLittleEndianBytes(8, key); - return Blake3.init_internal(key_words, KEYED_HASH); + const key_words = loadKeyWords(key); + return init_internal(key_words, .{ .keyed_hash = true }); } else { - return Blake3.init_internal(IV, 0); + return init_internal(iv, .{}); } } @@ -393,12 +835,12 @@ pub const Blake3 = struct { /// string should be hardcoded, globally unique, and application-specific. pub fn initKdf(context: []const u8, options: KdfOptions) Blake3 { _ = options; - var context_hasher = Blake3.init_internal(IV, DERIVE_KEY_CONTEXT); + var context_hasher = init_internal(iv, .{ .derive_key_context = true }); context_hasher.update(context); - var context_key: [KEY_LEN]u8 = undefined; - context_hasher.final(context_key[0..]); - const context_key_words = wordsFromLittleEndianBytes(8, context_key); - return Blake3.init_internal(context_key_words, DERIVE_KEY_MATERIAL); + var context_key: [key_length]u8 = undefined; + context_hasher.final(&context_key); + const context_key_words = loadKeyWords(context_key); + return init_internal(context_key_words, .{ .derive_key_material = true }); } pub fn hash(b: []const u8, out: []u8, options: Options) void { @@ -407,78 +849,135 @@ pub const Blake3 = struct { d.final(out); } - fn pushCv(self: *Blake3, cv: [8]u32) void { - self.cv_stack[self.cv_stack_len] = cv; - self.cv_stack_len += 1; + fn init_internal(key: [8]u32, flags: Flags) Blake3 { + return Blake3{ + .key = key, + .chunk = ChunkState.init(key, flags), + .cv_stack_len = 0, + .cv_stack = undefined, + }; } - fn popCv(self: *Blake3) [8]u32 { - self.cv_stack_len -= 1; - return self.cv_stack[self.cv_stack_len]; - } - - // Section 5.1.2 of the BLAKE3 spec explains this algorithm in more detail. - fn addChunkChainingValue(self: *Blake3, first_cv: [8]u32, total_chunks: u64) void { - // This chunk might complete some subtrees. For each completed subtree, - // its left child will be the current top entry in the CV stack, and - // its right child will be the current value of `new_cv`. Pop each left - // child off the stack, merge it with `new_cv`, and overwrite `new_cv` - // with the result. After all these merges, push the final value of - // `new_cv` onto the stack. The number of completed subtrees is given - // by the number of trailing 0-bits in the new total number of chunks. - var new_cv = first_cv; - var chunk_counter = total_chunks; - while (chunk_counter & 1 == 0) { - new_cv = parentCv(self.popCv(), new_cv, self.key, self.flags); - chunk_counter >>= 1; + fn mergeCvStack(self: *Blake3, total_len: u64) void { + const post_merge_stack_len = @as(u8, @intCast(@popCount(total_len))); + while (self.cv_stack_len > post_merge_stack_len) { + const left_cv = self.cv_stack[self.cv_stack_len - 2]; + const right_cv = self.cv_stack[self.cv_stack_len - 1]; + const output = parentOutputFromCvs(left_cv, right_cv, self.key, self.chunk.flags); + const cv = output.chainingValue(); + self.cv_stack[self.cv_stack_len - 2] = cv; + self.cv_stack_len -= 1; } - self.pushCv(new_cv); + } + + fn pushCv(self: *Blake3, new_cv: [8]u32, chunk_counter: u64) void { + self.mergeCvStack(chunk_counter); + self.cv_stack[self.cv_stack_len] = new_cv; + self.cv_stack_len += 1; } /// Add input to the hash state. This can be called any number of times. - pub fn update(self: *Blake3, input_slice: []const u8) void { - var input = input_slice; - while (input.len > 0) { - // If the current chunk is complete, finalize it and reset the - // chunk state. More input is coming, so this chunk is not ROOT. - if (self.chunk_state.len() == CHUNK_LEN) { - const chunk_cv = self.chunk_state.output().chainingValue(); - const total_chunks = self.chunk_state.chunk_counter + 1; - self.addChunkChainingValue(chunk_cv, total_chunks); - self.chunk_state = ChunkState.init(self.key, total_chunks, self.flags); + pub fn update(self: *Blake3, input: []const u8) void { + if (input.len == 0) return; + + var inp = input; + + if (self.chunk.len() > 0) { + const take = @min(chunk_length - self.chunk.len(), inp.len); + self.chunk.update(inp[0..take]); + inp = inp[take..]; + if (inp.len > 0) { + const output = self.chunk.output(); + const chunk_cv = output.chainingValue(); + self.pushCv(chunk_cv, self.chunk.chunk_counter); + self.chunk.reset(self.key, self.chunk.chunk_counter + 1); + } else { + return; } + } + + while (inp.len > chunk_length) { + var subtree_len = roundDownToPowerOf2(inp.len); + const count_so_far = self.chunk.chunk_counter * chunk_length; - // Compress input bytes into the current chunk state. - const want = CHUNK_LEN - self.chunk_state.len(); - const take = @min(want, input.len); - self.chunk_state.update(input[0..take]); - input = input[take..]; + while ((subtree_len - 1) & count_so_far != 0) { + subtree_len /= 2; + } + + const subtree_chunks = subtree_len / chunk_length; + if (subtree_len <= chunk_length) { + var chunk_state = ChunkState.init(self.key, self.chunk.flags); + chunk_state.chunk_counter = self.chunk.chunk_counter; + chunk_state.update(inp[0..@intCast(subtree_len)]); + const output = chunk_state.output(); + const cv = output.chainingValue(); + self.pushCv(cv, chunk_state.chunk_counter); + } else { + var cv_pair: [2 * digest_length]u8 = undefined; + compressSubtreeToParentNode(inp[0..@intCast(subtree_len)], self.key, self.chunk.chunk_counter, self.chunk.flags, &cv_pair); + const left_cv = loadCvWords(cv_pair[0..digest_length].*); + const right_cv = loadCvWords(cv_pair[digest_length..][0..digest_length].*); + self.pushCv(left_cv, self.chunk.chunk_counter); + self.pushCv(right_cv, self.chunk.chunk_counter + (subtree_chunks / 2)); + } + self.chunk.chunk_counter += subtree_chunks; + inp = inp[@intCast(subtree_len)..]; + } + + if (inp.len > 0) { + self.chunk.update(inp); + self.mergeCvStack(self.chunk.chunk_counter); } } /// Finalize the hash and write any number of output bytes. - pub fn final(self: *const Blake3, out_slice: []u8) void { - // Starting with the Output from the current chunk, compute all the - // parent chaining values along the right edge of the tree, until we - // have the root Output. - var output = self.chunk_state.output(); - var parent_nodes_remaining: usize = self.cv_stack_len; - while (parent_nodes_remaining > 0) { - parent_nodes_remaining -= 1; - output = parentOutput( - self.cv_stack[parent_nodes_remaining], - output.chainingValue(), - self.key, - self.flags, - ); + pub fn final(self: *const Blake3, out: []u8) void { + self.finalizeSeek(0, out); + } + + /// Finalize the hash and write any number of output bytes, starting at a given seek position. + /// This is an XOF (extendable-output function) extension. + pub fn finalizeSeek(self: *const Blake3, seek: u64, out: []u8) void { + if (out.len == 0) return; + + if (self.cv_stack_len == 0) { + const output = self.chunk.output(); + output.rootBytes(seek, out); + return; } - output.rootOutputBytes(out_slice); + + var output: Output = undefined; + var cvs_remaining: usize = undefined; + + if (self.chunk.len() > 0) { + cvs_remaining = self.cv_stack_len; + output = self.chunk.output(); + } else { + cvs_remaining = self.cv_stack_len - 2; + const left_cv = self.cv_stack[cvs_remaining]; + const right_cv = self.cv_stack[cvs_remaining + 1]; + output = parentOutputFromCvs(left_cv, right_cv, self.key, self.chunk.flags); + } + + while (cvs_remaining > 0) { + cvs_remaining -= 1; + const left_cv = self.cv_stack[cvs_remaining]; + const right_cv = output.chainingValue(); + output = parentOutputFromCvs(left_cv, right_cv, self.key, self.chunk.flags); + } + + output.rootBytes(seek, out); + } + + pub fn reset(self: *Blake3) void { + self.chunk.reset(self.key, 0); + self.cv_stack_len = 0; } }; // Use named type declarations to workaround crash with anonymous structs (issue #4373). const ReferenceTest = struct { - key: *const [KEY_LEN]u8, + key: *const [Blake3.key_length]u8, context_string: []const u8, cases: []const ReferenceTestCase, }; @@ -663,7 +1162,7 @@ fn testBlake3(hasher: *Blake3, input_len: usize, expected_hex: [262]u8) !void { // Compare to expected value var expected_bytes: [expected_hex.len / 2]u8 = undefined; _ = fmt.hexToBytes(expected_bytes[0..], expected_hex[0..]) catch unreachable; - try testing.expectEqual(actual_bytes, expected_bytes); + try std.testing.expectEqual(actual_bytes, expected_bytes); // Restore initial state hasher.* = initial_state; |
