diff options
| author | dweiller <4678790+dweiller@users.noreplay.github.com> | 2023-02-12 04:33:20 +1100 |
|---|---|---|
| committer | dweiller <4678790+dweiller@users.noreplay.github.com> | 2023-02-20 09:09:06 +1100 |
| commit | 373d8ef26edca9d16111ae41f960a44ead6ea2c8 (patch) | |
| tree | 32de189f254ee9b30ef799d48a6e4be1b043e180 /lib/std/compress | |
| parent | 1530e73648cd9687bbaea3e50da9b2e86d66df0c (diff) | |
| download | zig-373d8ef26edca9d16111ae41f960a44ead6ea2c8.tar.gz zig-373d8ef26edca9d16111ae41f960a44ead6ea2c8.zip | |
std.compress.zstandard: check FSE bitstreams are fully consumed
Diffstat (limited to 'lib/std/compress')
| -rw-r--r-- | lib/std/compress/zstandard/decode/block.zig | 36 | ||||
| -rw-r--r-- | lib/std/compress/zstandard/decode/huffman.zig | 4 | ||||
| -rw-r--r-- | lib/std/compress/zstandard/readers.zig | 8 |
3 files changed, 32 insertions, 16 deletions
diff --git a/lib/std/compress/zstandard/decode/block.zig b/lib/std/compress/zstandard/decode/block.zig index 8563f41614..7cc4c146ca 100644 --- a/lib/std/compress/zstandard/decode/block.zig +++ b/lib/std/compress/zstandard/decode/block.zig @@ -391,15 +391,21 @@ pub const DecodeState = struct { try self.literal_stream_reader.init(bytes); } + fn isLiteralStreamEmpty(self: *DecodeState) bool { + switch (self.literal_streams) { + .one => return self.literal_stream_reader.isEmpty(), + .four => return self.literal_stream_index == 3 and self.literal_stream_reader.isEmpty(), + } + } + const LiteralBitsError = error{ BitStreamHasNoStartBit, UnexpectedEndOfLiteralStream, }; fn readLiteralsBits( self: *DecodeState, - comptime T: type, bit_count_to_read: usize, - ) LiteralBitsError!T { + ) LiteralBitsError!u16 { return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: { if (self.literal_streams == .four and self.literal_stream_index < 3) { try self.nextLiteralMultiStream(); @@ -461,7 +467,7 @@ pub const DecodeState = struct { while (i < len) : (i += 1) { var prefix: u16 = 0; while (true) { - const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| { + const new_bits = self.readLiteralsBits(bit_count_to_read) catch |err| { return err; }; prefix <<= bit_count_to_read; @@ -533,7 +539,7 @@ pub const DecodeState = struct { while (i < len) : (i += 1) { var prefix: u16 = 0; while (true) { - const new_bits = try self.readLiteralsBits(u16, bit_count_to_read); + const new_bits = try self.readLiteralsBits(bit_count_to_read); prefix <<= bit_count_to_read; prefix |= new_bits; bits_read += bit_count_to_read; @@ -659,13 +665,10 @@ pub fn decodeBlock( sequence_size_limit -= decompressed_size; } - if (bit_stream.bit_reader.bit_count != 0) { + if (!bit_stream.isEmpty()) { return error.MalformedCompressedBlock; } - - bytes_read += bit_stream_bytes.len; } - if (bytes_read != block_size) return error.MalformedCompressedBlock; if (decode_state.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode_state.literal_written_count; @@ -675,7 +678,9 @@ pub fn decodeBlock( bytes_written += len; } - consumed_count.* += bytes_read; + if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock; + + consumed_count.* += block_size; return bytes_written; }, .reserved => return error.ReservedBlock, @@ -749,13 +754,10 @@ pub fn decodeBlockRingBuffer( sequence_size_limit -= decompressed_size; } - if (bit_stream.bit_reader.bit_count != 0) { + if (!bit_stream.isEmpty()) { return error.MalformedCompressedBlock; } - - bytes_read += bit_stream_bytes.len; } - if (bytes_read != block_size) return error.MalformedCompressedBlock; if (decode_state.literal_written_count < literals.header.regenerated_size) { const len = literals.header.regenerated_size - decode_state.literal_written_count; @@ -764,7 +766,9 @@ pub fn decodeBlockRingBuffer( bytes_written += len; } - consumed_count.* += bytes_read; + if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock; + + consumed_count.* += block_size; if (bytes_written > block_size_max) return error.BlockSizeOverMaximum; return bytes_written; }, @@ -837,7 +841,7 @@ pub fn decodeBlockReader( sequence_size_limit -= decompressed_size; bytes_written += decompressed_size; } - if (bit_stream.bit_reader.bit_count != 0) { + if (!bit_stream.isEmpty()) { return error.MalformedCompressedBlock; } } @@ -849,6 +853,8 @@ pub fn decodeBlockReader( bytes_written += len; } + if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock; + if (bytes_written > block_size_max) return error.BlockSizeOverMaximum; if (block_reader_limited.bytes_left != 0) return error.MalformedCompressedBlock; decode_state.literal_written_count = 0; diff --git a/lib/std/compress/zstandard/decode/huffman.zig b/lib/std/compress/zstandard/decode/huffman.zig index f5639e7721..c3bda380dd 100644 --- a/lib/std/compress/zstandard/decode/huffman.zig +++ b/lib/std/compress/zstandard/decode/huffman.zig @@ -86,6 +86,10 @@ fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entr odd_state = odd_data.baseline + odd_bits; } else return error.MalformedHuffmanTree; + if (!huff_bits.isEmpty()) { + return error.MalformedHuffmanTree; + } + return i + 1; // stream contains all but the last symbol } diff --git a/lib/std/compress/zstandard/readers.zig b/lib/std/compress/zstandard/readers.zig index 489f933310..98cac2ed80 100644 --- a/lib/std/compress/zstandard/readers.zig +++ b/lib/std/compress/zstandard/readers.zig @@ -36,7 +36,9 @@ pub const ReverseBitReader = struct { pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void { self.byte_reader = ReversedByteReader.init(bytes); self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader()); - while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {} + var i: usize = 0; + while (i < 8 and 0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) : (i += 1) {} + if (i == 8) return error.BitStreamHasNoStartBit; } pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U { @@ -50,6 +52,10 @@ pub const ReverseBitReader = struct { pub fn alignToByte(self: *@This()) void { self.bit_reader.alignToByte(); } + + pub fn isEmpty(self: ReverseBitReader) bool { + return self.byte_reader.remaining_bytes == 0 and self.bit_reader.bit_count == 0; + } }; pub fn BitReader(comptime Reader: type) type { |
