diff options
| author | dweiller <4678790+dweiller@users.noreplay.github.com> | 2023-01-22 16:11:47 +1100 |
|---|---|---|
| committer | dweiller <4678790+dweiller@users.noreplay.github.com> | 2023-02-20 09:09:05 +1100 |
| commit | 05e63f241edb2199e91ce29c488e104dfb826935 (patch) | |
| tree | 9784750af821ddcc80364e841b58dbaf26136d4d /lib/std | |
| parent | 18091723d5afa8001e0fd71274dc4b74d601d0e1 (diff) | |
| download | zig-05e63f241edb2199e91ce29c488e104dfb826935.tar.gz zig-05e63f241edb2199e91ce29c488e104dfb826935.zip | |
std.compress.zstandard: add functions decoding into ring buffer
This supports decoding frames that do not declare the content size or
decoding in a streaming fashion.
Diffstat (limited to 'lib/std')
| -rw-r--r-- | lib/std/compress/zstandard/RingBuffer.zig | 81 | ||||
| -rw-r--r-- | lib/std/compress/zstandard/decompress.zig | 194 |
2 files changed, 272 insertions, 3 deletions
diff --git a/lib/std/compress/zstandard/RingBuffer.zig b/lib/std/compress/zstandard/RingBuffer.zig new file mode 100644 index 0000000000..1a369596cb --- /dev/null +++ b/lib/std/compress/zstandard/RingBuffer.zig @@ -0,0 +1,81 @@ +//! This ring buffer stores read and write indices while being able to utilise the full +//! backing slice by incrementing the indices modulo twice the slice's length and reducing +//! indices modulo the slice's length on slice access. This means that the bit of information +//! distinguishing whether the buffer is full or empty in an implementation utilising +//! and extra flag is stored in difference of the indices. + +const assert = @import("std").debug.assert; + +const RingBuffer = @This(); + +data: []u8, +read_index: usize, +write_index: usize, + +pub fn mask(self: RingBuffer, index: usize) usize { + return index % self.data.len; +} + +pub fn mask2(self: RingBuffer, index: usize) usize { + return index % (2 * self.data.len); +} + +pub fn write(self: *RingBuffer, byte: u8) !void { + if (self.isFull()) return error.Full; + self.writeAssumeCapacity(byte); +} + +pub fn writeAssumeCapacity(self: *RingBuffer, byte: u8) void { + self.data[self.mask(self.write_index)] = byte; + self.write_index = self.mask2(self.write_index + 1); +} + +pub fn writeSlice(self: *RingBuffer, bytes: []const u8) !void { + if (self.len() + bytes.len > self.data.len) return error.Full; + self.writeSliceAssumeCapacity(bytes); +} + +pub fn writeSliceAssumeCapacity(self: *RingBuffer, bytes: []const u8) void { + for (bytes) |b| self.writeAssumeCapacity(b); +} + +pub fn read(self: *RingBuffer) ?u8 { + if (self.isEmpty()) return null; + const byte = self.data[self.mask(self.read_index)]; + self.read_index = self.mask2(self.read_index + 1); + return byte; +} + +pub fn isEmpty(self: RingBuffer) bool { + return self.write_index == self.read_index; +} + +pub fn isFull(self: RingBuffer) bool { + return self.mask2(self.write_index + self.data.len) == self.read_index; +} + +pub fn len(self: RingBuffer) usize { + const adjusted_write_index = self.write_index + @boolToInt(self.write_index < self.read_index) * 2 * self.data.len; + return adjusted_write_index - self.read_index; +} + +const Slice = struct { + first: []u8, + second: []u8, +}; + +pub fn sliceAt(self: RingBuffer, start_unmasked: usize, length: usize) Slice { + assert(length <= self.data.len); + const slice1_start = self.mask(start_unmasked); + const slice1_end = @min(self.data.len, slice1_start + length); + const slice1 = self.data[slice1_start..slice1_end]; + const slice2 = self.data[0 .. length - slice1.len]; + return Slice{ + .first = slice1, + .second = slice2, + }; +} + +pub fn sliceLast(self: RingBuffer, length: usize) Slice { + return self.sliceAt(self.write_index + self.data.len - length, length); +} diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 9483b4d9d7..6e107c2d7b 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -6,6 +6,7 @@ const frame = types.frame; const Literals = types.compressed_block.Literals; const Sequences = types.compressed_block.Sequences; const Table = types.compressed_block.Table; +const RingBuffer = @import("RingBuffer.zig"); const readInt = std.mem.readIntLittle; const readIntSlice = std.mem.readIntSliceLittle; @@ -214,7 +215,7 @@ const DecodeState = struct { } fn executeSequenceSlice(self: *DecodeState, dest: []u8, write_pos: usize, literals: Literals, sequence: Sequence) !void { - try self.decodeLiteralsInto(dest[write_pos..], literals, sequence.literal_length); + try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length); // TODO: should we validate offset against max_window_size? assert(sequence.offset <= write_pos + sequence.literal_length); @@ -225,6 +226,15 @@ const DecodeState = struct { std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]); } + fn executeSequenceRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, sequence: Sequence) !void { + try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length); + // TODO: check that ring buffer window is full enough for match copies + const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length); + // TODO: would std.mem.copy and figuring out dest slice be better/faster? + for (copy_slice.first) |b| dest.writeAssumeCapacity(b); + for (copy_slice.second) |b| dest.writeAssumeCapacity(b); + } + fn decodeSequenceSlice( self: *DecodeState, dest: []u8, @@ -246,6 +256,31 @@ const DecodeState = struct { return sequence.match_length + sequence.literal_length; } + fn decodeSequenceRingBuffer( + self: *DecodeState, + dest: *RingBuffer, + literals: Literals, + bit_reader: anytype, + last_sequence: bool, + ) !usize { + const sequence = try self.nextSequence(bit_reader); + try self.executeSequenceRingBuffer(dest, literals, sequence); + if (std.options.log_level == .debug) { + const sequence_length = sequence.literal_length + sequence.match_length; + const written_slice = dest.sliceLast(sequence_length); + log.debug("sequence decompressed into '{x}{x}'", .{ + std.fmt.fmtSliceHexUpper(written_slice.first), + std.fmt.fmtSliceHexUpper(written_slice.second), + }); + } + if (!last_sequence) { + try self.updateState(.literal, bit_reader); + try self.updateState(.match, bit_reader); + try self.updateState(.offset, bit_reader); + } + return sequence.match_length + sequence.literal_length; + } + fn nextLiteralMultiStream(self: *DecodeState, literals: Literals) !void { self.literal_stream_index += 1; try self.initLiteralStream(literals.streams.four[self.literal_stream_index]); @@ -258,7 +293,7 @@ const DecodeState = struct { while (0 == try self.literal_stream_reader.readBitsNoEof(u1, 1)) {} } - fn decodeLiteralsInto(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void { + fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void { if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; switch (literals.header.block_type) { .raw => { @@ -327,6 +362,74 @@ const DecodeState = struct { } } + fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, len: usize) !void { + if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; + switch (literals.header.block_type) { + .raw => { + const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len]; + dest.writeSliceAssumeCapacity(literal_data); + self.literal_written_count += len; + }, + .rle => { + var i: usize = 0; + while (i < len) : (i += 1) { + dest.writeAssumeCapacity(literals.streams.one[0]); + } + self.literal_written_count += len; + }, + .compressed, .treeless => { + // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; + const huffman_tree = self.huffman_tree orelse unreachable; + const max_bit_count = huffman_tree.max_bit_count; + const starting_bit_count = Literals.HuffmanTree.weightToBitCount( + huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, + max_bit_count, + ); + var bits_read: u4 = 0; + var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one; + var bit_count_to_read: u4 = starting_bit_count; + var i: usize = 0; + while (i < len) : (i += 1) { + var prefix: u16 = 0; + while (true) { + const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err| + switch (err) { + error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: { + try self.nextLiteralMultiStream(literals); + break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read); + } else { + return error.UnexpectedEndOfLiteralStream; + }, + }; + prefix <<= bit_count_to_read; + prefix |= new_bits; + bits_read += bit_count_to_read; + const result = try huffman_tree.query(huffman_tree_index, prefix); + + switch (result) { + .symbol => |sym| { + dest.writeAssumeCapacity(sym); + bit_count_to_read = starting_bit_count; + bits_read = 0; + huffman_tree_index = huffman_tree.symbol_count_minus_one; + break; + }, + .index => |index| { + huffman_tree_index = index; + const bit_count = Literals.HuffmanTree.weightToBitCount( + huffman_tree.nodes[index].weight, + max_bit_count, + ); + bit_count_to_read = bit_count - bits_read; + }, + } + } + } + self.literal_written_count += len; + }, + } + } + fn getCode(self: *DecodeState, comptime choice: DataType) u32 { return switch (@field(self, @tagName(choice)).table) { .rle => |value| value, @@ -437,6 +540,14 @@ fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: return block_size; } +fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize { + log.debug("writing raw block - size {d}", .{block_size}); + const data = src[0..block_size]; + dest.writeSliceAssumeCapacity(data); + consumed_count.* += block_size; + return block_size; +} + fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize { log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size }); var write_pos: usize = 0; @@ -447,6 +558,16 @@ fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: return block_size; } +fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize { + log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size }); + var write_pos: usize = 0; + while (write_pos < block_size) : (write_pos += 1) { + dest.writeAssumeCapacity(src[0]); + } + consumed_count.* += 1; + return block_size; +} + fn prepareDecodeState( decode_state: *DecodeState, src: []const u8, @@ -545,7 +666,7 @@ pub fn decodeBlock( if (decode_state.literal_written_count < literals.header.regenerated_size) { log.debug("decoding remaining literals", .{}); const len = literals.header.regenerated_size - decode_state.literal_written_count; - try decode_state.decodeLiteralsInto(dest[written_count + bytes_written ..], literals, len); + try decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len); log.debug("remaining decoded literals at {d}: {}", .{ written_count, std.fmt.fmtSliceHexUpper(dest[written_count .. written_count + len]), @@ -562,6 +683,73 @@ pub fn decodeBlock( } } +pub fn decodeBlockRingBuffer( + dest: *RingBuffer, + src: []const u8, + block_header: frame.ZStandard.Block.Header, + decode_state: *DecodeState, + consumed_count: *usize, + block_size_maximum: usize, +) !usize { + const block_size = block_header.block_size; + if (block_size_maximum < block_size) return error.BlockSizeOverMaximum; + // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks) + switch (block_header.block_type) { + .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count), + .rle => return decodeRleBlockRingBuffer(dest, src, block_size, consumed_count), + .compressed => { + var bytes_read: usize = 0; + const literals = try decodeLiteralsSection(src, &bytes_read); + const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read); + + bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header); + + var bytes_written: usize = 0; + if (sequences_header.sequence_count > 0) { + const bit_stream_bytes = src[bytes_read..block_size]; + var reverse_byte_reader = reversedByteReader(bit_stream_bytes); + var bit_stream = reverseBitReader(reverse_byte_reader.reader()); + + while (0 == try bit_stream.readBitsNoEof(u1, 1)) {} + try decode_state.readInitialState(&bit_stream); + + var i: usize = 0; + while (i < sequences_header.sequence_count) : (i += 1) { + log.debug("decoding sequence {d}", .{i}); + const decompressed_size = try decode_state.decodeSequenceRingBuffer( + dest, + literals, + &bit_stream, + i == sequences_header.sequence_count - 1, + ); + bytes_written += decompressed_size; + } + + bytes_read += bit_stream_bytes.len; + } + + if (decode_state.literal_written_count < literals.header.regenerated_size) { + log.debug("decoding remaining literals", .{}); + const len = literals.header.regenerated_size - decode_state.literal_written_count; + try decode_state.decodeLiteralsRingBuffer(dest, literals, len); + const written_slice = dest.sliceLast(len); + log.debug("remaining decoded literals at {d}: {}{}", .{ + bytes_written, + std.fmt.fmtSliceHexUpper(written_slice.first), + std.fmt.fmtSliceHexUpper(written_slice.second), + }); + bytes_written += len; + } + + decode_state.literal_written_count = 0; + assert(bytes_read == block_header.block_size); + consumed_count.* += bytes_read; + return bytes_written; + }, + .reserved => return error.FrameContainsReservedBlock, + } +} + pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header { const magic = readInt(u32, src[0..4]); assert(isSkippableMagic(magic)); |
