aboutsummaryrefslogtreecommitdiff
path: root/lib/std/compress/flate/Decompress.zig
diff options
context:
space:
mode:
Diffstat (limited to 'lib/std/compress/flate/Decompress.zig')
-rw-r--r--lib/std/compress/flate/Decompress.zig391
1 files changed, 156 insertions, 235 deletions
diff --git a/lib/std/compress/flate/Decompress.zig b/lib/std/compress/flate/Decompress.zig
index 47570beab2..d9b6a82a63 100644
--- a/lib/std/compress/flate/Decompress.zig
+++ b/lib/std/compress/flate/Decompress.zig
@@ -7,11 +7,10 @@ const Reader = std.Io.Reader;
const Container = flate.Container;
const Decompress = @This();
-const Token = @import("Token.zig");
+const token = @import("token.zig");
input: *Reader,
-next_bits: Bits,
-remaining_bits: std.math.Log2Int(Bits),
+consumed_bits: u3,
reader: Reader,
@@ -25,8 +24,6 @@ state: State,
err: ?Error,
-const Bits = usize;
-
const BlockType = enum(u2) {
stored = 0,
fixed = 1,
@@ -39,6 +36,8 @@ const State = union(enum) {
block_header,
stored_block: u16,
fixed_block,
+ fixed_block_literal: u8,
+ fixed_block_match: u16,
dynamic_block,
dynamic_block_literal: u8,
dynamic_block_match: u16,
@@ -87,8 +86,7 @@ pub fn init(input: *Reader, container: Container, buffer: []u8) Decompress {
.end = 0,
},
.input = input,
- .next_bits = 0,
- .remaining_bits = 0,
+ .consumed_bits = 0,
.container_metadata = .init(container),
.lit_dec = .{},
.dst_dec = .{},
@@ -183,27 +181,25 @@ fn streamIndirectInner(d: *Decompress) Reader.Error!usize {
return 0;
}
-fn decodeLength(self: *Decompress, code: u8) !u16 {
- if (code > 28) return error.InvalidCode;
- const ml = Token.matchLength(code);
- return if (ml.extra_bits == 0) // 0 - 5 extra bits
- ml.base
- else
- ml.base + try self.takeBitsRuntime(ml.extra_bits);
+fn decodeLength(self: *Decompress, code_int: u5) !u16 {
+ if (code_int > 28) return error.InvalidCode;
+ const l: token.LenCode = .fromInt(code_int);
+ const base = l.base();
+ const extra = l.extraBits();
+ return token.min_length + (base | try self.takeBits(extra));
}
-fn decodeDistance(self: *Decompress, code: u8) !u16 {
- if (code > 29) return error.InvalidCode;
- const md = Token.matchDistance(code);
- return if (md.extra_bits == 0) // 0 - 13 extra bits
- md.base
- else
- md.base + try self.takeBitsRuntime(md.extra_bits);
+fn decodeDistance(self: *Decompress, code_int: u5) !u16 {
+ if (code_int > 29) return error.InvalidCode;
+ const d: token.DistCode = .fromInt(code_int);
+ const base = d.base();
+ const extra = d.extraBits();
+ return token.min_distance + (base | try self.takeBits(extra));
}
-// Decode code length symbol to code length. Writes decoded length into
-// lens slice starting at position pos. Returns number of positions
-// advanced.
+/// Decode code length symbol to code length. Writes decoded length into
+/// lens slice starting at position pos. Returns number of positions
+/// advanced.
fn dynamicCodeLength(self: *Decompress, code: u16, lens: []u4, pos: usize) !usize {
if (pos >= lens.len)
return error.InvalidDynamicBlockHeader;
@@ -217,7 +213,7 @@ fn dynamicCodeLength(self: *Decompress, code: u16, lens: []u4, pos: usize) !usiz
16 => {
// Copy the previous code length 3 - 6 times.
// The next 2 bits indicate repeat length
- const n: u8 = @as(u8, try self.takeBits(u2)) + 3;
+ const n: u8 = @as(u8, try self.takeIntBits(u2)) + 3;
if (pos == 0 or pos + n > lens.len)
return error.InvalidDynamicBlockHeader;
for (0..n) |i| {
@@ -226,17 +222,17 @@ fn dynamicCodeLength(self: *Decompress, code: u16, lens: []u4, pos: usize) !usiz
return n;
},
// Repeat a code length of 0 for 3 - 10 times. (3 bits of length)
- 17 => return @as(u8, try self.takeBits(u3)) + 3,
+ 17 => return @as(u8, try self.takeIntBits(u3)) + 3,
// Repeat a code length of 0 for 11 - 138 times (7 bits of length)
- 18 => return @as(u8, try self.takeBits(u7)) + 11,
+ 18 => return @as(u8, try self.takeIntBits(u7)) + 11,
else => return error.InvalidDynamicBlockHeader,
}
}
fn decodeSymbol(self: *Decompress, decoder: anytype) !Symbol {
// Maximum code len is 15 bits.
- const sym = try decoder.find(@bitReverse(try self.peekBits(u15)));
- try self.tossBits(sym.code_bits);
+ const sym = try decoder.find(@bitReverse(try self.peekIntBitsShort(u15)));
+ try self.tossBitsShort(sym.code_bits);
return sym;
}
@@ -320,11 +316,11 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
.raw => continue :sw .block_header,
},
.block_header => {
- d.final_block = (try d.takeBits(u1)) != 0;
- const block_type: BlockType = @enumFromInt(try d.takeBits(u2));
+ d.final_block = (try d.takeIntBits(u1)) != 0;
+ const block_type: BlockType = @enumFromInt(try d.takeIntBits(u2));
switch (block_type) {
.stored => {
- d.alignBitsDiscarding();
+ d.alignBitsForward();
// everything after this is byte aligned in stored block
const len = try in.takeInt(u16, .little);
const nlen = try in.takeInt(u16, .little);
@@ -333,17 +329,17 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
},
.fixed => continue :sw .fixed_block,
.dynamic => {
- const hlit: u16 = @as(u16, try d.takeBits(u5)) + 257; // number of ll code entries present - 257
- const hdist: u16 = @as(u16, try d.takeBits(u5)) + 1; // number of distance code entries - 1
- const hclen: u8 = @as(u8, try d.takeBits(u4)) + 4; // hclen + 4 code lengths are encoded
+ const hlit: u16 = @as(u16, try d.takeIntBits(u5)) + 257; // number of ll code entries present - 257
+ const hdist: u16 = @as(u16, try d.takeIntBits(u5)) + 1; // number of distance code entries - 1
+ const hclen: u8 = @as(u8, try d.takeIntBits(u4)) + 4; // hclen + 4 code lengths are encoded
if (hlit > 286 or hdist > 30)
return error.InvalidDynamicBlockHeader;
// lengths for code lengths
var cl_lens: [19]u4 = @splat(0);
- for (flate.HuffmanEncoder.codegen_order[0..hclen]) |i| {
- cl_lens[i] = try d.takeBits(u3);
+ for (token.codegen_order[0..hclen]) |i| {
+ cl_lens[i] = try d.takeIntBits(u3);
}
var cl_dec: CodegenDecoder = .{};
try cl_dec.generate(&cl_lens);
@@ -352,9 +348,9 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
var dec_lens: [286 + 30]u4 = @splat(0);
var pos: usize = 0;
while (pos < hlit + hdist) {
- const peeked = @bitReverse(try d.peekBits(u7));
+ const peeked = @bitReverse(try d.peekIntBitsShort(u7));
const sym = try cl_dec.find(peeked);
- try d.tossBits(sym.code_bits);
+ try d.tossBitsShort(sym.code_bits);
pos += try d.dynamicCodeLength(sym.symbol, &dec_lens, pos);
}
if (pos > hlit + hdist) {
@@ -373,9 +369,12 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
}
},
.stored_block => |remaining_len| {
- const out = try w.writableSliceGreedyPreserve(flate.history_len, 1);
+ const out: []u8 = if (remaining != 0)
+ try w.writableSliceGreedyPreserve(flate.history_len, 1)
+ else
+ &.{};
var limited_out: [1][]u8 = .{limit.min(.limited(remaining_len)).slice(out)};
- const n = try d.input.readVec(&limited_out);
+ const n = try in.readVec(&limited_out);
if (remaining_len - n == 0) {
d.state = if (d.final_block) .protocol_footer else .block_header;
} else {
@@ -389,8 +388,14 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
const code = try d.readFixedCode();
switch (code) {
0...255 => {
- try w.writeBytePreserve(flate.history_len, @intCast(code));
- remaining -= 1;
+ if (remaining != 0) {
+ @branchHint(.likely);
+ try w.writeBytePreserve(flate.history_len, @intCast(code));
+ remaining -= 1;
+ } else {
+ d.state = .{ .fixed_block_literal = @intCast(code) };
+ return @intFromEnum(limit) - remaining;
+ }
},
256 => {
d.state = if (d.final_block) .protocol_footer else .block_header;
@@ -400,9 +405,7 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
// Handles fixed block non literal (length) code.
// Length code is followed by 5 bits of distance code.
const length = try d.decodeLength(@intCast(code - 257));
- const distance = try d.decodeDistance(@bitReverse(try d.takeBits(u5)));
- try writeMatch(w, length, distance);
- remaining -= length;
+ continue :sw .{ .fixed_block_match = length };
},
else => return error.InvalidCode,
}
@@ -410,6 +413,24 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
d.state = .fixed_block;
return @intFromEnum(limit) - remaining;
},
+ .fixed_block_literal => |symbol| {
+ assert(remaining != 0);
+ remaining -= 1;
+ try w.writeBytePreserve(flate.history_len, symbol);
+ continue :sw .fixed_block;
+ },
+ .fixed_block_match => |length| {
+ if (remaining >= length) {
+ @branchHint(.likely);
+ const distance = try d.decodeDistance(@bitReverse(try d.takeIntBits(u5)));
+ try writeMatch(w, length, distance);
+ remaining -= length;
+ continue :sw .fixed_block;
+ } else {
+ d.state = .{ .fixed_block_match = length };
+ return @intFromEnum(limit) - remaining;
+ }
+ },
.dynamic_block => {
// In larger archives most blocks are usually dynamic, so
// decompression performance depends on this logic.
@@ -429,7 +450,7 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
},
.match => {
// Decode match backreference <length, distance>
- const length = try d.decodeLength(sym.symbol);
+ const length = try d.decodeLength(@intCast(sym.symbol));
continue :sw .{ .dynamic_block_match = length };
},
.end_of_block => {
@@ -449,7 +470,7 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
@branchHint(.likely);
remaining -= length;
const dsm = try d.decodeSymbol(&d.dst_dec);
- const distance = try d.decodeDistance(dsm.symbol);
+ const distance = try d.decodeDistance(@intCast(dsm.symbol));
try writeMatch(w, length, distance);
continue :sw .dynamic_block;
} else {
@@ -458,23 +479,16 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
}
},
.protocol_footer => {
+ d.alignBitsForward();
switch (d.container_metadata) {
.gzip => |*gzip| {
- d.alignBitsDiscarding();
- gzip.* = .{
- .crc = try in.takeInt(u32, .little),
- .count = try in.takeInt(u32, .little),
- };
+ gzip.crc = try in.takeInt(u32, .little);
+ gzip.count = try in.takeInt(u32, .little);
},
.zlib => |*zlib| {
- d.alignBitsDiscarding();
- zlib.* = .{
- .adler = try in.takeInt(u32, .little),
- };
- },
- .raw => {
- d.alignBitsPreserving();
+ zlib.adler = try in.takeInt(u32, .big);
},
+ .raw => {},
}
d.state = .end;
return @intFromEnum(limit) - remaining;
@@ -487,10 +501,10 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
/// back from current write position, and `length` of bytes.
fn writeMatch(w: *Writer, length: u16, distance: u16) !void {
if (w.end < distance) return error.InvalidMatch;
- if (length < Token.base_length) return error.InvalidMatch;
- if (length > Token.max_length) return error.InvalidMatch;
- if (distance < Token.min_distance) return error.InvalidMatch;
- if (distance > Token.max_distance) return error.InvalidMatch;
+ if (length < token.min_length) return error.InvalidMatch;
+ if (length > token.max_length) return error.InvalidMatch;
+ if (distance < token.min_distance) return error.InvalidMatch;
+ if (distance > token.max_distance) return error.InvalidMatch;
// This is not a @memmove; it intentionally repeats patterns caused by
// iterating one byte at a time.
@@ -500,137 +514,71 @@ fn writeMatch(w: *Writer, length: u16, distance: u16) !void {
for (dest, src) |*d, s| d.* = s;
}
-fn takeBits(d: *Decompress, comptime U: type) !U {
- const remaining_bits = d.remaining_bits;
- const next_bits = d.next_bits;
- if (remaining_bits >= @bitSizeOf(U)) {
- const u: U = @truncate(next_bits);
- d.next_bits = next_bits >> @bitSizeOf(U);
- d.remaining_bits = remaining_bits - @bitSizeOf(U);
- return u;
- }
- const in = d.input;
- const next_int = in.takeInt(Bits, .little) catch |err| switch (err) {
- error.ReadFailed => return error.ReadFailed,
- error.EndOfStream => return takeBitsEnding(d, U),
+fn peekBits(d: *Decompress, n: u4) !u16 {
+ const bits = d.input.peekInt(u32, .little) catch |e| return switch (e) {
+ error.ReadFailed => error.ReadFailed,
+ error.EndOfStream => d.peekBitsEnding(n),
};
- const needed_bits = @bitSizeOf(U) - remaining_bits;
- const u: U = @intCast(((next_int & ((@as(Bits, 1) << needed_bits) - 1)) << remaining_bits) | next_bits);
- d.next_bits = next_int >> needed_bits;
- d.remaining_bits = @intCast(@bitSizeOf(Bits) - @as(usize, needed_bits));
- return u;
+ const mask = @shlExact(@as(u16, 1), n) - 1;
+ return @intCast((bits >> d.consumed_bits) & mask);
}
-fn takeBitsEnding(d: *Decompress, comptime U: type) !U {
- const remaining_bits = d.remaining_bits;
- const next_bits = d.next_bits;
- const in = d.input;
- const n = in.bufferedLen();
- assert(n < @sizeOf(Bits));
- const needed_bits = @bitSizeOf(U) - remaining_bits;
- if (n * 8 < needed_bits) return error.EndOfStream;
- const next_int = in.takeVarInt(Bits, .little, n) catch |err| switch (err) {
- error.ReadFailed => return error.ReadFailed,
- error.EndOfStream => unreachable,
- };
- const u: U = @intCast(((next_int & ((@as(Bits, 1) << needed_bits) - 1)) << remaining_bits) | next_bits);
- d.next_bits = next_int >> needed_bits;
- d.remaining_bits = @intCast(n * 8 - @as(usize, needed_bits));
- return u;
+fn peekBitsEnding(d: *Decompress, n: u4) !u16 {
+ @branchHint(.unlikely);
+
+ const left = d.input.buffered();
+ if (left.len * 8 - d.consumed_bits < n) return error.EndOfStream;
+ const bits = std.mem.readVarInt(u32, left, .little);
+ const mask = @shlExact(@as(u16, 1), n) - 1;
+ return @intCast((bits >> d.consumed_bits) & mask);
}
-fn peekBits(d: *Decompress, comptime U: type) !U {
- const remaining_bits = d.remaining_bits;
- const next_bits = d.next_bits;
- if (remaining_bits >= @bitSizeOf(U)) return @truncate(next_bits);
- const in = d.input;
- const next_int = in.peekInt(Bits, .little) catch |err| switch (err) {
- error.ReadFailed => return error.ReadFailed,
- error.EndOfStream => return peekBitsEnding(d, U),
- };
- const needed_bits = @bitSizeOf(U) - remaining_bits;
- return @intCast(((next_int & ((@as(Bits, 1) << needed_bits) - 1)) << remaining_bits) | next_bits);
+/// Safe only after `peekBits` has been called with a greater or equal `n` value.
+fn tossBits(d: *Decompress, n: u4) void {
+ d.input.toss((@as(u8, n) + d.consumed_bits) / 8);
+ d.consumed_bits +%= @truncate(n);
}
-fn peekBitsEnding(d: *Decompress, comptime U: type) !U {
- const remaining_bits = d.remaining_bits;
- const next_bits = d.next_bits;
- const in = d.input;
- var u: Bits = 0;
- var remaining_needed_bits = @bitSizeOf(U) - remaining_bits;
- var i: usize = 0;
- while (remaining_needed_bits > 0) {
- const peeked = in.peek(i + 1) catch |err| switch (err) {
- error.ReadFailed => return error.ReadFailed,
- error.EndOfStream => break,
- };
- u |= @as(Bits, peeked[i]) << @intCast(i * 8);
- remaining_needed_bits -|= 8;
- i += 1;
- }
- if (remaining_bits == 0 and i == 0) return error.EndOfStream;
- return @truncate((u << remaining_bits) | next_bits);
-}
-
-fn tossBits(d: *Decompress, n: u4) !void {
- const remaining_bits = d.remaining_bits;
- const next_bits = d.next_bits;
- if (remaining_bits >= n) {
- d.next_bits = next_bits >> n;
- d.remaining_bits = remaining_bits - n;
- } else {
- const in = d.input;
- const next_int = in.takeInt(Bits, .little) catch |err| switch (err) {
- error.ReadFailed => return error.ReadFailed,
- error.EndOfStream => return tossBitsEnding(d, n),
- };
- const needed_bits = n - remaining_bits;
- d.next_bits = next_int >> needed_bits;
- d.remaining_bits = @intCast(@bitSizeOf(Bits) - @as(usize, needed_bits));
- }
+fn takeBits(d: *Decompress, n: u4) !u16 {
+ const bits = try d.peekBits(n);
+ d.tossBits(n);
+ return bits;
}
-fn tossBitsEnding(d: *Decompress, n: u4) !void {
- const remaining_bits = d.remaining_bits;
- const in = d.input;
- const buffered_n = in.bufferedLen();
- if (buffered_n == 0) return error.EndOfStream;
- assert(buffered_n < @sizeOf(Bits));
- const needed_bits = n - remaining_bits;
- const next_int = in.takeVarInt(Bits, .little, buffered_n) catch |err| switch (err) {
- error.ReadFailed => return error.ReadFailed,
- error.EndOfStream => unreachable,
+fn alignBitsForward(d: *Decompress) void {
+ d.input.toss(@intFromBool(d.consumed_bits != 0));
+ d.consumed_bits = 0;
+}
+
+fn peekBitsShort(d: *Decompress, n: u4) !u16 {
+ const bits = d.input.peekInt(u32, .little) catch |e| return switch (e) {
+ error.ReadFailed => error.ReadFailed,
+ error.EndOfStream => d.peekBitsShortEnding(n),
};
- d.next_bits = next_int >> needed_bits;
- d.remaining_bits = @intCast(@as(usize, buffered_n) * 8 -| @as(usize, needed_bits));
+ const mask = @shlExact(@as(u16, 1), n) - 1;
+ return @intCast((bits >> d.consumed_bits) & mask);
}
-fn takeBitsRuntime(d: *Decompress, n: u4) !u16 {
- const x = try peekBits(d, u16);
- const mask: u16 = (@as(u16, 1) << n) - 1;
- const u: u16 = @as(u16, @truncate(x)) & mask;
- try tossBits(d, n);
- return u;
+fn peekBitsShortEnding(d: *Decompress, n: u4) !u16 {
+ @branchHint(.unlikely);
+
+ const left = d.input.buffered();
+ const bits = std.mem.readVarInt(u32, left, .little);
+ const mask = @shlExact(@as(u16, 1), n) - 1;
+ return @intCast((bits >> d.consumed_bits) & mask);
}
-fn alignBitsDiscarding(d: *Decompress) void {
- const remaining_bits = d.remaining_bits;
- if (remaining_bits == 0) return;
- const n_bytes = remaining_bits / 8;
- const in = d.input;
- in.seek -= n_bytes;
- d.remaining_bits = 0;
- d.next_bits = 0;
+fn tossBitsShort(d: *Decompress, n: u4) !void {
+ if (d.input.bufferedLen() * 8 + d.consumed_bits < n) return error.EndOfStream;
+ d.tossBits(n);
}
-fn alignBitsPreserving(d: *Decompress) void {
- const remaining_bits: usize = d.remaining_bits;
- if (remaining_bits == 0) return;
- const n_bytes = (remaining_bits + 7) / 8;
- const in = d.input;
- in.seek -= n_bytes;
- d.remaining_bits = 0;
- d.next_bits = 0;
+fn takeIntBits(d: *Decompress, T: type) !T {
+ return @intCast(try d.takeBits(@bitSizeOf(T)));
+}
+
+fn peekIntBitsShort(d: *Decompress, T: type) !T {
+ return @intCast(try d.peekBitsShort(@bitSizeOf(T)));
}
/// Reads first 7 bits, and then maybe 1 or 2 more to get full 7,8 or 9 bit code.
@@ -646,12 +594,12 @@ fn alignBitsPreserving(d: *Decompress) void {
/// 280 - 287 8 11000000 through
/// 11000111
fn readFixedCode(d: *Decompress) !u16 {
- const code7 = @bitReverse(try d.takeBits(u7));
+ const code7 = @bitReverse(try d.takeIntBits(u7));
return switch (code7) {
0...0b0010_111 => @as(u16, code7) + 256,
- 0b0010_111 + 1...0b1011_111 => (@as(u16, code7) << 1) + @as(u16, try d.takeBits(u1)) - 0b0011_0000,
- 0b1011_111 + 1...0b1100_011 => (@as(u16, code7 - 0b1100000) << 1) + try d.takeBits(u1) + 280,
- else => (@as(u16, code7 - 0b1100_100) << 2) + @as(u16, @bitReverse(try d.takeBits(u2))) + 144,
+ 0b0010_111 + 1...0b1011_111 => (@as(u16, code7) << 1) + @as(u16, try d.takeIntBits(u1)) - 0b0011_0000,
+ 0b1011_111 + 1...0b1100_011 => (@as(u16, code7 - 0b1100000) << 1) + try d.takeIntBits(u1) + 280,
+ else => (@as(u16, code7 - 0b1100_100) << 2) + @as(u16, @bitReverse(try d.takeIntBits(u2))) + 144,
};
}
@@ -807,7 +755,7 @@ fn HuffmanDecoder(
return self.findLinked(code, sym.next);
}
- inline fn findLinked(self: *Self, code: u16, start: u16) !Symbol {
+ fn findLinked(self: *Self, code: u16, start: u16) !Symbol {
var pos = start;
while (pos > 0) {
const sym = self.symbols[pos];
@@ -898,57 +846,30 @@ test "init/find" {
}
test "encode/decode literals" {
- var codes: [flate.HuffmanEncoder.max_num_frequencies]flate.HuffmanEncoder.Code = undefined;
- for (1..286) |j| { // for all different number of codes
- var enc: flate.HuffmanEncoder = .{
- .codes = &codes,
- .freq_cache = undefined,
- .bit_count = undefined,
- .lns = undefined,
- .lfs = undefined,
- };
- // create frequencies
- var freq = [_]u16{0} ** 286;
- freq[256] = 1; // ensure we have end of block code
- for (&freq, 1..) |*f, i| {
- if (i % j == 0)
- f.* = @intCast(i);
- }
-
- // encoder from frequencies
- enc.generate(&freq, 15);
-
- // get code_lens from encoder
- var code_lens = [_]u4{0} ** 286;
- for (code_lens, 0..) |_, i| {
- code_lens[i] = @intCast(enc.codes[i].len);
- }
- // generate decoder from code lens
- var dec: LiteralDecoder = .{};
- try dec.generate(&code_lens);
-
- // expect decoder code to match original encoder code
- for (dec.symbols) |s| {
- if (s.code_bits == 0) continue;
- const c_code: u16 = @bitReverse(@as(u15, @intCast(s.code)));
- const symbol: u16 = switch (s.kind) {
- .literal => s.symbol,
- .end_of_block => 256,
- .match => @as(u16, s.symbol) + 257,
- };
-
- const c = enc.codes[symbol];
- try testing.expect(c.code == c_code);
- }
-
- // find each symbol by code
- for (enc.codes) |c| {
- if (c.len == 0) continue;
-
- const s_code: u15 = @bitReverse(@as(u15, @intCast(c.code)));
- const s = try dec.find(s_code);
- try testing.expect(s.code == s_code);
- try testing.expect(s.code_bits == c.len);
+ // Check that the example in RFC 1951 section 3.2.2 works (plus some zeroes)
+ const max_bits = 5;
+ var decoder: HuffmanDecoder(16, max_bits, 3) = .{};
+ try decoder.generate(&.{ 3, 3, 3, 3, 0, 0, 3, 2, 4, 4 });
+
+ inline for (0.., .{
+ @as(u3, 0b010),
+ @as(u3, 0b011),
+ @as(u3, 0b100),
+ @as(u3, 0b101),
+ @as(u0, 0),
+ @as(u0, 0),
+ @as(u3, 0b110),
+ @as(u2, 0b00),
+ @as(u4, 0b1110),
+ @as(u4, 0b1111),
+ }) |i, code| {
+ const bits = @bitSizeOf(@TypeOf(code));
+ if (bits == 0) continue;
+ for (0..1 << (max_bits - bits)) |extra| {
+ const full = (@as(u16, code) << (max_bits - bits)) | @as(u16, @intCast(extra));
+ const symbol = try decoder.find(full);
+ try testing.expectEqual(i, symbol.symbol);
+ try testing.expectEqual(bits, symbol.code_bits);
}
}
}