diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2021-03-28 21:42:56 -0700 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2021-03-28 21:42:56 -0700 |
| commit | b85ef2300fa72f5f4c73b8eb9e14f0218ada592d (patch) | |
| tree | daee8ab81eaefb5433f6ba3750656ba769a311a4 /lib/std | |
| parent | 75080e351af8be45722bca50c1d5fcd503304d77 (diff) | |
| parent | 175adc0bd738c2e3a55bb71c6a53dcc920c203ba (diff) | |
| download | zig-b85ef2300fa72f5f4c73b8eb9e14f0218ada592d.tar.gz zig-b85ef2300fa72f5f4c73b8eb9e14f0218ada592d.zip | |
Merge remote-tracking branch 'origin/master' into llvm12
Diffstat (limited to 'lib/std')
45 files changed, 3169 insertions, 1337 deletions
diff --git a/lib/std/array_hash_map.zig b/lib/std/array_hash_map.zig index 7b0d9ea4dd..83a061dfef 100644 --- a/lib/std/array_hash_map.zig +++ b/lib/std/array_hash_map.zig @@ -687,8 +687,9 @@ pub fn ArrayHashMapUnmanaged( /// Removes the last inserted `Entry` in the hash map and returns it. pub fn pop(self: *Self) Entry { - const top = self.entries.pop(); + const top = self.entries.items[self.entries.items.len - 1]; _ = self.removeWithHash(top.key, top.hash, .index_only); + self.entries.items.len -= 1; return top; } @@ -1258,19 +1259,18 @@ test "pop" { var map = AutoArrayHashMap(i32, i32).init(std.testing.allocator); defer map.deinit(); - testing.expect((try map.fetchPut(1, 11)) == null); - testing.expect((try map.fetchPut(2, 22)) == null); - testing.expect((try map.fetchPut(3, 33)) == null); - testing.expect((try map.fetchPut(4, 44)) == null); + // Insert just enough entries so that the map expands. Afterwards, + // pop all entries out of the map. - const pop1 = map.pop(); - testing.expect(pop1.key == 4 and pop1.value == 44); - const pop2 = map.pop(); - testing.expect(pop2.key == 3 and pop2.value == 33); - const pop3 = map.pop(); - testing.expect(pop3.key == 2 and pop3.value == 22); - const pop4 = map.pop(); - testing.expect(pop4.key == 1 and pop4.value == 11); + var i: i32 = 0; + while (i < 9) : (i += 1) { + testing.expect((try map.fetchPut(i, i)) == null); + } + + while (i > 0) : (i -= 1) { + const pop = map.pop(); + testing.expect(pop.key == i - 1 and pop.value == i - 1); + } } test "reIndex" { diff --git a/lib/std/base64.zig b/lib/std/base64.zig index e6a780c239..4e7c9a696f 100644 --- a/lib/std/base64.zig +++ b/lib/std/base64.zig @@ -8,454 +8,452 @@ const assert = std.debug.assert; const testing = std.testing; const mem = std.mem; -pub const standard_alphabet_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -pub const standard_pad_char = '='; -pub const standard_encoder = Base64Encoder.init(standard_alphabet_chars, standard_pad_char); +pub const Error = error{ + InvalidCharacter, + InvalidPadding, + NoSpaceLeft, +}; + +/// Base64 codecs +pub const Codecs = struct { + alphabet_chars: [64]u8, + pad_char: ?u8, + decoderWithIgnore: fn (ignore: []const u8) Base64DecoderWithIgnore, + Encoder: Base64Encoder, + Decoder: Base64Decoder, +}; + +pub const standard_alphabet_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".*; +fn standardBase64DecoderWithIgnore(ignore: []const u8) Base64DecoderWithIgnore { + return Base64DecoderWithIgnore.init(standard_alphabet_chars, '=', ignore); +} + +/// Standard Base64 codecs, with padding +pub const standard = Codecs{ + .alphabet_chars = standard_alphabet_chars, + .pad_char = '=', + .decoderWithIgnore = standardBase64DecoderWithIgnore, + .Encoder = Base64Encoder.init(standard_alphabet_chars, '='), + .Decoder = Base64Decoder.init(standard_alphabet_chars, '='), +}; + +/// Standard Base64 codecs, without padding +pub const standard_no_pad = Codecs{ + .alphabet_chars = standard_alphabet_chars, + .pad_char = null, + .decoderWithIgnore = standardBase64DecoderWithIgnore, + .Encoder = Base64Encoder.init(standard_alphabet_chars, null), + .Decoder = Base64Decoder.init(standard_alphabet_chars, null), +}; + +pub const url_safe_alphabet_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_".*; +fn urlSafeBase64DecoderWithIgnore(ignore: []const u8) Base64DecoderWithIgnore { + return Base64DecoderWithIgnore.init(url_safe_alphabet_chars, null, ignore); +} + +/// URL-safe Base64 codecs, with padding +pub const url_safe = Codecs{ + .alphabet_chars = url_safe_alphabet_chars, + .pad_char = '=', + .decoderWithIgnore = urlSafeBase64DecoderWithIgnore, + .Encoder = Base64Encoder.init(url_safe_alphabet_chars, '='), + .Decoder = Base64Decoder.init(url_safe_alphabet_chars, '='), +}; + +/// URL-safe Base64 codecs, without padding +pub const url_safe_no_pad = Codecs{ + .alphabet_chars = url_safe_alphabet_chars, + .pad_char = null, + .decoderWithIgnore = urlSafeBase64DecoderWithIgnore, + .Encoder = Base64Encoder.init(url_safe_alphabet_chars, null), + .Decoder = Base64Decoder.init(url_safe_alphabet_chars, null), +}; + +// Backwards compatibility + +/// Deprecated - Use `standard.pad_char` +pub const standard_pad_char = standard.pad_char; +/// Deprecated - Use `standard.Encoder` +pub const standard_encoder = standard.Encoder; +/// Deprecated - Use `standard.Decoder` +pub const standard_decoder = standard.Decoder; pub const Base64Encoder = struct { - alphabet_chars: []const u8, - pad_char: u8, + alphabet_chars: [64]u8, + pad_char: ?u8, - /// a bunch of assertions, then simply pass the data right through. - pub fn init(alphabet_chars: []const u8, pad_char: u8) Base64Encoder { + /// A bunch of assertions, then simply pass the data right through. + pub fn init(alphabet_chars: [64]u8, pad_char: ?u8) Base64Encoder { assert(alphabet_chars.len == 64); var char_in_alphabet = [_]bool{false} ** 256; for (alphabet_chars) |c| { assert(!char_in_alphabet[c]); - assert(c != pad_char); + assert(pad_char == null or c != pad_char.?); char_in_alphabet[c] = true; } - return Base64Encoder{ .alphabet_chars = alphabet_chars, .pad_char = pad_char, }; } - /// ceil(source_len * 4/3) - pub fn calcSize(source_len: usize) usize { - return @divTrunc(source_len + 2, 3) * 4; + /// Compute the encoded length + pub fn calcSize(encoder: *const Base64Encoder, source_len: usize) usize { + if (encoder.pad_char != null) { + return @divTrunc(source_len + 2, 3) * 4; + } else { + const leftover = source_len % 3; + return @divTrunc(source_len, 3) * 4 + @divTrunc(leftover * 4 + 2, 3); + } } - /// dest.len must be what you get from ::calcSize. + /// dest.len must at least be what you get from ::calcSize. pub fn encode(encoder: *const Base64Encoder, dest: []u8, source: []const u8) []const u8 { - assert(dest.len >= Base64Encoder.calcSize(source.len)); - - var i: usize = 0; - var out_index: usize = 0; - while (i + 2 < source.len) : (i += 3) { - dest[out_index] = encoder.alphabet_chars[(source[i] >> 2) & 0x3f]; - out_index += 1; - - dest[out_index] = encoder.alphabet_chars[((source[i] & 0x3) << 4) | ((source[i + 1] & 0xf0) >> 4)]; - out_index += 1; - - dest[out_index] = encoder.alphabet_chars[((source[i + 1] & 0xf) << 2) | ((source[i + 2] & 0xc0) >> 6)]; - out_index += 1; - - dest[out_index] = encoder.alphabet_chars[source[i + 2] & 0x3f]; - out_index += 1; + const out_len = encoder.calcSize(source.len); + assert(dest.len >= out_len); + + const nibbles = source.len / 3; + const leftover = source.len - 3 * nibbles; + + var acc: u12 = 0; + var acc_len: u4 = 0; + var out_idx: usize = 0; + for (source) |v| { + acc = (acc << 8) + v; + acc_len += 8; + while (acc_len >= 6) { + acc_len -= 6; + dest[out_idx] = encoder.alphabet_chars[@truncate(u6, (acc >> acc_len))]; + out_idx += 1; + } } - - if (i < source.len) { - dest[out_index] = encoder.alphabet_chars[(source[i] >> 2) & 0x3f]; - out_index += 1; - - if (i + 1 == source.len) { - dest[out_index] = encoder.alphabet_chars[(source[i] & 0x3) << 4]; - out_index += 1; - - dest[out_index] = encoder.pad_char; - out_index += 1; - } else { - dest[out_index] = encoder.alphabet_chars[((source[i] & 0x3) << 4) | ((source[i + 1] & 0xf0) >> 4)]; - out_index += 1; - - dest[out_index] = encoder.alphabet_chars[(source[i + 1] & 0xf) << 2]; - out_index += 1; + if (acc_len > 0) { + dest[out_idx] = encoder.alphabet_chars[@truncate(u6, (acc << 6 - acc_len))]; + out_idx += 1; + } + if (encoder.pad_char) |pad_char| { + for (dest[out_idx..]) |*pad| { + pad.* = pad_char; } - - dest[out_index] = encoder.pad_char; - out_index += 1; } - return dest[0..out_index]; + return dest[0..out_len]; } }; -pub const standard_decoder = Base64Decoder.init(standard_alphabet_chars, standard_pad_char); - pub const Base64Decoder = struct { + const invalid_char: u8 = 0xff; + /// e.g. 'A' => 0. - /// undefined for any value not in the 64 alphabet chars. + /// `invalid_char` for any value not in the 64 alphabet chars. char_to_index: [256]u8, + pad_char: ?u8, - /// true only for the 64 chars in the alphabet, not the pad char. - char_in_alphabet: [256]bool, - pad_char: u8, - - pub fn init(alphabet_chars: []const u8, pad_char: u8) Base64Decoder { - assert(alphabet_chars.len == 64); - + pub fn init(alphabet_chars: [64]u8, pad_char: ?u8) Base64Decoder { var result = Base64Decoder{ - .char_to_index = undefined, - .char_in_alphabet = [_]bool{false} ** 256, + .char_to_index = [_]u8{invalid_char} ** 256, .pad_char = pad_char, }; + var char_in_alphabet = [_]bool{false} ** 256; for (alphabet_chars) |c, i| { - assert(!result.char_in_alphabet[c]); - assert(c != pad_char); + assert(!char_in_alphabet[c]); + assert(pad_char == null or c != pad_char.?); result.char_to_index[c] = @intCast(u8, i); - result.char_in_alphabet[c] = true; + char_in_alphabet[c] = true; } + return result; + } + /// Return the maximum possible decoded size for a given input length - The actual length may be less if the input includes padding. + /// `InvalidPadding` is returned if the input length is not valid. + pub fn calcSizeUpperBound(decoder: *const Base64Decoder, source_len: usize) Error!usize { + var result = source_len / 4 * 3; + const leftover = source_len % 4; + if (decoder.pad_char != null) { + if (leftover % 4 != 0) return error.InvalidPadding; + } else { + if (leftover % 4 == 1) return error.InvalidPadding; + result += leftover * 3 / 4; + } return result; } - /// If the encoded buffer is detected to be invalid, returns error.InvalidPadding. - pub fn calcSize(decoder: *const Base64Decoder, source: []const u8) !usize { - if (source.len % 4 != 0) return error.InvalidPadding; - return calcDecodedSizeExactUnsafe(source, decoder.pad_char); + /// Return the exact decoded size for a slice. + /// `InvalidPadding` is returned if the input length is not valid. + pub fn calcSizeForSlice(decoder: *const Base64Decoder, source: []const u8) Error!usize { + const source_len = source.len; + var result = try decoder.calcSizeUpperBound(source_len); + if (decoder.pad_char) |pad_char| { + if (source_len >= 1 and source[source_len - 1] == pad_char) result -= 1; + if (source_len >= 2 and source[source_len - 2] == pad_char) result -= 1; + } + return result; } /// dest.len must be what you get from ::calcSize. /// invalid characters result in error.InvalidCharacter. /// invalid padding results in error.InvalidPadding. - pub fn decode(decoder: *const Base64Decoder, dest: []u8, source: []const u8) !void { - assert(dest.len == (decoder.calcSize(source) catch unreachable)); - assert(source.len % 4 == 0); - - var src_cursor: usize = 0; - var dest_cursor: usize = 0; - - while (src_cursor < source.len) : (src_cursor += 4) { - if (!decoder.char_in_alphabet[source[src_cursor + 0]]) return error.InvalidCharacter; - if (!decoder.char_in_alphabet[source[src_cursor + 1]]) return error.InvalidCharacter; - if (src_cursor < source.len - 4 or source[src_cursor + 3] != decoder.pad_char) { - // common case - if (!decoder.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter; - if (!decoder.char_in_alphabet[source[src_cursor + 3]]) return error.InvalidCharacter; - dest[dest_cursor + 0] = decoder.char_to_index[source[src_cursor + 0]] << 2 | decoder.char_to_index[source[src_cursor + 1]] >> 4; - dest[dest_cursor + 1] = decoder.char_to_index[source[src_cursor + 1]] << 4 | decoder.char_to_index[source[src_cursor + 2]] >> 2; - dest[dest_cursor + 2] = decoder.char_to_index[source[src_cursor + 2]] << 6 | decoder.char_to_index[source[src_cursor + 3]]; - dest_cursor += 3; - } else if (source[src_cursor + 2] != decoder.pad_char) { - // one pad char - if (!decoder.char_in_alphabet[source[src_cursor + 2]]) return error.InvalidCharacter; - dest[dest_cursor + 0] = decoder.char_to_index[source[src_cursor + 0]] << 2 | decoder.char_to_index[source[src_cursor + 1]] >> 4; - dest[dest_cursor + 1] = decoder.char_to_index[source[src_cursor + 1]] << 4 | decoder.char_to_index[source[src_cursor + 2]] >> 2; - if (decoder.char_to_index[source[src_cursor + 2]] << 6 != 0) return error.InvalidPadding; - dest_cursor += 2; - } else { - // two pad chars - dest[dest_cursor + 0] = decoder.char_to_index[source[src_cursor + 0]] << 2 | decoder.char_to_index[source[src_cursor + 1]] >> 4; - if (decoder.char_to_index[source[src_cursor + 1]] << 4 != 0) return error.InvalidPadding; - dest_cursor += 1; + pub fn decode(decoder: *const Base64Decoder, dest: []u8, source: []const u8) Error!void { + if (decoder.pad_char != null and source.len % 4 != 0) return error.InvalidPadding; + var acc: u12 = 0; + var acc_len: u4 = 0; + var dest_idx: usize = 0; + var leftover_idx: ?usize = null; + for (source) |c, src_idx| { + const d = decoder.char_to_index[c]; + if (d == invalid_char) { + if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter; + leftover_idx = src_idx; + break; + } + acc = (acc << 6) + d; + acc_len += 6; + if (acc_len >= 8) { + acc_len -= 8; + dest[dest_idx] = @truncate(u8, acc >> acc_len); + dest_idx += 1; } } - - assert(src_cursor == source.len); - assert(dest_cursor == dest.len); + if (acc_len > 4 or (acc & (@as(u12, 1) << acc_len) - 1) != 0) { + return error.InvalidPadding; + } + if (leftover_idx == null) return; + var leftover = source[leftover_idx.?..]; + if (decoder.pad_char) |pad_char| { + const padding_len = acc_len / 2; + var padding_chars: usize = 0; + var i: usize = 0; + for (leftover) |c| { + if (c != pad_char) { + return if (c == Base64Decoder.invalid_char) error.InvalidCharacter else error.InvalidPadding; + } + padding_chars += 1; + } + if (padding_chars != padding_len) return error.InvalidPadding; + } } }; pub const Base64DecoderWithIgnore = struct { decoder: Base64Decoder, char_is_ignored: [256]bool, - pub fn init(alphabet_chars: []const u8, pad_char: u8, ignore_chars: []const u8) Base64DecoderWithIgnore { + + pub fn init(alphabet_chars: [64]u8, pad_char: ?u8, ignore_chars: []const u8) Base64DecoderWithIgnore { var result = Base64DecoderWithIgnore{ .decoder = Base64Decoder.init(alphabet_chars, pad_char), .char_is_ignored = [_]bool{false} ** 256, }; - for (ignore_chars) |c| { - assert(!result.decoder.char_in_alphabet[c]); + assert(result.decoder.char_to_index[c] == Base64Decoder.invalid_char); assert(!result.char_is_ignored[c]); assert(result.decoder.pad_char != c); result.char_is_ignored[c] = true; } - return result; } - /// If no characters end up being ignored or padding, this will be the exact decoded size. - pub fn calcSizeUpperBound(encoded_len: usize) usize { - return @divTrunc(encoded_len, 4) * 3; + /// Return the maximum possible decoded size for a given input length - The actual length may be less if the input includes padding + /// `InvalidPadding` is returned if the input length is not valid. + pub fn calcSizeUpperBound(decoder_with_ignore: *const Base64DecoderWithIgnore, source_len: usize) Error!usize { + var result = source_len / 4 * 3; + if (decoder_with_ignore.decoder.pad_char == null) { + const leftover = source_len % 4; + result += leftover * 3 / 4; + } + return result; } /// Invalid characters that are not ignored result in error.InvalidCharacter. /// Invalid padding results in error.InvalidPadding. - /// Decoding more data than can fit in dest results in error.OutputTooSmall. See also ::calcSizeUpperBound. + /// Decoding more data than can fit in dest results in error.NoSpaceLeft. See also ::calcSizeUpperBound. /// Returns the number of bytes written to dest. - pub fn decode(decoder_with_ignore: *const Base64DecoderWithIgnore, dest: []u8, source: []const u8) !usize { + pub fn decode(decoder_with_ignore: *const Base64DecoderWithIgnore, dest: []u8, source: []const u8) Error!usize { const decoder = &decoder_with_ignore.decoder; - - var src_cursor: usize = 0; - var dest_cursor: usize = 0; - - while (true) { - // get the next 4 chars, if available - var next_4_chars: [4]u8 = undefined; - var available_chars: usize = 0; - var pad_char_count: usize = 0; - while (available_chars < 4 and src_cursor < source.len) { - var c = source[src_cursor]; - src_cursor += 1; - - if (decoder.char_in_alphabet[c]) { - // normal char - next_4_chars[available_chars] = c; - available_chars += 1; - } else if (decoder_with_ignore.char_is_ignored[c]) { - // we're told to skip this one - continue; - } else if (c == decoder.pad_char) { - // the padding has begun. count the pad chars. - pad_char_count += 1; - while (src_cursor < source.len) { - c = source[src_cursor]; - src_cursor += 1; - if (c == decoder.pad_char) { - pad_char_count += 1; - if (pad_char_count > 2) return error.InvalidCharacter; - } else if (decoder_with_ignore.char_is_ignored[c]) { - // we can even ignore chars during the padding - continue; - } else return error.InvalidCharacter; - } - break; - } else return error.InvalidCharacter; + var acc: u12 = 0; + var acc_len: u4 = 0; + var dest_idx: usize = 0; + var leftover_idx: ?usize = null; + for (source) |c, src_idx| { + if (decoder_with_ignore.char_is_ignored[c]) continue; + const d = decoder.char_to_index[c]; + if (d == Base64Decoder.invalid_char) { + if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter; + leftover_idx = src_idx; + break; } - - switch (available_chars) { - 4 => { - // common case - if (dest_cursor + 3 > dest.len) return error.OutputTooSmall; - assert(pad_char_count == 0); - dest[dest_cursor + 0] = decoder.char_to_index[next_4_chars[0]] << 2 | decoder.char_to_index[next_4_chars[1]] >> 4; - dest[dest_cursor + 1] = decoder.char_to_index[next_4_chars[1]] << 4 | decoder.char_to_index[next_4_chars[2]] >> 2; - dest[dest_cursor + 2] = decoder.char_to_index[next_4_chars[2]] << 6 | decoder.char_to_index[next_4_chars[3]]; - dest_cursor += 3; - continue; - }, - 3 => { - if (dest_cursor + 2 > dest.len) return error.OutputTooSmall; - if (pad_char_count != 1) return error.InvalidPadding; - dest[dest_cursor + 0] = decoder.char_to_index[next_4_chars[0]] << 2 | decoder.char_to_index[next_4_chars[1]] >> 4; - dest[dest_cursor + 1] = decoder.char_to_index[next_4_chars[1]] << 4 | decoder.char_to_index[next_4_chars[2]] >> 2; - if (decoder.char_to_index[next_4_chars[2]] << 6 != 0) return error.InvalidPadding; - dest_cursor += 2; - break; - }, - 2 => { - if (dest_cursor + 1 > dest.len) return error.OutputTooSmall; - if (pad_char_count != 2) return error.InvalidPadding; - dest[dest_cursor + 0] = decoder.char_to_index[next_4_chars[0]] << 2 | decoder.char_to_index[next_4_chars[1]] >> 4; - if (decoder.char_to_index[next_4_chars[1]] << 4 != 0) return error.InvalidPadding; - dest_cursor += 1; - break; - }, - 1 => { - return error.InvalidPadding; - }, - 0 => { - if (pad_char_count != 0) return error.InvalidPadding; - break; - }, - else => unreachable, + acc = (acc << 6) + d; + acc_len += 6; + if (acc_len >= 8) { + if (dest_idx == dest.len) return error.NoSpaceLeft; + acc_len -= 8; + dest[dest_idx] = @truncate(u8, acc >> acc_len); + dest_idx += 1; } } - - assert(src_cursor == source.len); - - return dest_cursor; - } -}; - -pub const standard_decoder_unsafe = Base64DecoderUnsafe.init(standard_alphabet_chars, standard_pad_char); - -pub const Base64DecoderUnsafe = struct { - /// e.g. 'A' => 0. - /// undefined for any value not in the 64 alphabet chars. - char_to_index: [256]u8, - pad_char: u8, - - pub fn init(alphabet_chars: []const u8, pad_char: u8) Base64DecoderUnsafe { - assert(alphabet_chars.len == 64); - var result = Base64DecoderUnsafe{ - .char_to_index = undefined, - .pad_char = pad_char, - }; - for (alphabet_chars) |c, i| { - assert(c != pad_char); - result.char_to_index[c] = @intCast(u8, i); + if (acc_len > 4 or (acc & (@as(u12, 1) << acc_len) - 1) != 0) { + return error.InvalidPadding; } - return result; - } - - /// The source buffer must be valid. - pub fn calcSize(decoder: *const Base64DecoderUnsafe, source: []const u8) usize { - return calcDecodedSizeExactUnsafe(source, decoder.pad_char); - } - - /// dest.len must be what you get from ::calcDecodedSizeExactUnsafe. - /// invalid characters or padding will result in undefined values. - pub fn decode(decoder: *const Base64DecoderUnsafe, dest: []u8, source: []const u8) void { - assert(dest.len == decoder.calcSize(source)); - - var src_index: usize = 0; - var dest_index: usize = 0; - var in_buf_len: usize = source.len; - - while (in_buf_len > 0 and source[in_buf_len - 1] == decoder.pad_char) { - in_buf_len -= 1; + const padding_len = acc_len / 2; + if (leftover_idx == null) { + if (decoder.pad_char != null and padding_len != 0) return error.InvalidPadding; + return dest_idx; } - - while (in_buf_len > 4) { - dest[dest_index] = decoder.char_to_index[source[src_index + 0]] << 2 | decoder.char_to_index[source[src_index + 1]] >> 4; - dest_index += 1; - - dest[dest_index] = decoder.char_to_index[source[src_index + 1]] << 4 | decoder.char_to_index[source[src_index + 2]] >> 2; - dest_index += 1; - - dest[dest_index] = decoder.char_to_index[source[src_index + 2]] << 6 | decoder.char_to_index[source[src_index + 3]]; - dest_index += 1; - - src_index += 4; - in_buf_len -= 4; - } - - if (in_buf_len > 1) { - dest[dest_index] = decoder.char_to_index[source[src_index + 0]] << 2 | decoder.char_to_index[source[src_index + 1]] >> 4; - dest_index += 1; - } - if (in_buf_len > 2) { - dest[dest_index] = decoder.char_to_index[source[src_index + 1]] << 4 | decoder.char_to_index[source[src_index + 2]] >> 2; - dest_index += 1; - } - if (in_buf_len > 3) { - dest[dest_index] = decoder.char_to_index[source[src_index + 2]] << 6 | decoder.char_to_index[source[src_index + 3]]; - dest_index += 1; + var leftover = source[leftover_idx.?..]; + if (decoder.pad_char) |pad_char| { + var padding_chars: usize = 0; + var i: usize = 0; + for (leftover) |c| { + if (decoder_with_ignore.char_is_ignored[c]) continue; + if (c != pad_char) { + return if (c == Base64Decoder.invalid_char) error.InvalidCharacter else error.InvalidPadding; + } + padding_chars += 1; + } + if (padding_chars != padding_len) return error.InvalidPadding; } + return dest_idx; } }; -fn calcDecodedSizeExactUnsafe(source: []const u8, pad_char: u8) usize { - if (source.len == 0) return 0; - var result = @divExact(source.len, 4) * 3; - if (source[source.len - 1] == pad_char) { - result -= 1; - if (source[source.len - 2] == pad_char) { - result -= 1; - } - } - return result; -} - test "base64" { @setEvalBranchQuota(8000); testBase64() catch unreachable; - comptime (testBase64() catch unreachable); + comptime testAllApis(standard, "comptime", "Y29tcHRpbWU=") catch unreachable; +} + +test "base64 url_safe_no_pad" { + @setEvalBranchQuota(8000); + testBase64UrlSafeNoPad() catch unreachable; + comptime testAllApis(url_safe_no_pad, "comptime", "Y29tcHRpbWU") catch unreachable; } fn testBase64() !void { - try testAllApis("", ""); - try testAllApis("f", "Zg=="); - try testAllApis("fo", "Zm8="); - try testAllApis("foo", "Zm9v"); - try testAllApis("foob", "Zm9vYg=="); - try testAllApis("fooba", "Zm9vYmE="); - try testAllApis("foobar", "Zm9vYmFy"); - - try testDecodeIgnoreSpace("", " "); - try testDecodeIgnoreSpace("f", "Z g= ="); - try testDecodeIgnoreSpace("fo", " Zm8="); - try testDecodeIgnoreSpace("foo", "Zm9v "); - try testDecodeIgnoreSpace("foob", "Zm9vYg = = "); - try testDecodeIgnoreSpace("fooba", "Zm9v YmE="); - try testDecodeIgnoreSpace("foobar", " Z m 9 v Y m F y "); + const codecs = standard; + + try testAllApis(codecs, "", ""); + try testAllApis(codecs, "f", "Zg=="); + try testAllApis(codecs, "fo", "Zm8="); + try testAllApis(codecs, "foo", "Zm9v"); + try testAllApis(codecs, "foob", "Zm9vYg=="); + try testAllApis(codecs, "fooba", "Zm9vYmE="); + try testAllApis(codecs, "foobar", "Zm9vYmFy"); + + try testDecodeIgnoreSpace(codecs, "", " "); + try testDecodeIgnoreSpace(codecs, "f", "Z g= ="); + try testDecodeIgnoreSpace(codecs, "fo", " Zm8="); + try testDecodeIgnoreSpace(codecs, "foo", "Zm9v "); + try testDecodeIgnoreSpace(codecs, "foob", "Zm9vYg = = "); + try testDecodeIgnoreSpace(codecs, "fooba", "Zm9v YmE="); + try testDecodeIgnoreSpace(codecs, "foobar", " Z m 9 v Y m F y "); + + // test getting some api errors + try testError(codecs, "A", error.InvalidPadding); + try testError(codecs, "AA", error.InvalidPadding); + try testError(codecs, "AAA", error.InvalidPadding); + try testError(codecs, "A..A", error.InvalidCharacter); + try testError(codecs, "AA=A", error.InvalidPadding); + try testError(codecs, "AA/=", error.InvalidPadding); + try testError(codecs, "A/==", error.InvalidPadding); + try testError(codecs, "A===", error.InvalidPadding); + try testError(codecs, "====", error.InvalidPadding); + + try testNoSpaceLeftError(codecs, "AA=="); + try testNoSpaceLeftError(codecs, "AAA="); + try testNoSpaceLeftError(codecs, "AAAA"); + try testNoSpaceLeftError(codecs, "AAAAAA=="); +} + +fn testBase64UrlSafeNoPad() !void { + const codecs = url_safe_no_pad; + + try testAllApis(codecs, "", ""); + try testAllApis(codecs, "f", "Zg"); + try testAllApis(codecs, "fo", "Zm8"); + try testAllApis(codecs, "foo", "Zm9v"); + try testAllApis(codecs, "foob", "Zm9vYg"); + try testAllApis(codecs, "fooba", "Zm9vYmE"); + try testAllApis(codecs, "foobar", "Zm9vYmFy"); + + try testDecodeIgnoreSpace(codecs, "", " "); + try testDecodeIgnoreSpace(codecs, "f", "Z g "); + try testDecodeIgnoreSpace(codecs, "fo", " Zm8"); + try testDecodeIgnoreSpace(codecs, "foo", "Zm9v "); + try testDecodeIgnoreSpace(codecs, "foob", "Zm9vYg "); + try testDecodeIgnoreSpace(codecs, "fooba", "Zm9v YmE"); + try testDecodeIgnoreSpace(codecs, "foobar", " Z m 9 v Y m F y "); // test getting some api errors - try testError("A", error.InvalidPadding); - try testError("AA", error.InvalidPadding); - try testError("AAA", error.InvalidPadding); - try testError("A..A", error.InvalidCharacter); - try testError("AA=A", error.InvalidCharacter); - try testError("AA/=", error.InvalidPadding); - try testError("A/==", error.InvalidPadding); - try testError("A===", error.InvalidCharacter); - try testError("====", error.InvalidCharacter); - - try testOutputTooSmallError("AA=="); - try testOutputTooSmallError("AAA="); - try testOutputTooSmallError("AAAA"); - try testOutputTooSmallError("AAAAAA=="); + try testError(codecs, "A", error.InvalidPadding); + try testError(codecs, "AAA=", error.InvalidCharacter); + try testError(codecs, "A..A", error.InvalidCharacter); + try testError(codecs, "AA=A", error.InvalidCharacter); + try testError(codecs, "AA/=", error.InvalidCharacter); + try testError(codecs, "A/==", error.InvalidCharacter); + try testError(codecs, "A===", error.InvalidCharacter); + try testError(codecs, "====", error.InvalidCharacter); + + try testNoSpaceLeftError(codecs, "AA"); + try testNoSpaceLeftError(codecs, "AAA"); + try testNoSpaceLeftError(codecs, "AAAA"); + try testNoSpaceLeftError(codecs, "AAAAAA"); } -fn testAllApis(expected_decoded: []const u8, expected_encoded: []const u8) !void { +fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: []const u8) !void { // Base64Encoder { var buffer: [0x100]u8 = undefined; - const encoded = standard_encoder.encode(&buffer, expected_decoded); + const encoded = codecs.Encoder.encode(&buffer, expected_decoded); testing.expectEqualSlices(u8, expected_encoded, encoded); } // Base64Decoder { var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..try standard_decoder.calcSize(expected_encoded)]; - try standard_decoder.decode(decoded, expected_encoded); + var decoded = buffer[0..try codecs.Decoder.calcSizeForSlice(expected_encoded)]; + try codecs.Decoder.decode(decoded, expected_encoded); testing.expectEqualSlices(u8, expected_decoded, decoded); } // Base64DecoderWithIgnore { - const standard_decoder_ignore_nothing = Base64DecoderWithIgnore.init(standard_alphabet_chars, standard_pad_char, ""); + const decoder_ignore_nothing = codecs.decoderWithIgnore(""); var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..Base64DecoderWithIgnore.calcSizeUpperBound(expected_encoded.len)]; - var written = try standard_decoder_ignore_nothing.decode(decoded, expected_encoded); + var decoded = buffer[0..try decoder_ignore_nothing.calcSizeUpperBound(expected_encoded.len)]; + var written = try decoder_ignore_nothing.decode(decoded, expected_encoded); testing.expect(written <= decoded.len); testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]); } - - // Base64DecoderUnsafe - { - var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..standard_decoder_unsafe.calcSize(expected_encoded)]; - standard_decoder_unsafe.decode(decoded, expected_encoded); - testing.expectEqualSlices(u8, expected_decoded, decoded); - } } -fn testDecodeIgnoreSpace(expected_decoded: []const u8, encoded: []const u8) !void { - const standard_decoder_ignore_space = Base64DecoderWithIgnore.init(standard_alphabet_chars, standard_pad_char, " "); +fn testDecodeIgnoreSpace(codecs: Codecs, expected_decoded: []const u8, encoded: []const u8) !void { + const decoder_ignore_space = codecs.decoderWithIgnore(" "); var buffer: [0x100]u8 = undefined; - var decoded = buffer[0..Base64DecoderWithIgnore.calcSizeUpperBound(encoded.len)]; - var written = try standard_decoder_ignore_space.decode(decoded, encoded); + var decoded = buffer[0..try decoder_ignore_space.calcSizeUpperBound(encoded.len)]; + var written = try decoder_ignore_space.decode(decoded, encoded); testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]); } -fn testError(encoded: []const u8, expected_err: anyerror) !void { - const standard_decoder_ignore_space = Base64DecoderWithIgnore.init(standard_alphabet_chars, standard_pad_char, " "); +fn testError(codecs: Codecs, encoded: []const u8, expected_err: anyerror) !void { + const decoder_ignore_space = codecs.decoderWithIgnore(" "); var buffer: [0x100]u8 = undefined; - if (standard_decoder.calcSize(encoded)) |decoded_size| { + if (codecs.Decoder.calcSizeForSlice(encoded)) |decoded_size| { var decoded = buffer[0..decoded_size]; - if (standard_decoder.decode(decoded, encoded)) |_| { + if (codecs.Decoder.decode(decoded, encoded)) |_| { return error.ExpectedError; } else |err| if (err != expected_err) return err; } else |err| if (err != expected_err) return err; - if (standard_decoder_ignore_space.decode(buffer[0..], encoded)) |_| { + if (decoder_ignore_space.decode(buffer[0..], encoded)) |_| { return error.ExpectedError; } else |err| if (err != expected_err) return err; } -fn testOutputTooSmallError(encoded: []const u8) !void { - const standard_decoder_ignore_space = Base64DecoderWithIgnore.init(standard_alphabet_chars, standard_pad_char, " "); +fn testNoSpaceLeftError(codecs: Codecs, encoded: []const u8) !void { + const decoder_ignore_space = codecs.decoderWithIgnore(" "); var buffer: [0x100]u8 = undefined; - var decoded = buffer[0 .. calcDecodedSizeExactUnsafe(encoded, standard_pad_char) - 1]; - if (standard_decoder_ignore_space.decode(decoded, encoded)) |_| { + var decoded = buffer[0 .. (try codecs.Decoder.calcSizeForSlice(encoded)) - 1]; + if (decoder_ignore_space.decode(decoded, encoded)) |_| { return error.ExpectedError; - } else |err| if (err != error.OutputTooSmall) return err; + } else |err| if (err != error.NoSpaceLeft) return err; } diff --git a/lib/std/bit_set.zig b/lib/std/bit_set.zig index 29ad0d7963..80cdd5c79c 100644 --- a/lib/std/bit_set.zig +++ b/lib/std/bit_set.zig @@ -176,7 +176,7 @@ pub fn IntegerBitSet(comptime size: u16) type { /// The default options (.{}) will iterate indices of set bits in /// ascending order. Modifications to the underlying bit set may /// or may not be observed by the iterator. - pub fn iterator(self: *const Self, comptime options: IteratorOptions) Iterator(options.direction) { + pub fn iterator(self: *const Self, comptime options: IteratorOptions) Iterator(options) { return .{ .bits_remain = switch (options.kind) { .set => self.mask, @@ -185,7 +185,11 @@ pub fn IntegerBitSet(comptime size: u16) type { }; } - fn Iterator(comptime direction: IteratorOptions.Direction) type { + pub fn Iterator(comptime options: IteratorOptions) type { + return SingleWordIterator(options.direction); + } + + fn SingleWordIterator(comptime direction: IteratorOptions.Direction) type { return struct { const IterSelf = @This(); // all bits which have not yet been iterated over @@ -425,8 +429,12 @@ pub fn ArrayBitSet(comptime MaskIntType: type, comptime size: usize) type { /// The default options (.{}) will iterate indices of set bits in /// ascending order. Modifications to the underlying bit set may /// or may not be observed by the iterator. - pub fn iterator(self: *const Self, comptime options: IteratorOptions) BitSetIterator(MaskInt, options) { - return BitSetIterator(MaskInt, options).init(&self.masks, last_item_mask); + pub fn iterator(self: *const Self, comptime options: IteratorOptions) Iterator(options) { + return Iterator(options).init(&self.masks, last_item_mask); + } + + pub fn Iterator(comptime options: IteratorOptions) type { + return BitSetIterator(MaskInt, options); } fn maskBit(index: usize) MaskInt { @@ -700,11 +708,15 @@ pub const DynamicBitSetUnmanaged = struct { /// ascending order. Modifications to the underlying bit set may /// or may not be observed by the iterator. Resizing the underlying /// bit set invalidates the iterator. - pub fn iterator(self: *const Self, comptime options: IteratorOptions) BitSetIterator(MaskInt, options) { + pub fn iterator(self: *const Self, comptime options: IteratorOptions) Iterator(options) { const num_masks = numMasks(self.bit_length); const padding_bits = num_masks * @bitSizeOf(MaskInt) - self.bit_length; const last_item_mask = (~@as(MaskInt, 0)) >> @intCast(ShiftInt, padding_bits); - return BitSetIterator(MaskInt, options).init(self.masks[0..num_masks], last_item_mask); + return Iterator(options).init(self.masks[0..num_masks], last_item_mask); + } + + pub fn Iterator(comptime options: IteratorOptions) type { + return BitSetIterator(MaskInt, options); } fn maskBit(index: usize) MaskInt { @@ -858,9 +870,11 @@ pub const DynamicBitSet = struct { /// ascending order. Modifications to the underlying bit set may /// or may not be observed by the iterator. Resizing the underlying /// bit set invalidates the iterator. - pub fn iterator(self: *Self, comptime options: IteratorOptions) BitSetIterator(MaskInt, options) { + pub fn iterator(self: *Self, comptime options: IteratorOptions) Iterator(options) { return self.unmanaged.iterator(options); } + + pub const Iterator = DynamicBitSetUnmanaged.Iterator; }; /// Options for configuring an iterator over a bit set diff --git a/lib/std/build.zig b/lib/std/build.zig index efeea4adb7..825312755f 100644 --- a/lib/std/build.zig +++ b/lib/std/build.zig @@ -51,7 +51,7 @@ pub const Builder = struct { default_step: *Step, env_map: *BufMap, top_level_steps: ArrayList(*TopLevelStep), - install_prefix: ?[]const u8, + install_prefix: []const u8, dest_dir: ?[]const u8, lib_dir: []const u8, exe_dir: []const u8, @@ -156,7 +156,7 @@ pub const Builder = struct { .default_step = undefined, .env_map = env_map, .search_prefixes = ArrayList([]const u8).init(allocator), - .install_prefix = null, + .install_prefix = undefined, .lib_dir = undefined, .exe_dir = undefined, .h_dir = undefined, @@ -190,22 +190,13 @@ pub const Builder = struct { } /// This function is intended to be called by std/special/build_runner.zig, not a build.zig file. - pub fn setInstallPrefix(self: *Builder, optional_prefix: ?[]const u8) void { - self.install_prefix = optional_prefix; - } - - /// This function is intended to be called by std/special/build_runner.zig, not a build.zig file. - pub fn resolveInstallPrefix(self: *Builder) void { + pub fn resolveInstallPrefix(self: *Builder, install_prefix: ?[]const u8) void { if (self.dest_dir) |dest_dir| { - const install_prefix = self.install_prefix orelse "/usr"; - self.install_path = fs.path.join(self.allocator, &[_][]const u8{ dest_dir, install_prefix }) catch unreachable; + self.install_prefix = install_prefix orelse "/usr"; + self.install_path = fs.path.join(self.allocator, &[_][]const u8{ dest_dir, self.install_prefix }) catch unreachable; } else { - const install_prefix = self.install_prefix orelse blk: { - const p = self.cache_root; - self.install_prefix = p; - break :blk p; - }; - self.install_path = install_prefix; + self.install_prefix = install_prefix orelse self.cache_root; + self.install_path = self.install_prefix; } self.lib_dir = fs.path.join(self.allocator, &[_][]const u8{ self.install_path, "lib" }) catch unreachable; self.exe_dir = fs.path.join(self.allocator, &[_][]const u8{ self.install_path, "bin" }) catch unreachable; diff --git a/lib/std/c.zig b/lib/std/c.zig index bd0ce04f75..01247ffc00 100644 --- a/lib/std/c.zig +++ b/lib/std/c.zig @@ -295,9 +295,9 @@ pub extern "c" fn kevent( ) c_int; pub extern "c" fn getaddrinfo( - noalias node: [*:0]const u8, - noalias service: [*:0]const u8, - noalias hints: *const addrinfo, + noalias node: ?[*:0]const u8, + noalias service: ?[*:0]const u8, + noalias hints: ?*const addrinfo, noalias res: **addrinfo, ) EAI; diff --git a/lib/std/c/builtins.zig b/lib/std/c/builtins.zig index 2c03c1ceac..99721a150c 100644 --- a/lib/std/c/builtins.zig +++ b/lib/std/c/builtins.zig @@ -140,7 +140,7 @@ pub fn __builtin_object_size(ptr: ?*const c_void, ty: c_int) callconv(.Inline) u // If it is not possible to determine which objects ptr points to at compile time, // __builtin_object_size should return (size_t) -1 for type 0 or 1 and (size_t) 0 // for type 2 or 3. - if (ty == 0 or ty == 1) return @bitCast(usize, -@as(c_long, 1)); + if (ty == 0 or ty == 1) return @bitCast(usize, -@as(isize, 1)); if (ty == 2 or ty == 3) return 0; unreachable; } @@ -188,3 +188,9 @@ pub fn __builtin_memcpy( pub fn __builtin_expect(expr: c_long, c: c_long) callconv(.Inline) c_long { return expr; } + +// __builtin_alloca_with_align is not currently implemented. +// It is used in a run-translated-c test and a test-translate-c test to ensure that non-implemented +// builtins are correctly demoted. If you implement __builtin_alloca_with_align, please update the +// run-translated-c test and the test-translate-c test to use a different non-implemented builtin. +// pub fn __builtin_alloca_with_align(size: usize, alignment: usize) callconv(.Inline) *c_void {} diff --git a/lib/std/crypto.zig b/lib/std/crypto.zig index 24ca549479..457b9130d9 100644 --- a/lib/std/crypto.zig +++ b/lib/std/crypto.zig @@ -24,8 +24,12 @@ pub const aead = struct { pub const Gimli = @import("crypto/gimli.zig").Aead; pub const chacha_poly = struct { - pub const ChaCha20Poly1305 = @import("crypto/chacha20.zig").Chacha20Poly1305; - pub const XChaCha20Poly1305 = @import("crypto/chacha20.zig").XChacha20Poly1305; + pub const ChaCha20Poly1305 = @import("crypto/chacha20.zig").ChaCha20Poly1305; + pub const ChaCha12Poly1305 = @import("crypto/chacha20.zig").ChaCha12Poly1305; + pub const ChaCha8Poly1305 = @import("crypto/chacha20.zig").ChaCha8Poly1305; + pub const XChaCha20Poly1305 = @import("crypto/chacha20.zig").XChaCha20Poly1305; + pub const XChaCha12Poly1305 = @import("crypto/chacha20.zig").XChaCha12Poly1305; + pub const XChaCha8Poly1305 = @import("crypto/chacha20.zig").XChaCha8Poly1305; }; pub const isap = @import("crypto/isap.zig"); @@ -119,8 +123,14 @@ pub const sign = struct { pub const stream = struct { pub const chacha = struct { pub const ChaCha20IETF = @import("crypto/chacha20.zig").ChaCha20IETF; + pub const ChaCha12IETF = @import("crypto/chacha20.zig").ChaCha12IETF; + pub const ChaCha8IETF = @import("crypto/chacha20.zig").ChaCha8IETF; pub const ChaCha20With64BitNonce = @import("crypto/chacha20.zig").ChaCha20With64BitNonce; + pub const ChaCha12With64BitNonce = @import("crypto/chacha20.zig").ChaCha12With64BitNonce; + pub const ChaCha8With64BitNonce = @import("crypto/chacha20.zig").ChaCha8With64BitNonce; pub const XChaCha20IETF = @import("crypto/chacha20.zig").XChaCha20IETF; + pub const XChaCha12IETF = @import("crypto/chacha20.zig").XChaCha12IETF; + pub const XChaCha8IETF = @import("crypto/chacha20.zig").XChaCha8IETF; }; pub const salsa = struct { @@ -144,6 +154,8 @@ pub const random = &@import("crypto/tlcsprng.zig").interface; const std = @import("std.zig"); +pub const Error = @import("crypto/error.zig").Error; + test "crypto" { const please_windows_dont_oom = std.Target.current.os.tag == .windows; if (please_windows_dont_oom) return error.SkipZigTest; @@ -151,7 +163,9 @@ test "crypto" { inline for (std.meta.declarations(@This())) |decl| { switch (decl.data) { .Type => |t| { - std.testing.refAllDecls(t); + if (@typeInfo(t) != .ErrorSet) { + std.testing.refAllDecls(t); + } }, .Var => |v| { _ = v; diff --git a/lib/std/crypto/25519/curve25519.zig b/lib/std/crypto/25519/curve25519.zig index e01b024360..d3e51ad0e0 100644 --- a/lib/std/crypto/25519/curve25519.zig +++ b/lib/std/crypto/25519/curve25519.zig @@ -4,6 +4,7 @@ // The MIT license requires this copyright notice to be included in all copies // and substantial portions of the software. const std = @import("std"); +const Error = std.crypto.Error; /// Group operations over Curve25519. pub const Curve25519 = struct { @@ -28,12 +29,12 @@ pub const Curve25519 = struct { pub const basePoint = Curve25519{ .x = Fe.curve25519BasePoint }; /// Check that the encoding of a Curve25519 point is canonical. - pub fn rejectNonCanonical(s: [32]u8) !void { + pub fn rejectNonCanonical(s: [32]u8) Error!void { return Fe.rejectNonCanonical(s, false); } /// Reject the neutral element. - pub fn rejectIdentity(p: Curve25519) !void { + pub fn rejectIdentity(p: Curve25519) Error!void { if (p.x.isZero()) { return error.IdentityElement; } @@ -44,7 +45,7 @@ pub const Curve25519 = struct { return p.dbl().dbl().dbl(); } - fn ladder(p: Curve25519, s: [32]u8, comptime bits: usize) !Curve25519 { + fn ladder(p: Curve25519, s: [32]u8, comptime bits: usize) Error!Curve25519 { var x1 = p.x; var x2 = Fe.one; var z2 = Fe.zero; @@ -85,7 +86,7 @@ pub const Curve25519 = struct { /// way to use Curve25519 for a DH operation. /// Return error.IdentityElement if the resulting point is /// the identity element. - pub fn clampedMul(p: Curve25519, s: [32]u8) !Curve25519 { + pub fn clampedMul(p: Curve25519, s: [32]u8) Error!Curve25519 { var t: [32]u8 = s; scalar.clamp(&t); return try ladder(p, t, 255); @@ -95,14 +96,14 @@ pub const Curve25519 = struct { /// Return error.IdentityElement if the resulting point is /// the identity element or error.WeakPublicKey if the public /// key is a low-order point. - pub fn mul(p: Curve25519, s: [32]u8) !Curve25519 { + pub fn mul(p: Curve25519, s: [32]u8) Error!Curve25519 { const cofactor = [_]u8{8} ++ [_]u8{0} ** 31; _ = ladder(p, cofactor, 4) catch |_| return error.WeakPublicKey; return try ladder(p, s, 256); } /// Compute the Curve25519 equivalent to an Edwards25519 point. - pub fn fromEdwards25519(p: std.crypto.ecc.Edwards25519) !Curve25519 { + pub fn fromEdwards25519(p: std.crypto.ecc.Edwards25519) Error!Curve25519 { try p.clearCofactor().rejectIdentity(); const one = std.crypto.ecc.Edwards25519.Fe.one; const x = one.add(p.y).mul(one.sub(p.y).invert()); // xMont=(1+yEd)/(1-yEd) diff --git a/lib/std/crypto/25519/ed25519.zig b/lib/std/crypto/25519/ed25519.zig index 06a4826f58..e385e34f12 100644 --- a/lib/std/crypto/25519/ed25519.zig +++ b/lib/std/crypto/25519/ed25519.zig @@ -8,7 +8,8 @@ const crypto = std.crypto; const debug = std.debug; const fmt = std.fmt; const mem = std.mem; -const Sha512 = std.crypto.hash.sha2.Sha512; +const Sha512 = crypto.hash.sha2.Sha512; +const Error = crypto.Error; /// Ed25519 (EdDSA) signatures. pub const Ed25519 = struct { @@ -40,7 +41,7 @@ pub const Ed25519 = struct { /// /// For this reason, an EdDSA secret key is commonly called a seed, /// from which the actual secret is derived. - pub fn create(seed: ?[seed_length]u8) !KeyPair { + pub fn create(seed: ?[seed_length]u8) Error!KeyPair { const ss = seed orelse ss: { var random_seed: [seed_length]u8 = undefined; crypto.random.bytes(&random_seed); @@ -71,7 +72,7 @@ pub const Ed25519 = struct { /// Sign a message using a key pair, and optional random noise. /// Having noise creates non-standard, non-deterministic signatures, /// but has been proven to increase resilience against fault attacks. - pub fn sign(msg: []const u8, key_pair: KeyPair, noise: ?[noise_length]u8) ![signature_length]u8 { + pub fn sign(msg: []const u8, key_pair: KeyPair, noise: ?[noise_length]u8) Error![signature_length]u8 { const seed = key_pair.secret_key[0..seed_length]; const public_key = key_pair.secret_key[seed_length..]; if (!mem.eql(u8, public_key, &key_pair.public_key)) { @@ -111,8 +112,8 @@ pub const Ed25519 = struct { } /// Verify an Ed25519 signature given a message and a public key. - /// Returns error.InvalidSignature is the signature verification failed. - pub fn verify(sig: [signature_length]u8, msg: []const u8, public_key: [public_length]u8) !void { + /// Returns error.SignatureVerificationFailed is the signature verification failed. + pub fn verify(sig: [signature_length]u8, msg: []const u8, public_key: [public_length]u8) Error!void { const r = sig[0..32]; const s = sig[32..64]; try Curve.scalar.rejectNonCanonical(s.*); @@ -133,7 +134,7 @@ pub const Ed25519 = struct { const ah = try a.neg().mulPublic(hram); const sb_ah = (try Curve.basePoint.mulPublic(s.*)).add(ah); if (expected_r.sub(sb_ah).clearCofactor().rejectIdentity()) |_| { - return error.InvalidSignature; + return error.SignatureVerificationFailed; } else |_| {} } @@ -145,7 +146,7 @@ pub const Ed25519 = struct { }; /// Verify several signatures in a single operation, much faster than verifying signatures one-by-one - pub fn verifyBatch(comptime count: usize, signature_batch: [count]BatchElement) !void { + pub fn verifyBatch(comptime count: usize, signature_batch: [count]BatchElement) Error!void { var r_batch: [count][32]u8 = undefined; var s_batch: [count][32]u8 = undefined; var a_batch: [count]Curve = undefined; @@ -200,7 +201,7 @@ pub const Ed25519 = struct { const zsb = try Curve.basePoint.mulPublic(zs_sum); if (zr.add(zah).sub(zsb).rejectIdentity()) |_| { - return error.InvalidSignature; + return error.SignatureVerificationFailed; } else |_| {} } }; @@ -223,7 +224,7 @@ test "ed25519 signature" { var buf: [128]u8 = undefined; std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&sig)}), "10A442B4A80CC4225B154F43BEF28D2472CA80221951262EB8E0DF9091575E2687CC486E77263C3418C757522D54F84B0359236ABBBD4ACD20DC297FDCA66808"); try Ed25519.verify(sig, "test", key_pair.public_key); - std.testing.expectError(error.InvalidSignature, Ed25519.verify(sig, "TEST", key_pair.public_key)); + std.testing.expectError(error.SignatureVerificationFailed, Ed25519.verify(sig, "TEST", key_pair.public_key)); } test "ed25519 batch verification" { @@ -251,7 +252,7 @@ test "ed25519 batch verification" { try Ed25519.verifyBatch(2, signature_batch); signature_batch[1].sig = sig1; - std.testing.expectError(error.InvalidSignature, Ed25519.verifyBatch(signature_batch.len, signature_batch)); + std.testing.expectError(error.SignatureVerificationFailed, Ed25519.verifyBatch(signature_batch.len, signature_batch)); } } @@ -316,7 +317,7 @@ test "ed25519 test vectors" { .msg_hex = "9bedc267423725d473888631ebf45988bad3db83851ee85c85e241a07d148b41", .public_key_hex = "f7badec5b8abeaf699583992219b7b223f1df3fbbea919844e3f7c554a43dd43", .sig_hex = "ecffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff03be9678ac102edcd92b0210bb34d7428d12ffc5df5f37e359941266a4e35f0f", - .expected = error.InvalidSignature, // 8 - non-canonical R + .expected = error.SignatureVerificationFailed, // 8 - non-canonical R }, Vec{ .msg_hex = "9bedc267423725d473888631ebf45988bad3db83851ee85c85e241a07d148b41", diff --git a/lib/std/crypto/25519/edwards25519.zig b/lib/std/crypto/25519/edwards25519.zig index 8d9922d80c..89b7b9b9f3 100644 --- a/lib/std/crypto/25519/edwards25519.zig +++ b/lib/std/crypto/25519/edwards25519.zig @@ -7,6 +7,7 @@ const std = @import("std"); const debug = std.debug; const fmt = std.fmt; const mem = std.mem; +const Error = std.crypto.Error; /// Group operations over Edwards25519. pub const Edwards25519 = struct { @@ -25,7 +26,7 @@ pub const Edwards25519 = struct { is_base: bool = false, /// Decode an Edwards25519 point from its compressed (Y+sign) coordinates. - pub fn fromBytes(s: [encoded_length]u8) !Edwards25519 { + pub fn fromBytes(s: [encoded_length]u8) Error!Edwards25519 { const z = Fe.one; const y = Fe.fromBytes(s); var u = y.sq(); @@ -55,7 +56,7 @@ pub const Edwards25519 = struct { } /// Check that the encoding of a point is canonical. - pub fn rejectNonCanonical(s: [32]u8) !void { + pub fn rejectNonCanonical(s: [32]u8) Error!void { return Fe.rejectNonCanonical(s, true); } @@ -80,7 +81,7 @@ pub const Edwards25519 = struct { const identityElement = Edwards25519{ .x = Fe.zero, .y = Fe.one, .z = Fe.one, .t = Fe.zero }; /// Reject the neutral element. - pub fn rejectIdentity(p: Edwards25519) !void { + pub fn rejectIdentity(p: Edwards25519) Error!void { if (p.x.isZero()) { return error.IdentityElement; } @@ -176,7 +177,7 @@ pub const Edwards25519 = struct { // Based on real-world benchmarks, we only use this for multi-scalar multiplication. // NAF could be useful to half the size of precomputation tables, but we intentionally // avoid these to keep the standard library lightweight. - fn pcMul(pc: [9]Edwards25519, s: [32]u8, comptime vartime: bool) !Edwards25519 { + fn pcMul(pc: [9]Edwards25519, s: [32]u8, comptime vartime: bool) Error!Edwards25519 { std.debug.assert(vartime); const e = nonAdjacentForm(s); var q = Edwards25519.identityElement; @@ -196,7 +197,7 @@ pub const Edwards25519 = struct { } // Scalar multiplication with a 4-bit window and the first 15 multiples. - fn pcMul16(pc: [16]Edwards25519, s: [32]u8, comptime vartime: bool) !Edwards25519 { + fn pcMul16(pc: [16]Edwards25519, s: [32]u8, comptime vartime: bool) Error!Edwards25519 { var q = Edwards25519.identityElement; var pos: usize = 252; while (true) : (pos -= 4) { @@ -234,7 +235,7 @@ pub const Edwards25519 = struct { /// Multiply an Edwards25519 point by a scalar without clamping it. /// Return error.WeakPublicKey if the resulting point is /// the identity element. - pub fn mul(p: Edwards25519, s: [32]u8) !Edwards25519 { + pub fn mul(p: Edwards25519, s: [32]u8) Error!Edwards25519 { const pc = if (p.is_base) basePointPc else pc: { const xpc = precompute(p, 15); xpc[4].rejectIdentity() catch |_| return error.WeakPublicKey; @@ -245,7 +246,7 @@ pub const Edwards25519 = struct { /// Multiply an Edwards25519 point by a *PUBLIC* scalar *IN VARIABLE TIME* /// This can be used for signature verification. - pub fn mulPublic(p: Edwards25519, s: [32]u8) !Edwards25519 { + pub fn mulPublic(p: Edwards25519, s: [32]u8) Error!Edwards25519 { if (p.is_base) { return pcMul16(basePointPc, s, true); } else { @@ -257,7 +258,7 @@ pub const Edwards25519 = struct { /// Multiscalar multiplication *IN VARIABLE TIME* for public data /// Computes ps0*ss0 + ps1*ss1 + ps2*ss2... faster than doing many of these operations individually - pub fn mulMulti(comptime count: usize, ps: [count]Edwards25519, ss: [count][32]u8) !Edwards25519 { + pub fn mulMulti(comptime count: usize, ps: [count]Edwards25519, ss: [count][32]u8) Error!Edwards25519 { var pcs: [count][9]Edwards25519 = undefined; for (ps) |p, i| { if (p.is_base) { @@ -296,14 +297,14 @@ pub const Edwards25519 = struct { /// This is strongly recommended for DH operations. /// Return error.WeakPublicKey if the resulting point is /// the identity element. - pub fn clampedMul(p: Edwards25519, s: [32]u8) !Edwards25519 { + pub fn clampedMul(p: Edwards25519, s: [32]u8) Error!Edwards25519 { var t: [32]u8 = s; scalar.clamp(&t); return mul(p, t); } // montgomery -- recover y = sqrt(x^3 + A*x^2 + x) - fn xmontToYmont(x: Fe) !Fe { + fn xmontToYmont(x: Fe) Error!Fe { var x2 = x.sq(); const x3 = x.mul(x2); x2 = x2.mul32(Fe.edwards25519a_32); diff --git a/lib/std/crypto/25519/field.zig b/lib/std/crypto/25519/field.zig index 320cb1bb51..b570e2d06b 100644 --- a/lib/std/crypto/25519/field.zig +++ b/lib/std/crypto/25519/field.zig @@ -6,6 +6,7 @@ const std = @import("std"); const readIntLittle = std.mem.readIntLittle; const writeIntLittle = std.mem.writeIntLittle; +const Error = std.crypto.Error; pub const Fe = struct { limbs: [5]u64, @@ -112,7 +113,7 @@ pub const Fe = struct { } /// Reject non-canonical encodings of an element, possibly ignoring the top bit - pub fn rejectNonCanonical(s: [32]u8, comptime ignore_extra_bit: bool) !void { + pub fn rejectNonCanonical(s: [32]u8, comptime ignore_extra_bit: bool) Error!void { var c: u16 = (s[31] & 0x7f) ^ 0x7f; comptime var i = 30; inline while (i > 0) : (i -= 1) { @@ -412,7 +413,7 @@ pub const Fe = struct { } /// Compute the square root of `x2`, returning `error.NotSquare` if `x2` was not a square - pub fn sqrt(x2: Fe) !Fe { + pub fn sqrt(x2: Fe) Error!Fe { var x2_copy = x2; const x = x2.uncheckedSqrt(); const check = x.sq().sub(x2_copy); diff --git a/lib/std/crypto/25519/ristretto255.zig b/lib/std/crypto/25519/ristretto255.zig index 46bb9697e2..4644b7622e 100644 --- a/lib/std/crypto/25519/ristretto255.zig +++ b/lib/std/crypto/25519/ristretto255.zig @@ -5,6 +5,7 @@ // and substantial portions of the software. const std = @import("std"); const fmt = std.fmt; +const Error = std.crypto.Error; /// Group operations over Edwards25519. pub const Ristretto255 = struct { @@ -34,7 +35,7 @@ pub const Ristretto255 = struct { return .{ .ratio_is_square = @boolToInt(has_m_root) | @boolToInt(has_p_root), .root = x.abs() }; } - fn rejectNonCanonical(s: [encoded_length]u8) !void { + fn rejectNonCanonical(s: [encoded_length]u8) Error!void { if ((s[0] & 1) != 0) { return error.NonCanonical; } @@ -42,7 +43,7 @@ pub const Ristretto255 = struct { } /// Reject the neutral element. - pub fn rejectIdentity(p: Ristretto255) callconv(.Inline) !void { + pub fn rejectIdentity(p: Ristretto255) callconv(.Inline) Error!void { return p.p.rejectIdentity(); } @@ -50,7 +51,7 @@ pub const Ristretto255 = struct { pub const basePoint = Ristretto255{ .p = Curve.basePoint }; /// Decode a Ristretto255 representative. - pub fn fromBytes(s: [encoded_length]u8) !Ristretto255 { + pub fn fromBytes(s: [encoded_length]u8) Error!Ristretto255 { try rejectNonCanonical(s); const s_ = Fe.fromBytes(s); const ss = s_.sq(); // s^2 @@ -153,7 +154,7 @@ pub const Ristretto255 = struct { /// Multiply a Ristretto255 element with a scalar. /// Return error.WeakPublicKey if the resulting element is /// the identity element. - pub fn mul(p: Ristretto255, s: [encoded_length]u8) callconv(.Inline) !Ristretto255 { + pub fn mul(p: Ristretto255, s: [encoded_length]u8) callconv(.Inline) Error!Ristretto255 { return Ristretto255{ .p = try p.p.mul(s) }; } diff --git a/lib/std/crypto/25519/scalar.zig b/lib/std/crypto/25519/scalar.zig index e4fb277807..a4bf5aafcf 100644 --- a/lib/std/crypto/25519/scalar.zig +++ b/lib/std/crypto/25519/scalar.zig @@ -5,6 +5,7 @@ // and substantial portions of the software. const std = @import("std"); const mem = std.mem; +const Error = std.crypto.Error; /// 2^252 + 27742317777372353535851937790883648493 pub const field_size = [32]u8{ @@ -18,7 +19,7 @@ pub const CompressedScalar = [32]u8; pub const zero = [_]u8{0} ** 32; /// Reject a scalar whose encoding is not canonical. -pub fn rejectNonCanonical(s: [32]u8) !void { +pub fn rejectNonCanonical(s: [32]u8) Error!void { var c: u8 = 0; var n: u8 = 1; var i: usize = 31; diff --git a/lib/std/crypto/25519/x25519.zig b/lib/std/crypto/25519/x25519.zig index 5d0479bd4d..2d53124056 100644 --- a/lib/std/crypto/25519/x25519.zig +++ b/lib/std/crypto/25519/x25519.zig @@ -9,6 +9,7 @@ const mem = std.mem; const fmt = std.fmt; const Sha512 = crypto.hash.sha2.Sha512; +const Error = crypto.Error; /// X25519 DH function. pub const X25519 = struct { @@ -31,7 +32,7 @@ pub const X25519 = struct { secret_key: [secret_length]u8, /// Create a new key pair using an optional seed. - pub fn create(seed: ?[seed_length]u8) !KeyPair { + pub fn create(seed: ?[seed_length]u8) Error!KeyPair { const sk = seed orelse sk: { var random_seed: [seed_length]u8 = undefined; crypto.random.bytes(&random_seed); @@ -44,7 +45,7 @@ pub const X25519 = struct { } /// Create a key pair from an Ed25519 key pair - pub fn fromEd25519(ed25519_key_pair: crypto.sign.Ed25519.KeyPair) !KeyPair { + pub fn fromEd25519(ed25519_key_pair: crypto.sign.Ed25519.KeyPair) Error!KeyPair { const seed = ed25519_key_pair.secret_key[0..32]; var az: [Sha512.digest_length]u8 = undefined; Sha512.hash(seed, &az, .{}); @@ -59,13 +60,13 @@ pub const X25519 = struct { }; /// Compute the public key for a given private key. - pub fn recoverPublicKey(secret_key: [secret_length]u8) ![public_length]u8 { + pub fn recoverPublicKey(secret_key: [secret_length]u8) Error![public_length]u8 { const q = try Curve.basePoint.clampedMul(secret_key); return q.toBytes(); } /// Compute the X25519 equivalent to an Ed25519 public eky. - pub fn publicKeyFromEd25519(ed25519_public_key: [crypto.sign.Ed25519.public_length]u8) ![public_length]u8 { + pub fn publicKeyFromEd25519(ed25519_public_key: [crypto.sign.Ed25519.public_length]u8) Error![public_length]u8 { const pk_ed = try crypto.ecc.Edwards25519.fromBytes(ed25519_public_key); const pk = try Curve.fromEdwards25519(pk_ed); return pk.toBytes(); @@ -74,7 +75,7 @@ pub const X25519 = struct { /// Compute the scalar product of a public key and a secret scalar. /// Note that the output should not be used as a shared secret without /// hashing it first. - pub fn scalarmult(secret_key: [secret_length]u8, public_key: [public_length]u8) ![shared_length]u8 { + pub fn scalarmult(secret_key: [secret_length]u8, public_key: [public_length]u8) Error![shared_length]u8 { const q = try Curve.fromBytes(public_key).clampedMul(secret_key); return q.toBytes(); } diff --git a/lib/std/crypto/aegis.zig b/lib/std/crypto/aegis.zig index 2983f68ce8..3969d59e10 100644 --- a/lib/std/crypto/aegis.zig +++ b/lib/std/crypto/aegis.zig @@ -8,6 +8,7 @@ const std = @import("std"); const mem = std.mem; const assert = std.debug.assert; const AesBlock = std.crypto.core.aes.Block; +const Error = std.crypto.Error; const State128L = struct { blocks: [8]AesBlock, @@ -136,7 +137,7 @@ pub const Aegis128L = struct { /// ad: Associated Data /// npub: public nonce /// k: private key - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) !void { + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) Error!void { assert(c.len == m.len); var state = State128L.init(key, npub); var src: [32]u8 align(16) = undefined; @@ -298,7 +299,7 @@ pub const Aegis256 = struct { /// ad: Associated Data /// npub: public nonce /// k: private key - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) !void { + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) Error!void { assert(c.len == m.len); var state = State256.init(key, npub); var src: [16]u8 align(16) = undefined; diff --git a/lib/std/crypto/aes_gcm.zig b/lib/std/crypto/aes_gcm.zig index 5ef3f93963..bcb1b4c5fa 100644 --- a/lib/std/crypto/aes_gcm.zig +++ b/lib/std/crypto/aes_gcm.zig @@ -12,6 +12,7 @@ const debug = std.debug; const Ghash = std.crypto.onetimeauth.Ghash; const mem = std.mem; const modes = crypto.core.modes; +const Error = crypto.Error; pub const Aes128Gcm = AesGcm(crypto.core.aes.Aes128); pub const Aes256Gcm = AesGcm(crypto.core.aes.Aes256); @@ -59,7 +60,7 @@ fn AesGcm(comptime Aes: anytype) type { } } - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) !void { + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) Error!void { assert(c.len == m.len); const aes = Aes.initEnc(key); diff --git a/lib/std/crypto/aes_ocb.zig b/lib/std/crypto/aes_ocb.zig index ab0138f181..9eb0561d9f 100644 --- a/lib/std/crypto/aes_ocb.zig +++ b/lib/std/crypto/aes_ocb.zig @@ -10,6 +10,7 @@ const aes = crypto.core.aes; const assert = std.debug.assert; const math = std.math; const mem = std.mem; +const Error = crypto.Error; pub const Aes128Ocb = AesOcb(aes.Aes128); pub const Aes256Ocb = AesOcb(aes.Aes256); @@ -178,7 +179,7 @@ fn AesOcb(comptime Aes: anytype) type { /// ad: Associated Data /// npub: public nonce /// k: secret key - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) !void { + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) Error!void { assert(c.len == m.len); const aes_enc_ctx = Aes.initEnc(key); diff --git a/lib/std/crypto/bcrypt.zig b/lib/std/crypto/bcrypt.zig index caceb6d7b9..d00108b9c4 100644 --- a/lib/std/crypto/bcrypt.zig +++ b/lib/std/crypto/bcrypt.zig @@ -11,7 +11,8 @@ const math = std.math; const mem = std.mem; const debug = std.debug; const testing = std.testing; -const utils = std.crypto.utils; +const utils = crypto.utils; +const Error = crypto.Error; const salt_length: usize = 16; const salt_str_length: usize = 22; @@ -21,13 +22,6 @@ const ct_length: usize = 24; /// Length (in bytes) of a password hash pub const hash_length: usize = 60; -pub const BcryptError = error{ - /// The hashed password cannot be decoded. - InvalidEncoding, - /// The hash is not valid for the given password. - InvalidPassword, -}; - const State = struct { sboxes: [4][256]u32 = [4][256]u32{ .{ 0xd1310ba6, 0x98dfb5ac, 0x2ffd72db, 0xd01adfb7, 0xb8e1afed, 0x6a267e96, 0xba7c9045, 0xf12c7f99, 0x24a19947, 0xb3916cf7, 0x0801f2e2, 0x858efc16, 0x636920d8, 0x71574e69, 0xa458fea3, 0xf4933d7e, 0x0d95748f, 0x728eb658, 0x718bcd58, 0x82154aee, 0x7b54a41d, 0xc25a59b5, 0x9c30d539, 0x2af26013, 0xc5d1b023, 0x286085f0, 0xca417918, 0xb8db38ef, 0x8e79dcb0, 0x603a180e, 0x6c9e0e8b, 0xb01e8a3e, 0xd71577c1, 0xbd314b27, 0x78af2fda, 0x55605c60, 0xe65525f3, 0xaa55ab94, 0x57489862, 0x63e81440, 0x55ca396a, 0x2aab10b6, 0xb4cc5c34, 0x1141e8ce, 0xa15486af, 0x7c72e993, 0xb3ee1411, 0x636fbc2a, 0x2ba9c55d, 0x741831f6, 0xce5c3e16, 0x9b87931e, 0xafd6ba33, 0x6c24cf5c, 0x7a325381, 0x28958677, 0x3b8f4898, 0x6b4bb9af, 0xc4bfe81b, 0x66282193, 0x61d809cc, 0xfb21a991, 0x487cac60, 0x5dec8032, 0xef845d5d, 0xe98575b1, 0xdc262302, 0xeb651b88, 0x23893e81, 0xd396acc5, 0x0f6d6ff3, 0x83f44239, 0x2e0b4482, 0xa4842004, 0x69c8f04a, 0x9e1f9b5e, 0x21c66842, 0xf6e96c9a, 0x670c9c61, 0xabd388f0, 0x6a51a0d2, 0xd8542f68, 0x960fa728, 0xab5133a3, 0x6eef0b6c, 0x137a3be4, 0xba3bf050, 0x7efb2a98, 0xa1f1651d, 0x39af0176, 0x66ca593e, 0x82430e88, 0x8cee8619, 0x456f9fb4, 0x7d84a5c3, 0x3b8b5ebe, 0xe06f75d8, 0x85c12073, 0x401a449f, 0x56c16aa6, 0x4ed3aa62, 0x363f7706, 0x1bfedf72, 0x429b023d, 0x37d0d724, 0xd00a1248, 0xdb0fead3, 0x49f1c09b, 0x075372c9, 0x80991b7b, 0x25d479d8, 0xf6e8def7, 0xe3fe501a, 0xb6794c3b, 0x976ce0bd, 0x04c006ba, 0xc1a94fb6, 0x409f60c4, 0x5e5c9ec2, 0x196a2463, 0x68fb6faf, 0x3e6c53b5, 0x1339b2eb, 0x3b52ec6f, 0x6dfc511f, 0x9b30952c, 0xcc814544, 0xaf5ebd09, 0xbee3d004, 0xde334afd, 0x660f2807, 0x192e4bb3, 0xc0cba857, 0x45c8740f, 0xd20b5f39, 0xb9d3fbdb, 0x5579c0bd, 0x1a60320a, 0xd6a100c6, 0x402c7279, 0x679f25fe, 0xfb1fa3cc, 0x8ea5e9f8, 0xdb3222f8, 0x3c7516df, 0xfd616b15, 0x2f501ec8, 0xad0552ab, 0x323db5fa, 0xfd238760, 0x53317b48, 0x3e00df82, 0x9e5c57bb, 0xca6f8ca0, 0x1a87562e, 0xdf1769db, 0xd542a8f6, 0x287effc3, 0xac6732c6, 0x8c4f5573, 0x695b27b0, 0xbbca58c8, 0xe1ffa35d, 0xb8f011a0, 0x10fa3d98, 0xfd2183b8, 0x4afcb56c, 0x2dd1d35b, 0x9a53e479, 0xb6f84565, 0xd28e49bc, 0x4bfb9790, 0xe1ddf2da, 0xa4cb7e33, 0x62fb1341, 0xcee4c6e8, 0xef20cada, 0x36774c01, 0xd07e9efe, 0x2bf11fb4, 0x95dbda4d, 0xae909198, 0xeaad8e71, 0x6b93d5a0, 0xd08ed1d0, 0xafc725e0, 0x8e3c5b2f, 0x8e7594b7, 0x8ff6e2fb, 0xf2122b64, 0x8888b812, 0x900df01c, 0x4fad5ea0, 0x688fc31c, 0xd1cff191, 0xb3a8c1ad, 0x2f2f2218, 0xbe0e1777, 0xea752dfe, 0x8b021fa1, 0xe5a0cc0f, 0xb56f74e8, 0x18acf3d6, 0xce89e299, 0xb4a84fe0, 0xfd13e0b7, 0x7cc43b81, 0xd2ada8d9, 0x165fa266, 0x80957705, 0x93cc7314, 0x211a1477, 0xe6ad2065, 0x77b5fa86, 0xc75442f5, 0xfb9d35cf, 0xebcdaf0c, 0x7b3e89a0, 0xd6411bd3, 0xae1e7e49, 0x00250e2d, 0x2071b35e, 0x226800bb, 0x57b8e0af, 0x2464369b, 0xf009b91e, 0x5563911d, 0x59dfa6aa, 0x78c14389, 0xd95a537f, 0x207d5ba2, 0x02e5b9c5, 0x83260376, 0x6295cfa9, 0x11c81968, 0x4e734a41, 0xb3472dca, 0x7b14a94a, 0x1b510052, 0x9a532915, 0xd60f573f, 0xbc9bc6e4, 0x2b60a476, 0x81e67400, 0x08ba6fb5, 0x571be91f, 0xf296ec6b, 0x2a0dd915, 0xb6636521, 0xe7b9f9b6, 0xff34052e, 0xc5855664, 0x53b02d5d, 0xa99f8fa1, 0x08ba4799, 0x6e85076a }, @@ -185,7 +179,7 @@ const Codec = struct { debug.assert(j == b64.len); } - fn decode(bin: []u8, b64: []const u8) BcryptError!void { + fn decode(bin: []u8, b64: []const u8) Error!void { var i: usize = 0; var j: usize = 0; while (j < bin.len) { @@ -210,7 +204,7 @@ const Codec = struct { } }; -fn strHashInternal(password: []const u8, rounds_log: u6, salt: [salt_length]u8) BcryptError![hash_length]u8 { +fn strHashInternal(password: []const u8, rounds_log: u6, salt: [salt_length]u8) Error![hash_length]u8 { var state = State{}; var password_buf: [73]u8 = undefined; const trimmed_len = math.min(password.len, password_buf.len - 1); @@ -258,14 +252,14 @@ fn strHashInternal(password: []const u8, rounds_log: u6, salt: [salt_length]u8) /// IMPORTANT: by design, bcrypt silently truncates passwords to 72 bytes. /// If this is an issue for your application, hash the password first using a function such as SHA-512, /// and then use the resulting hash as the password parameter for bcrypt. -pub fn strHash(password: []const u8, rounds_log: u6) ![hash_length]u8 { +pub fn strHash(password: []const u8, rounds_log: u6) Error![hash_length]u8 { var salt: [salt_length]u8 = undefined; crypto.random.bytes(&salt); return strHashInternal(password, rounds_log, salt); } /// Verify that a previously computed hash is valid for a given password. -pub fn strVerify(h: [hash_length]u8, password: []const u8) BcryptError!void { +pub fn strVerify(h: [hash_length]u8, password: []const u8) Error!void { if (!mem.eql(u8, "$2", h[0..2])) return error.InvalidEncoding; if (h[3] != '$' or h[6] != '$') return error.InvalidEncoding; const rounds_log_str = h[4..][0..2]; @@ -275,7 +269,7 @@ pub fn strVerify(h: [hash_length]u8, password: []const u8) BcryptError!void { const rounds_log = fmt.parseInt(u6, rounds_log_str[0..], 10) catch return error.InvalidEncoding; const wanted_s = try strHashInternal(password, rounds_log, salt); if (!mem.eql(u8, wanted_s[0..], h[0..])) { - return error.InvalidPassword; + return error.PasswordVerificationFailed; } } @@ -292,7 +286,7 @@ test "bcrypt codec" { test "bcrypt" { const s = try strHash("password", 5); try strVerify(s, "password"); - testing.expectError(error.InvalidPassword, strVerify(s, "invalid password")); + testing.expectError(error.PasswordVerificationFailed, strVerify(s, "invalid password")); const long_s = try strHash("password" ** 100, 5); try strVerify(long_s, "password" ** 100); diff --git a/lib/std/crypto/benchmark.zig b/lib/std/crypto/benchmark.zig index e3ffa62ed1..49d5b15820 100644 --- a/lib/std/crypto/benchmark.zig +++ b/lib/std/crypto/benchmark.zig @@ -202,6 +202,7 @@ pub fn benchmarkBatchSignatureVerification(comptime Signature: anytype, comptime const aeads = [_]Crypto{ Crypto{ .ty = crypto.aead.chacha_poly.ChaCha20Poly1305, .name = "chacha20Poly1305" }, Crypto{ .ty = crypto.aead.chacha_poly.XChaCha20Poly1305, .name = "xchacha20Poly1305" }, + Crypto{ .ty = crypto.aead.chacha_poly.XChaCha8Poly1305, .name = "xchacha8Poly1305" }, Crypto{ .ty = crypto.aead.salsa_poly.XSalsa20Poly1305, .name = "xsalsa20Poly1305" }, Crypto{ .ty = crypto.aead.Gimli, .name = "gimli-aead" }, Crypto{ .ty = crypto.aead.aegis.Aegis128L, .name = "aegis-128l" }, diff --git a/lib/std/crypto/chacha20.zig b/lib/std/crypto/chacha20.zig index e01888e793..e1fe3e232d 100644 --- a/lib/std/crypto/chacha20.zig +++ b/lib/std/crypto/chacha20.zig @@ -13,287 +13,359 @@ const testing = std.testing; const maxInt = math.maxInt; const Vector = std.meta.Vector; const Poly1305 = std.crypto.onetimeauth.Poly1305; +const Error = std.crypto.Error; + +/// IETF-variant of the ChaCha20 stream cipher, as designed for TLS. +pub const ChaCha20IETF = ChaChaIETF(20); + +/// IETF-variant of the ChaCha20 stream cipher, reduced to 12 rounds. +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const ChaCha12IETF = ChaChaIETF(12); + +/// IETF-variant of the ChaCha20 stream cipher, reduced to 8 rounds. +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const ChaCha8IETF = ChaChaIETF(8); + +/// Original ChaCha20 stream cipher. +pub const ChaCha20With64BitNonce = ChaChaWith64BitNonce(20); + +/// Original ChaCha20 stream cipher, reduced to 12 rounds. +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const ChaCha12With64BitNonce = ChaChaWith64BitNonce(12); + +/// Original ChaCha20 stream cipher, reduced to 8 rounds. +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const ChaCha8With64BitNonce = ChaChaWith64BitNonce(8); + +/// XChaCha20 (nonce-extended version of the IETF ChaCha20 variant) stream cipher +pub const XChaCha20IETF = XChaChaIETF(20); + +/// XChaCha20 (nonce-extended version of the IETF ChaCha20 variant) stream cipher, reduced to 12 rounds +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const XChaCha12IETF = XChaChaIETF(12); + +/// XChaCha20 (nonce-extended version of the IETF ChaCha20 variant) stream cipher, reduced to 8 rounds +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const XChaCha8IETF = XChaChaIETF(8); + +/// ChaCha20-Poly1305 authenticated cipher, as designed for TLS +pub const ChaCha20Poly1305 = ChaChaPoly1305(20); + +/// ChaCha20-Poly1305 authenticated cipher, reduced to 12 rounds +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const ChaCha12Poly1305 = ChaChaPoly1305(12); + +/// ChaCha20-Poly1305 authenticated cipher, reduced to 8 rounds +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const ChaCha8Poly1305 = ChaChaPoly1305(8); + +/// XChaCha20-Poly1305 authenticated cipher +pub const XChaCha20Poly1305 = XChaChaPoly1305(20); + +/// XChaCha20-Poly1305 authenticated cipher +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const XChaCha12Poly1305 = XChaChaPoly1305(12); + +/// XChaCha20-Poly1305 authenticated cipher +/// Reduced-rounds versions are faster than the full-round version, but have a lower security margin. +/// However, ChaCha is still believed to have a comfortable security even with only with 8 rounds. +pub const XChaCha8Poly1305 = XChaChaPoly1305(8); // Vectorized implementation of the core function -const ChaCha20VecImpl = struct { - const Lane = Vector(4, u32); - const BlockVec = [4]Lane; - - fn initContext(key: [8]u32, d: [4]u32) BlockVec { - const c = "expand 32-byte k"; - const constant_le = comptime Lane{ - mem.readIntLittle(u32, c[0..4]), - mem.readIntLittle(u32, c[4..8]), - mem.readIntLittle(u32, c[8..12]), - mem.readIntLittle(u32, c[12..16]), - }; - return BlockVec{ - constant_le, - Lane{ key[0], key[1], key[2], key[3] }, - Lane{ key[4], key[5], key[6], key[7] }, - Lane{ d[0], d[1], d[2], d[3] }, - }; - } +fn ChaChaVecImpl(comptime rounds_nb: usize) type { + return struct { + const Lane = Vector(4, u32); + const BlockVec = [4]Lane; + + fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime Lane{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le, + Lane{ key[0], key[1], key[2], key[3] }, + Lane{ key[4], key[5], key[6], key[7] }, + Lane{ d[0], d[1], d[2], d[3] }, + }; + } - fn chacha20Core(x: *BlockVec, input: BlockVec) callconv(.Inline) void { - x.* = input; - - var r: usize = 0; - while (r < 20) : (r += 2) { - x[0] +%= x[1]; - x[3] ^= x[0]; - x[3] = math.rotl(Lane, x[3], 16); - - x[2] +%= x[3]; - x[1] ^= x[2]; - x[1] = math.rotl(Lane, x[1], 12); - - x[0] +%= x[1]; - x[3] ^= x[0]; - x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 3, 0, 1, 2 }); - x[3] = math.rotl(Lane, x[3], 8); - - x[2] +%= x[3]; - x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 }); - x[1] ^= x[2]; - x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 1, 2, 3, 0 }); - x[1] = math.rotl(Lane, x[1], 7); - - x[0] +%= x[1]; - x[3] ^= x[0]; - x[3] = math.rotl(Lane, x[3], 16); - - x[2] +%= x[3]; - x[1] ^= x[2]; - x[1] = math.rotl(Lane, x[1], 12); - - x[0] +%= x[1]; - x[3] ^= x[0]; - x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 1, 2, 3, 0 }); - x[3] = math.rotl(Lane, x[3], 8); - - x[2] +%= x[3]; - x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 }); - x[1] ^= x[2]; - x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 3, 0, 1, 2 }); - x[1] = math.rotl(Lane, x[1], 7); + fn chacha20Core(x: *BlockVec, input: BlockVec) callconv(.Inline) void { + x.* = input; + + var r: usize = 0; + while (r < rounds_nb) : (r += 2) { + x[0] +%= x[1]; + x[3] ^= x[0]; + x[3] = math.rotl(Lane, x[3], 16); + + x[2] +%= x[3]; + x[1] ^= x[2]; + x[1] = math.rotl(Lane, x[1], 12); + + x[0] +%= x[1]; + x[3] ^= x[0]; + x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 3, 0, 1, 2 }); + x[3] = math.rotl(Lane, x[3], 8); + + x[2] +%= x[3]; + x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 }); + x[1] ^= x[2]; + x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 1, 2, 3, 0 }); + x[1] = math.rotl(Lane, x[1], 7); + + x[0] +%= x[1]; + x[3] ^= x[0]; + x[3] = math.rotl(Lane, x[3], 16); + + x[2] +%= x[3]; + x[1] ^= x[2]; + x[1] = math.rotl(Lane, x[1], 12); + + x[0] +%= x[1]; + x[3] ^= x[0]; + x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 1, 2, 3, 0 }); + x[3] = math.rotl(Lane, x[3], 8); + + x[2] +%= x[3]; + x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 }); + x[1] ^= x[2]; + x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 3, 0, 1, 2 }); + x[1] = math.rotl(Lane, x[1], 7); + } } - } - fn hashToBytes(out: *[64]u8, x: BlockVec) callconv(.Inline) void { - var i: usize = 0; - while (i < 4) : (i += 1) { - mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]); - mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]); - mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]); - mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]); + fn hashToBytes(out: *[64]u8, x: BlockVec) callconv(.Inline) void { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]); + mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]); + mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]); + mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]); + } } - } - fn contextFeedback(x: *BlockVec, ctx: BlockVec) callconv(.Inline) void { - x[0] +%= ctx[0]; - x[1] +%= ctx[1]; - x[2] +%= ctx[2]; - x[3] +%= ctx[3]; - } + fn contextFeedback(x: *BlockVec, ctx: BlockVec) callconv(.Inline) void { + x[0] +%= ctx[0]; + x[1] +%= ctx[1]; + x[2] +%= ctx[2]; + x[3] +%= ctx[3]; + } - fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { - var ctx = initContext(key, counter); - var x: BlockVec = undefined; - var buf: [64]u8 = undefined; - var i: usize = 0; - while (i + 64 <= in.len) : (i += 64) { - chacha20Core(x[0..], ctx); - contextFeedback(&x, ctx); - hashToBytes(buf[0..], x); - - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < 64) : (j += 1) { - xout[j] = xin[j]; - } - j = 0; - while (j < 64) : (j += 1) { - xout[j] ^= buf[j]; + fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { + var ctx = initContext(key, counter); + var x: BlockVec = undefined; + var buf: [64]u8 = undefined; + var i: usize = 0; + while (i + 64 <= in.len) : (i += 64) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64) : (j += 1) { + xout[j] ^= buf[j]; + } + ctx[3][0] += 1; } - ctx[3][0] += 1; - } - if (i < in.len) { - chacha20Core(x[0..], ctx); - contextFeedback(&x, ctx); - hashToBytes(buf[0..], x); - - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < in.len % 64) : (j += 1) { - xout[j] = xin[j] ^ buf[j]; + if (i < in.len) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < in.len % 64) : (j += 1) { + xout[j] = xin[j] ^ buf[j]; + } } } - } - fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 { - var c: [4]u32 = undefined; - for (c) |_, i| { - c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 { + var c: [4]u32 = undefined; + for (c) |_, i| { + c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + } + const ctx = initContext(keyToWords(key), c); + var x: BlockVec = undefined; + chacha20Core(x[0..], ctx); + var out: [32]u8 = undefined; + mem.writeIntLittle(u32, out[0..4], x[0][0]); + mem.writeIntLittle(u32, out[4..8], x[0][1]); + mem.writeIntLittle(u32, out[8..12], x[0][2]); + mem.writeIntLittle(u32, out[12..16], x[0][3]); + mem.writeIntLittle(u32, out[16..20], x[3][0]); + mem.writeIntLittle(u32, out[20..24], x[3][1]); + mem.writeIntLittle(u32, out[24..28], x[3][2]); + mem.writeIntLittle(u32, out[28..32], x[3][3]); + return out; } - const ctx = initContext(keyToWords(key), c); - var x: BlockVec = undefined; - chacha20Core(x[0..], ctx); - var out: [32]u8 = undefined; - mem.writeIntLittle(u32, out[0..4], x[0][0]); - mem.writeIntLittle(u32, out[4..8], x[0][1]); - mem.writeIntLittle(u32, out[8..12], x[0][2]); - mem.writeIntLittle(u32, out[12..16], x[0][3]); - mem.writeIntLittle(u32, out[16..20], x[3][0]); - mem.writeIntLittle(u32, out[20..24], x[3][1]); - mem.writeIntLittle(u32, out[24..28], x[3][2]); - mem.writeIntLittle(u32, out[28..32], x[3][3]); - return out; - } -}; + }; +} // Non-vectorized implementation of the core function -const ChaCha20NonVecImpl = struct { - const BlockVec = [16]u32; - - fn initContext(key: [8]u32, d: [4]u32) BlockVec { - const c = "expand 32-byte k"; - const constant_le = comptime [4]u32{ - mem.readIntLittle(u32, c[0..4]), - mem.readIntLittle(u32, c[4..8]), - mem.readIntLittle(u32, c[8..12]), - mem.readIntLittle(u32, c[12..16]), - }; - return BlockVec{ - constant_le[0], constant_le[1], constant_le[2], constant_le[3], - key[0], key[1], key[2], key[3], - key[4], key[5], key[6], key[7], - d[0], d[1], d[2], d[3], - }; - } - - const QuarterRound = struct { - a: usize, - b: usize, - c: usize, - d: usize, - }; +fn ChaChaNonVecImpl(comptime rounds_nb: usize) type { + return struct { + const BlockVec = [16]u32; + + fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime [4]u32{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le[0], constant_le[1], constant_le[2], constant_le[3], + key[0], key[1], key[2], key[3], + key[4], key[5], key[6], key[7], + d[0], d[1], d[2], d[3], + }; + } - fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound { - return QuarterRound{ - .a = a, - .b = b, - .c = c, - .d = d, + const QuarterRound = struct { + a: usize, + b: usize, + c: usize, + d: usize, }; - } - fn chacha20Core(x: *BlockVec, input: BlockVec) callconv(.Inline) void { - x.* = input; - - const rounds = comptime [_]QuarterRound{ - Rp(0, 4, 8, 12), - Rp(1, 5, 9, 13), - Rp(2, 6, 10, 14), - Rp(3, 7, 11, 15), - Rp(0, 5, 10, 15), - Rp(1, 6, 11, 12), - Rp(2, 7, 8, 13), - Rp(3, 4, 9, 14), - }; + fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound { + return QuarterRound{ + .a = a, + .b = b, + .c = c, + .d = d, + }; + } - comptime var j: usize = 0; - inline while (j < 20) : (j += 2) { - inline for (rounds) |r| { - x[r.a] +%= x[r.b]; - x[r.d] = math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16)); - x[r.c] +%= x[r.d]; - x[r.b] = math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12)); - x[r.a] +%= x[r.b]; - x[r.d] = math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8)); - x[r.c] +%= x[r.d]; - x[r.b] = math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7)); + fn chacha20Core(x: *BlockVec, input: BlockVec) callconv(.Inline) void { + x.* = input; + + const rounds = comptime [_]QuarterRound{ + Rp(0, 4, 8, 12), + Rp(1, 5, 9, 13), + Rp(2, 6, 10, 14), + Rp(3, 7, 11, 15), + Rp(0, 5, 10, 15), + Rp(1, 6, 11, 12), + Rp(2, 7, 8, 13), + Rp(3, 4, 9, 14), + }; + + comptime var j: usize = 0; + inline while (j < rounds_nb) : (j += 2) { + inline for (rounds) |r| { + x[r.a] +%= x[r.b]; + x[r.d] = math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16)); + x[r.c] +%= x[r.d]; + x[r.b] = math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12)); + x[r.a] +%= x[r.b]; + x[r.d] = math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8)); + x[r.c] +%= x[r.d]; + x[r.b] = math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7)); + } } } - } - fn hashToBytes(out: *[64]u8, x: BlockVec) callconv(.Inline) void { - var i: usize = 0; - while (i < 4) : (i += 1) { - mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i * 4 + 0]); - mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i * 4 + 1]); - mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i * 4 + 2]); - mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i * 4 + 3]); + fn hashToBytes(out: *[64]u8, x: BlockVec) callconv(.Inline) void { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i * 4 + 0]); + mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i * 4 + 1]); + mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i * 4 + 2]); + mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i * 4 + 3]); + } } - } - fn contextFeedback(x: *BlockVec, ctx: BlockVec) callconv(.Inline) void { - var i: usize = 0; - while (i < 16) : (i += 1) { - x[i] +%= ctx[i]; + fn contextFeedback(x: *BlockVec, ctx: BlockVec) callconv(.Inline) void { + var i: usize = 0; + while (i < 16) : (i += 1) { + x[i] +%= ctx[i]; + } } - } - fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { - var ctx = initContext(key, counter); - var x: BlockVec = undefined; - var buf: [64]u8 = undefined; - var i: usize = 0; - while (i + 64 <= in.len) : (i += 64) { - chacha20Core(x[0..], ctx); - contextFeedback(&x, ctx); - hashToBytes(buf[0..], x); - - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < 64) : (j += 1) { - xout[j] = xin[j]; - } - j = 0; - while (j < 64) : (j += 1) { - xout[j] ^= buf[j]; + fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { + var ctx = initContext(key, counter); + var x: BlockVec = undefined; + var buf: [64]u8 = undefined; + var i: usize = 0; + while (i + 64 <= in.len) : (i += 64) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64) : (j += 1) { + xout[j] ^= buf[j]; + } + ctx[12] += 1; } - ctx[12] += 1; - } - if (i < in.len) { - chacha20Core(x[0..], ctx); - contextFeedback(&x, ctx); - hashToBytes(buf[0..], x); - - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < in.len % 64) : (j += 1) { - xout[j] = xin[j] ^ buf[j]; + if (i < in.len) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < in.len % 64) : (j += 1) { + xout[j] = xin[j] ^ buf[j]; + } } } - } - fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 { - var c: [4]u32 = undefined; - for (c) |_, i| { - c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 { + var c: [4]u32 = undefined; + for (c) |_, i| { + c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + } + const ctx = initContext(keyToWords(key), c); + var x: BlockVec = undefined; + chacha20Core(x[0..], ctx); + var out: [32]u8 = undefined; + mem.writeIntLittle(u32, out[0..4], x[0]); + mem.writeIntLittle(u32, out[4..8], x[1]); + mem.writeIntLittle(u32, out[8..12], x[2]); + mem.writeIntLittle(u32, out[12..16], x[3]); + mem.writeIntLittle(u32, out[16..20], x[12]); + mem.writeIntLittle(u32, out[20..24], x[13]); + mem.writeIntLittle(u32, out[24..28], x[14]); + mem.writeIntLittle(u32, out[28..32], x[15]); + return out; } - const ctx = initContext(keyToWords(key), c); - var x: BlockVec = undefined; - chacha20Core(x[0..], ctx); - var out: [32]u8 = undefined; - mem.writeIntLittle(u32, out[0..4], x[0]); - mem.writeIntLittle(u32, out[4..8], x[1]); - mem.writeIntLittle(u32, out[8..12], x[2]); - mem.writeIntLittle(u32, out[12..16], x[3]); - mem.writeIntLittle(u32, out[16..20], x[12]); - mem.writeIntLittle(u32, out[20..24], x[13]); - mem.writeIntLittle(u32, out[24..28], x[14]); - mem.writeIntLittle(u32, out[28..32], x[15]); - return out; - } -}; + }; +} -const ChaCha20Impl = if (std.Target.current.cpu.arch == .x86_64) ChaCha20VecImpl else ChaCha20NonVecImpl; +fn ChaChaImpl(comptime rounds_nb: usize) type { + return if (std.Target.current.cpu.arch == .x86_64) ChaChaVecImpl(rounds_nb) else ChaChaNonVecImpl(rounds_nb); +} fn keyToWords(key: [32]u8) [8]u32 { var k: [8]u32 = undefined; @@ -304,68 +376,239 @@ fn keyToWords(key: [32]u8) [8]u32 { return k; } -/// ChaCha20 avoids the possibility of timing attacks, as there are no branches -/// on secret key data. -/// -/// in and out should be the same length. -/// counter should generally be 0 or 1 -/// -/// ChaCha20 is self-reversing. To decrypt just run the cipher with the same -/// counter, nonce, and key. -pub const ChaCha20IETF = struct { - pub fn xor(out: []u8, in: []const u8, counter: u32, key: [32]u8, nonce: [12]u8) void { - assert(in.len == out.len); - assert((in.len >> 6) + counter <= maxInt(u32)); - - var c: [4]u32 = undefined; - c[0] = counter; - c[1] = mem.readIntLittle(u32, nonce[0..4]); - c[2] = mem.readIntLittle(u32, nonce[4..8]); - c[3] = mem.readIntLittle(u32, nonce[8..12]); - ChaCha20Impl.chacha20Xor(out, in, keyToWords(key), c); - } -}; - -/// This is the original ChaCha20 before RFC 7539, which recommends using the -/// orgininal version on applications such as disk or file encryption that might -/// exceed the 256 GiB limit of the 96-bit nonce version. -pub const ChaCha20With64BitNonce = struct { - pub fn xor(out: []u8, in: []const u8, counter: u64, key: [32]u8, nonce: [8]u8) void { - assert(in.len == out.len); - assert(counter +% (in.len >> 6) >= counter); - - var cursor: usize = 0; - const k = keyToWords(key); - var c: [4]u32 = undefined; - c[0] = @truncate(u32, counter); - c[1] = @truncate(u32, counter >> 32); - c[2] = mem.readIntLittle(u32, nonce[0..4]); - c[3] = mem.readIntLittle(u32, nonce[4..8]); - - const block_length = (1 << 6); - // The full block size is greater than the address space on a 32bit machine - const big_block = if (@sizeOf(usize) > 4) (block_length << 32) else maxInt(usize); - - // first partial big block - if (((@intCast(u64, maxInt(u32) - @truncate(u32, counter)) + 1) << 6) < in.len) { - ChaCha20Impl.chacha20Xor(out[cursor..big_block], in[cursor..big_block], k, c); - cursor = big_block - cursor; - c[1] += 1; - if (comptime @sizeOf(usize) > 4) { - // A big block is giant: 256 GiB, but we can avoid this limitation - var remaining_blocks: u32 = @intCast(u32, (in.len / big_block)); - var i: u32 = 0; - while (remaining_blocks > 0) : (remaining_blocks -= 1) { - ChaCha20Impl.chacha20Xor(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c); - c[1] += 1; // upper 32-bit of counter, generic chacha20Xor() doesn't know about this. - cursor += big_block; +fn extend(key: [32]u8, nonce: [24]u8, comptime rounds_nb: usize) struct { key: [32]u8, nonce: [12]u8 } { + var subnonce: [12]u8 = undefined; + mem.set(u8, subnonce[0..4], 0); + mem.copy(u8, subnonce[4..], nonce[16..24]); + return .{ + .key = ChaChaImpl(rounds_nb).hchacha20(nonce[0..16].*, key), + .nonce = subnonce, + }; +} + +fn ChaChaIETF(comptime rounds_nb: usize) type { + return struct { + /// Nonce length in bytes. + pub const nonce_length = 12; + /// Key length in bytes. + pub const key_length = 32; + + /// Add the output of the ChaCha20 stream cipher to `in` and stores the result into `out`. + /// WARNING: This function doesn't provide authenticated encryption. + /// Using the AEAD or one of the `box` versions is usually preferred. + pub fn xor(out: []u8, in: []const u8, counter: u32, key: [key_length]u8, nonce: [nonce_length]u8) void { + assert(in.len == out.len); + assert(in.len / 64 <= (1 << 32 - 1) - counter); + + var d: [4]u32 = undefined; + d[0] = counter; + d[1] = mem.readIntLittle(u32, nonce[0..4]); + d[2] = mem.readIntLittle(u32, nonce[4..8]); + d[3] = mem.readIntLittle(u32, nonce[8..12]); + ChaChaImpl(rounds_nb).chacha20Xor(out, in, keyToWords(key), d); + } + }; +} + +fn ChaChaWith64BitNonce(comptime rounds_nb: usize) type { + return struct { + /// Nonce length in bytes. + pub const nonce_length = 8; + /// Key length in bytes. + pub const key_length = 32; + + /// Add the output of the ChaCha20 stream cipher to `in` and stores the result into `out`. + /// WARNING: This function doesn't provide authenticated encryption. + /// Using the AEAD or one of the `box` versions is usually preferred. + pub fn xor(out: []u8, in: []const u8, counter: u64, key: [key_length]u8, nonce: [nonce_length]u8) void { + assert(in.len == out.len); + assert(in.len / 64 <= (1 << 64 - 1) - counter); + + var cursor: usize = 0; + const k = keyToWords(key); + var c: [4]u32 = undefined; + c[0] = @truncate(u32, counter); + c[1] = @truncate(u32, counter >> 32); + c[2] = mem.readIntLittle(u32, nonce[0..4]); + c[3] = mem.readIntLittle(u32, nonce[4..8]); + + const block_length = (1 << 6); + // The full block size is greater than the address space on a 32bit machine + const big_block = if (@sizeOf(usize) > 4) (block_length << 32) else maxInt(usize); + + // first partial big block + if (((@intCast(u64, maxInt(u32) - @truncate(u32, counter)) + 1) << 6) < in.len) { + ChaChaImpl(rounds_nb).chacha20Xor(out[cursor..big_block], in[cursor..big_block], k, c); + cursor = big_block - cursor; + c[1] += 1; + if (comptime @sizeOf(usize) > 4) { + // A big block is giant: 256 GiB, but we can avoid this limitation + var remaining_blocks: u32 = @intCast(u32, (in.len / big_block)); + var i: u32 = 0; + while (remaining_blocks > 0) : (remaining_blocks -= 1) { + ChaChaImpl(rounds_nb).chacha20Xor(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c); + c[1] += 1; // upper 32-bit of counter, generic chacha20Xor() doesn't know about this. + cursor += big_block; + } } } + ChaChaImpl(rounds_nb).chacha20Xor(out[cursor..], in[cursor..], k, c); + } + }; +} + +fn XChaChaIETF(comptime rounds_nb: usize) type { + return struct { + /// Nonce length in bytes. + pub const nonce_length = 24; + /// Key length in bytes. + pub const key_length = 32; + + /// Add the output of the XChaCha20 stream cipher to `in` and stores the result into `out`. + /// WARNING: This function doesn't provide authenticated encryption. + /// Using the AEAD or one of the `box` versions is usually preferred. + pub fn xor(out: []u8, in: []const u8, counter: u32, key: [key_length]u8, nonce: [nonce_length]u8) void { + const extended = extend(key, nonce, rounds_nb); + ChaChaIETF(rounds_nb).xor(out, in, counter, extended.key, extended.nonce); + } + }; +} + +fn ChaChaPoly1305(comptime rounds_nb: usize) type { + return struct { + pub const tag_length = 16; + pub const nonce_length = 12; + pub const key_length = 32; + + /// c: ciphertext: output buffer should be of size m.len + /// tag: authentication tag: output MAC + /// m: message + /// ad: Associated Data + /// npub: public nonce + /// k: private key + pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) void { + assert(c.len == m.len); + + var polyKey = [_]u8{0} ** 32; + ChaChaIETF(rounds_nb).xor(polyKey[0..], polyKey[0..], 0, k, npub); + + ChaChaIETF(rounds_nb).xor(c[0..m.len], m, 1, k, npub); + + var mac = Poly1305.init(polyKey[0..]); + mac.update(ad); + if (ad.len % 16 != 0) { + const zeros = [_]u8{0} ** 16; + const padding = 16 - (ad.len % 16); + mac.update(zeros[0..padding]); + } + mac.update(c[0..m.len]); + if (m.len % 16 != 0) { + const zeros = [_]u8{0} ** 16; + const padding = 16 - (m.len % 16); + mac.update(zeros[0..padding]); + } + var lens: [16]u8 = undefined; + mem.writeIntLittle(u64, lens[0..8], ad.len); + mem.writeIntLittle(u64, lens[8..16], m.len); + mac.update(lens[0..]); + mac.final(tag); + } + + /// m: message: output buffer should be of size c.len + /// c: ciphertext + /// tag: authentication tag + /// ad: Associated Data + /// npub: public nonce + /// k: private key + /// NOTE: the check of the authentication tag is currently not done in constant time + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) Error!void { + assert(c.len == m.len); + + var polyKey = [_]u8{0} ** 32; + ChaChaIETF(rounds_nb).xor(polyKey[0..], polyKey[0..], 0, k, npub); + + var mac = Poly1305.init(polyKey[0..]); + + mac.update(ad); + if (ad.len % 16 != 0) { + const zeros = [_]u8{0} ** 16; + const padding = 16 - (ad.len % 16); + mac.update(zeros[0..padding]); + } + mac.update(c); + if (c.len % 16 != 0) { + const zeros = [_]u8{0} ** 16; + const padding = 16 - (c.len % 16); + mac.update(zeros[0..padding]); + } + var lens: [16]u8 = undefined; + mem.writeIntLittle(u64, lens[0..8], ad.len); + mem.writeIntLittle(u64, lens[8..16], c.len); + mac.update(lens[0..]); + var computedTag: [16]u8 = undefined; + mac.final(computedTag[0..]); + + var acc: u8 = 0; + for (computedTag) |_, i| { + acc |= computedTag[i] ^ tag[i]; + } + if (acc != 0) { + return error.AuthenticationFailed; + } + ChaChaIETF(rounds_nb).xor(m[0..c.len], c, 1, k, npub); + } + }; +} + +fn XChaChaPoly1305(comptime rounds_nb: usize) type { + return struct { + pub const tag_length = 16; + pub const nonce_length = 24; + pub const key_length = 32; + + /// c: ciphertext: output buffer should be of size m.len + /// tag: authentication tag: output MAC + /// m: message + /// ad: Associated Data + /// npub: public nonce + /// k: private key + pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) void { + const extended = extend(k, npub, rounds_nb); + return ChaChaPoly1305(rounds_nb).encrypt(c, tag, m, ad, extended.nonce, extended.key); } - ChaCha20Impl.chacha20Xor(out[cursor..], in[cursor..], k, c); + /// m: message: output buffer should be of size c.len + /// c: ciphertext + /// tag: authentication tag + /// ad: Associated Data + /// npub: public nonce + /// k: private key + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) Error!void { + const extended = extend(k, npub, rounds_nb); + return ChaChaPoly1305(rounds_nb).decrypt(m, c, tag, ad, extended.nonce, extended.key); + } + }; +} + +test "chacha20 AEAD API" { + const aeads = [_]type{ ChaCha20Poly1305, XChaCha20Poly1305 }; + const m = "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; + const ad = "Additional data"; + + inline for (aeads) |aead| { + const key = [_]u8{69} ** aead.key_length; + const nonce = [_]u8{42} ** aead.nonce_length; + var c: [m.len]u8 = undefined; + var tag: [aead.tag_length]u8 = undefined; + var out: [m.len]u8 = undefined; + + aead.encrypt(c[0..], tag[0..], m, ad, nonce, key); + try aead.decrypt(out[0..], c[0..], tag, ad[0..], nonce, key); + testing.expectEqualSlices(u8, out[0..], m); + c[0] += 1; + testing.expectError(error.AuthenticationFailed, aead.decrypt(out[0..], c[0..], tag, ad[0..], nonce, key)); } -}; +} // https://tools.ietf.org/html/rfc7539#section-2.4.2 test "crypto.chacha20 test vector sunscreen" { @@ -386,7 +629,7 @@ test "crypto.chacha20 test vector sunscreen" { 0xb4, 0x0b, 0x8e, 0xed, 0xf2, 0x78, 0x5e, 0x42, 0x87, 0x4d, }; - const input = "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; + const m = "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; var result: [114]u8 = undefined; const key = [_]u8{ 0, 1, 2, 3, 4, 5, 6, 7, @@ -400,13 +643,12 @@ test "crypto.chacha20 test vector sunscreen" { 0, 0, 0, 0, }; - ChaCha20IETF.xor(result[0..], input[0..], 1, key, nonce); + ChaCha20IETF.xor(result[0..], m[0..], 1, key, nonce); testing.expectEqualSlices(u8, &expected_result, &result); - // Chacha20 is self-reversing. - var plaintext: [114]u8 = undefined; - ChaCha20IETF.xor(plaintext[0..], result[0..], 1, key, nonce); - testing.expect(mem.order(u8, input, &plaintext) == .eq); + var m2: [114]u8 = undefined; + ChaCha20IETF.xor(m2[0..], result[0..], 1, key, nonce); + testing.expect(mem.order(u8, m, &m2) == .eq); } // https://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-04#section-7 @@ -421,7 +663,7 @@ test "crypto.chacha20 test vector 1" { 0x6a, 0x43, 0xb8, 0xf4, 0x15, 0x18, 0xa1, 0x1c, 0xc3, 0x87, 0xb6, 0x69, 0xb2, 0xee, 0x65, 0x86, }; - const input = [_]u8{ + const m = [_]u8{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -440,7 +682,7 @@ test "crypto.chacha20 test vector 1" { }; const nonce = [_]u8{ 0, 0, 0, 0, 0, 0, 0, 0 }; - ChaCha20With64BitNonce.xor(result[0..], input[0..], 0, key, nonce); + ChaCha20With64BitNonce.xor(result[0..], m[0..], 0, key, nonce); testing.expectEqualSlices(u8, &expected_result, &result); } @@ -455,7 +697,7 @@ test "crypto.chacha20 test vector 2" { 0x53, 0xd7, 0x92, 0xb1, 0xc4, 0x3f, 0xea, 0x81, 0x7e, 0x9a, 0xd2, 0x75, 0xae, 0x54, 0x69, 0x63, }; - const input = [_]u8{ + const m = [_]u8{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -474,7 +716,7 @@ test "crypto.chacha20 test vector 2" { }; const nonce = [_]u8{ 0, 0, 0, 0, 0, 0, 0, 0 }; - ChaCha20With64BitNonce.xor(result[0..], input[0..], 0, key, nonce); + ChaCha20With64BitNonce.xor(result[0..], m[0..], 0, key, nonce); testing.expectEqualSlices(u8, &expected_result, &result); } @@ -489,7 +731,7 @@ test "crypto.chacha20 test vector 3" { 0x52, 0x77, 0x06, 0x2e, 0xb7, 0xa0, 0x43, 0x3e, 0x44, 0x5f, 0x41, 0xe3, }; - const input = [_]u8{ + const m = [_]u8{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -508,7 +750,7 @@ test "crypto.chacha20 test vector 3" { }; const nonce = [_]u8{ 0, 0, 0, 0, 0, 0, 0, 1 }; - ChaCha20With64BitNonce.xor(result[0..], input[0..], 0, key, nonce); + ChaCha20With64BitNonce.xor(result[0..], m[0..], 0, key, nonce); testing.expectEqualSlices(u8, &expected_result, &result); } @@ -523,7 +765,7 @@ test "crypto.chacha20 test vector 4" { 0x5d, 0xdc, 0x49, 0x7a, 0x0b, 0x46, 0x6e, 0x7d, 0x6b, 0xbd, 0xb0, 0x04, 0x1b, 0x2f, 0x58, 0x6b, }; - const input = [_]u8{ + const m = [_]u8{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -542,7 +784,7 @@ test "crypto.chacha20 test vector 4" { }; const nonce = [_]u8{ 1, 0, 0, 0, 0, 0, 0, 0 }; - ChaCha20With64BitNonce.xor(result[0..], input[0..], 0, key, nonce); + ChaCha20With64BitNonce.xor(result[0..], m[0..], 0, key, nonce); testing.expectEqualSlices(u8, &expected_result, &result); } @@ -584,7 +826,7 @@ test "crypto.chacha20 test vector 5" { 0x87, 0x46, 0xd4, 0x52, 0x4d, 0x38, 0x40, 0x7a, 0x6d, 0xeb, 0x3a, 0xb7, 0x8f, 0xab, 0x78, 0xc9, }; - const input = [_]u8{ + const m = [_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -614,147 +856,14 @@ test "crypto.chacha20 test vector 5" { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, }; - ChaCha20With64BitNonce.xor(result[0..], input[0..], 0, key, nonce); + ChaCha20With64BitNonce.xor(result[0..], m[0..], 0, key, nonce); testing.expectEqualSlices(u8, &expected_result, &result); } -pub const chacha20poly1305_tag_length = 16; - -fn chacha20poly1305SealDetached(ciphertext: []u8, tag: *[chacha20poly1305_tag_length]u8, plaintext: []const u8, data: []const u8, key: [32]u8, nonce: [12]u8) void { - assert(ciphertext.len == plaintext.len); - - // derive poly1305 key - var polyKey = [_]u8{0} ** 32; - ChaCha20IETF.xor(polyKey[0..], polyKey[0..], 0, key, nonce); - - // encrypt plaintext - ChaCha20IETF.xor(ciphertext[0..plaintext.len], plaintext, 1, key, nonce); - - // construct mac - var mac = Poly1305.init(polyKey[0..]); - mac.update(data); - if (data.len % 16 != 0) { - const zeros = [_]u8{0} ** 16; - const padding = 16 - (data.len % 16); - mac.update(zeros[0..padding]); - } - mac.update(ciphertext[0..plaintext.len]); - if (plaintext.len % 16 != 0) { - const zeros = [_]u8{0} ** 16; - const padding = 16 - (plaintext.len % 16); - mac.update(zeros[0..padding]); - } - var lens: [16]u8 = undefined; - mem.writeIntLittle(u64, lens[0..8], data.len); - mem.writeIntLittle(u64, lens[8..16], plaintext.len); - mac.update(lens[0..]); - mac.final(tag); -} - -fn chacha20poly1305Seal(ciphertextAndTag: []u8, plaintext: []const u8, data: []const u8, key: [32]u8, nonce: [12]u8) void { - return chacha20poly1305SealDetached(ciphertextAndTag[0..plaintext.len], ciphertextAndTag[plaintext.len..][0..chacha20poly1305_tag_length], plaintext, data, key, nonce); -} - -/// Verifies and decrypts an authenticated message produced by chacha20poly1305SealDetached. -fn chacha20poly1305OpenDetached(dst: []u8, ciphertext: []const u8, tag: *const [chacha20poly1305_tag_length]u8, data: []const u8, key: [32]u8, nonce: [12]u8) !void { - // split ciphertext and tag - assert(dst.len == ciphertext.len); - - // derive poly1305 key - var polyKey = [_]u8{0} ** 32; - ChaCha20IETF.xor(polyKey[0..], polyKey[0..], 0, key, nonce); - - // construct mac - var mac = Poly1305.init(polyKey[0..]); - - mac.update(data); - if (data.len % 16 != 0) { - const zeros = [_]u8{0} ** 16; - const padding = 16 - (data.len % 16); - mac.update(zeros[0..padding]); - } - mac.update(ciphertext); - if (ciphertext.len % 16 != 0) { - const zeros = [_]u8{0} ** 16; - const padding = 16 - (ciphertext.len % 16); - mac.update(zeros[0..padding]); - } - var lens: [16]u8 = undefined; - mem.writeIntLittle(u64, lens[0..8], data.len); - mem.writeIntLittle(u64, lens[8..16], ciphertext.len); - mac.update(lens[0..]); - var computedTag: [16]u8 = undefined; - mac.final(computedTag[0..]); - - // verify mac in constant time - // TODO: we can't currently guarantee that this will run in constant time. - // See https://github.com/ziglang/zig/issues/1776 - var acc: u8 = 0; - for (computedTag) |_, i| { - acc |= computedTag[i] ^ tag[i]; - } - if (acc != 0) { - return error.AuthenticationFailed; - } - - // decrypt ciphertext - ChaCha20IETF.xor(dst[0..ciphertext.len], ciphertext, 1, key, nonce); -} - -/// Verifies and decrypts an authenticated message produced by chacha20poly1305Seal. -fn chacha20poly1305Open(dst: []u8, ciphertextAndTag: []const u8, data: []const u8, key: [32]u8, nonce: [12]u8) !void { - if (ciphertextAndTag.len < chacha20poly1305_tag_length) { - return error.InvalidMessage; - } - const ciphertextLen = ciphertextAndTag.len - chacha20poly1305_tag_length; - return try chacha20poly1305OpenDetached(dst, ciphertextAndTag[0..ciphertextLen], ciphertextAndTag[ciphertextLen..][0..chacha20poly1305_tag_length], data, key, nonce); -} - -fn extend(key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [12]u8 } { - var subnonce: [12]u8 = undefined; - mem.set(u8, subnonce[0..4], 0); - mem.copy(u8, subnonce[4..], nonce[16..24]); - return .{ - .key = ChaCha20Impl.hchacha20(nonce[0..16].*, key), - .nonce = subnonce, - }; -} - -pub const XChaCha20IETF = struct { - pub fn xor(out: []u8, in: []const u8, counter: u32, key: [32]u8, nonce: [24]u8) void { - const extended = extend(key, nonce); - ChaCha20IETF.xor(out, in, counter, extended.key, extended.nonce); - } -}; - -pub const xchacha20poly1305_tag_length = 16; - -fn xchacha20poly1305SealDetached(ciphertext: []u8, tag: *[chacha20poly1305_tag_length]u8, plaintext: []const u8, data: []const u8, key: [32]u8, nonce: [24]u8) void { - const extended = extend(key, nonce); - return chacha20poly1305SealDetached(ciphertext, tag, plaintext, data, extended.key, extended.nonce); -} - -fn xchacha20poly1305Seal(ciphertextAndTag: []u8, plaintext: []const u8, data: []const u8, key: [32]u8, nonce: [24]u8) void { - const extended = extend(key, nonce); - return chacha20poly1305Seal(ciphertextAndTag, plaintext, data, extended.key, extended.nonce); -} - -/// Verifies and decrypts an authenticated message produced by xchacha20poly1305SealDetached. -fn xchacha20poly1305OpenDetached(plaintext: []u8, ciphertext: []const u8, tag: *const [chacha20poly1305_tag_length]u8, data: []const u8, key: [32]u8, nonce: [24]u8) !void { - const extended = extend(key, nonce); - return try chacha20poly1305OpenDetached(plaintext, ciphertext, tag, data, extended.key, extended.nonce); -} - -/// Verifies and decrypts an authenticated message produced by xchacha20poly1305Seal. -fn xchacha20poly1305Open(ciphertextAndTag: []u8, msgAndTag: []const u8, data: []const u8, key: [32]u8, nonce: [24]u8) !void { - const extended = extend(key, nonce); - return try chacha20poly1305Open(ciphertextAndTag, msgAndTag, data, extended.key, extended.nonce); -} - test "seal" { { - const plaintext = ""; - const data = ""; + const m = ""; + const ad = ""; const key = [_]u8{ 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, @@ -763,11 +872,11 @@ test "seal" { const exp_out = [_]u8{ 0xa0, 0x78, 0x4d, 0x7a, 0x47, 0x16, 0xf3, 0xfe, 0xb4, 0xf6, 0x4e, 0x7f, 0x4b, 0x39, 0xbf, 0x4 }; var out: [exp_out.len]u8 = undefined; - chacha20poly1305Seal(out[0..], plaintext, data, key, nonce); + ChaCha20Poly1305.encrypt(out[0..m.len], out[m.len..], m, ad, nonce, key); testing.expectEqualSlices(u8, exp_out[0..], out[0..]); } { - const plaintext = [_]u8{ + const m = [_]u8{ 0x4c, 0x61, 0x64, 0x69, 0x65, 0x73, 0x20, 0x61, 0x6e, 0x64, 0x20, 0x47, 0x65, 0x6e, 0x74, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x20, 0x6f, 0x66, 0x20, 0x74, 0x68, 0x65, 0x20, 0x63, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x6f, 0x66, 0x20, 0x27, 0x39, 0x39, 0x3a, 0x20, 0x49, 0x66, 0x20, 0x49, 0x20, 0x63, @@ -777,7 +886,7 @@ test "seal" { 0x63, 0x72, 0x65, 0x65, 0x6e, 0x20, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, 0x62, 0x65, 0x20, 0x69, 0x74, 0x2e, }; - const data = [_]u8{ 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7 }; + const ad = [_]u8{ 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7 }; const key = [_]u8{ 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, @@ -796,15 +905,15 @@ test "seal" { }; var out: [exp_out.len]u8 = undefined; - chacha20poly1305Seal(out[0..], plaintext[0..], data[0..], key, nonce); + ChaCha20Poly1305.encrypt(out[0..m.len], out[m.len..], m[0..], ad[0..], nonce, key); testing.expectEqualSlices(u8, exp_out[0..], out[0..]); } } test "open" { { - const ciphertext = [_]u8{ 0xa0, 0x78, 0x4d, 0x7a, 0x47, 0x16, 0xf3, 0xfe, 0xb4, 0xf6, 0x4e, 0x7f, 0x4b, 0x39, 0xbf, 0x4 }; - const data = ""; + const c = [_]u8{ 0xa0, 0x78, 0x4d, 0x7a, 0x47, 0x16, 0xf3, 0xfe, 0xb4, 0xf6, 0x4e, 0x7f, 0x4b, 0x39, 0xbf, 0x4 }; + const ad = ""; const key = [_]u8{ 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, @@ -813,11 +922,11 @@ test "open" { const exp_out = ""; var out: [exp_out.len]u8 = undefined; - try chacha20poly1305Open(out[0..], ciphertext[0..], data, key, nonce); + try ChaCha20Poly1305.decrypt(out[0..], c[0..exp_out.len], c[exp_out.len..].*, ad[0..], nonce, key); testing.expectEqualSlices(u8, exp_out[0..], out[0..]); } { - const ciphertext = [_]u8{ + const c = [_]u8{ 0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, 0x7b, 0x86, 0xaf, 0xbc, 0x53, 0xef, 0x7e, 0xc2, 0xa4, 0xad, 0xed, 0x51, 0x29, 0x6e, 0x8, 0xfe, 0xa9, 0xe2, 0xb5, 0xa7, 0x36, 0xee, 0x62, 0xd6, 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12, 0x82, 0xfa, 0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, @@ -828,7 +937,7 @@ test "open" { 0x61, 0x16, 0x1a, 0xe1, 0xb, 0x59, 0x4f, 0x9, 0xe2, 0x6a, 0x7e, 0x90, 0x2e, 0xcb, 0xd0, 0x60, 0x6, 0x91, }; - const data = [_]u8{ 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7 }; + const ad = [_]u8{ 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7 }; const key = [_]u8{ 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, @@ -846,126 +955,45 @@ test "open" { }; var out: [exp_out.len]u8 = undefined; - try chacha20poly1305Open(out[0..], ciphertext[0..], data[0..], key, nonce); + try ChaCha20Poly1305.decrypt(out[0..], c[0..exp_out.len], c[exp_out.len..].*, ad[0..], nonce, key); testing.expectEqualSlices(u8, exp_out[0..], out[0..]); // corrupting the ciphertext, data, key, or nonce should cause a failure - var bad_ciphertext = ciphertext; - bad_ciphertext[0] ^= 1; - testing.expectError(error.AuthenticationFailed, chacha20poly1305Open(out[0..], bad_ciphertext[0..], data[0..], key, nonce)); - var bad_data = data; - bad_data[0] ^= 1; - testing.expectError(error.AuthenticationFailed, chacha20poly1305Open(out[0..], ciphertext[0..], bad_data[0..], key, nonce)); + var bad_c = c; + bad_c[0] ^= 1; + testing.expectError(error.AuthenticationFailed, ChaCha20Poly1305.decrypt(out[0..], bad_c[0..out.len], bad_c[out.len..].*, ad[0..], nonce, key)); + var bad_ad = ad; + bad_ad[0] ^= 1; + testing.expectError(error.AuthenticationFailed, ChaCha20Poly1305.decrypt(out[0..], c[0..out.len], c[out.len..].*, bad_ad[0..], nonce, key)); var bad_key = key; bad_key[0] ^= 1; - testing.expectError(error.AuthenticationFailed, chacha20poly1305Open(out[0..], ciphertext[0..], data[0..], bad_key, nonce)); + testing.expectError(error.AuthenticationFailed, ChaCha20Poly1305.decrypt(out[0..], c[0..out.len], c[out.len..].*, ad[0..], nonce, bad_key)); var bad_nonce = nonce; bad_nonce[0] ^= 1; - testing.expectError(error.AuthenticationFailed, chacha20poly1305Open(out[0..], ciphertext[0..], data[0..], key, bad_nonce)); - - // a short ciphertext should result in a different error - testing.expectError(error.InvalidMessage, chacha20poly1305Open(out[0..], "", data[0..], key, bad_nonce)); + testing.expectError(error.AuthenticationFailed, ChaCha20Poly1305.decrypt(out[0..], c[0..out.len], c[out.len..].*, ad[0..], bad_nonce, key)); } } test "crypto.xchacha20" { const key = [_]u8{69} ** 32; const nonce = [_]u8{42} ** 24; - const input = "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; + const m = "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; { - var ciphertext: [input.len]u8 = undefined; - XChaCha20IETF.xor(ciphertext[0..], input[0..], 0, key, nonce); - var buf: [2 * ciphertext.len]u8 = undefined; - testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&ciphertext)}), "E0A1BCF939654AFDBDC1746EC49832647C19D891F0D1A81FC0C1703B4514BDEA584B512F6908C2C5E9DD18D5CBC1805DE5803FE3B9CA5F193FB8359E91FAB0C3BB40309A292EB1CF49685C65C4A3ADF4F11DB0CD2B6B67FBC174BC2E860E8F769FD3565BBFAD1C845E05A0FED9BE167C240D"); + var c: [m.len]u8 = undefined; + XChaCha20IETF.xor(c[0..], m[0..], 0, key, nonce); + var buf: [2 * c.len]u8 = undefined; + testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&c)}), "E0A1BCF939654AFDBDC1746EC49832647C19D891F0D1A81FC0C1703B4514BDEA584B512F6908C2C5E9DD18D5CBC1805DE5803FE3B9CA5F193FB8359E91FAB0C3BB40309A292EB1CF49685C65C4A3ADF4F11DB0CD2B6B67FBC174BC2E860E8F769FD3565BBFAD1C845E05A0FED9BE167C240D"); } { - const data = "Additional data"; - var ciphertext: [input.len + xchacha20poly1305_tag_length]u8 = undefined; - xchacha20poly1305Seal(ciphertext[0..], input, data, key, nonce); - var out: [input.len]u8 = undefined; - try xchacha20poly1305Open(out[0..], ciphertext[0..], data, key, nonce); - var buf: [2 * ciphertext.len]u8 = undefined; - testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&ciphertext)}), "994D2DD32333F48E53650C02C7A2ABB8E018B0836D7175AEC779F52E961780768F815C58F1AA52D211498DB89B9216763F569C9433A6BBFCEFB4D4A49387A4C5207FBB3B5A92B5941294DF30588C6740D39DC16FA1F0E634F7246CF7CDCB978E44347D89381B7A74EB7084F754B90BDE9AAF5A94B8F2A85EFD0B50692AE2D425E234"); - testing.expectEqualSlices(u8, out[0..], input); - ciphertext[0] += 1; - testing.expectError(error.AuthenticationFailed, xchacha20poly1305Open(out[0..], ciphertext[0..], data, key, nonce)); - } -} - -pub const Chacha20Poly1305 = struct { - pub const tag_length = 16; - pub const nonce_length = 12; - pub const key_length = 32; - - /// c: ciphertext: output buffer should be of size m.len - /// tag: authentication tag: output MAC - /// m: message - /// ad: Associated Data - /// npub: public nonce - /// k: private key - pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) void { - assert(c.len == m.len); - return chacha20poly1305SealDetached(c, tag, m, ad, k, npub); - } - - /// m: message: output buffer should be of size c.len - /// c: ciphertext - /// tag: authentication tag - /// ad: Associated Data - /// npub: public nonce - /// k: private key - /// NOTE: the check of the authentication tag is currently not done in constant time - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) !void { - assert(c.len == m.len); - return try chacha20poly1305OpenDetached(m, c, tag[0..], ad, k, npub); - } -}; - -pub const XChacha20Poly1305 = struct { - pub const tag_length = 16; - pub const nonce_length = 24; - pub const key_length = 32; - - /// c: ciphertext: output buffer should be of size m.len - /// tag: authentication tag: output MAC - /// m: message - /// ad: Associated Data - /// npub: public nonce - /// k: private key - pub fn encrypt(c: []u8, tag: *[tag_length]u8, m: []const u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) void { - assert(c.len == m.len); - return xchacha20poly1305SealDetached(c, tag, m, ad, k, npub); - } - - /// m: message: output buffer should be of size c.len - /// c: ciphertext - /// tag: authentication tag - /// ad: Associated Data - /// npub: public nonce - /// k: private key - /// NOTE: the check of the authentication tag is currently not done in constant time - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) !void { - assert(c.len == m.len); - return try xchacha20poly1305OpenDetached(m, c, tag[0..], ad, k, npub); - } -}; - -test "chacha20 AEAD API" { - const aeads = [_]type{ Chacha20Poly1305, XChacha20Poly1305 }; - const input = "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it."; - const data = "Additional data"; - - inline for (aeads) |aead| { - const key = [_]u8{69} ** aead.key_length; - const nonce = [_]u8{42} ** aead.nonce_length; - var ciphertext: [input.len]u8 = undefined; - var tag: [aead.tag_length]u8 = undefined; - var out: [input.len]u8 = undefined; - - aead.encrypt(ciphertext[0..], tag[0..], input, data, nonce, key); - try aead.decrypt(out[0..], ciphertext[0..], tag, data[0..], nonce, key); - testing.expectEqualSlices(u8, out[0..], input); - ciphertext[0] += 1; - testing.expectError(error.AuthenticationFailed, aead.decrypt(out[0..], ciphertext[0..], tag, data[0..], nonce, key)); + const ad = "Additional data"; + var c: [m.len + XChaCha20Poly1305.tag_length]u8 = undefined; + XChaCha20Poly1305.encrypt(c[0..m.len], c[m.len..], m, ad, nonce, key); + var out: [m.len]u8 = undefined; + try XChaCha20Poly1305.decrypt(out[0..], c[0..m.len], c[m.len..].*, ad, nonce, key); + var buf: [2 * c.len]u8 = undefined; + testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{s}", .{std.fmt.fmtSliceHexUpper(&c)}), "994D2DD32333F48E53650C02C7A2ABB8E018B0836D7175AEC779F52E961780768F815C58F1AA52D211498DB89B9216763F569C9433A6BBFCEFB4D4A49387A4C5207FBB3B5A92B5941294DF30588C6740D39DC16FA1F0E634F7246CF7CDCB978E44347D89381B7A74EB7084F754B90BDE9AAF5A94B8F2A85EFD0B50692AE2D425E234"); + testing.expectEqualSlices(u8, out[0..], m); + c[0] += 1; + testing.expectError(error.AuthenticationFailed, XChaCha20Poly1305.decrypt(out[0..], c[0..m.len], c[m.len..].*, ad, nonce, key)); } } diff --git a/lib/std/crypto/error.zig b/lib/std/crypto/error.zig new file mode 100644 index 0000000000..4cb12bb8f7 --- /dev/null +++ b/lib/std/crypto/error.zig @@ -0,0 +1,34 @@ +pub const Error = error{ + /// MAC verification failed - The tag doesn't verify for the given ciphertext and secret key + AuthenticationFailed, + + /// The requested output length is too long for the chosen algorithm + OutputTooLong, + + /// Finite field operation returned the identity element + IdentityElement, + + /// Encoded input cannot be decoded + InvalidEncoding, + + /// The signature does't verify for the given message and public key + SignatureVerificationFailed, + + /// Both a public and secret key have been provided, but they are incompatible + KeyMismatch, + + /// Encoded input is not in canonical form + NonCanonical, + + /// Square root has no solutions + NotSquare, + + /// Verification string doesn't match the provided password and parameters + PasswordVerificationFailed, + + /// Parameters would be insecure to use + WeakParameters, + + /// Public key would be insecure to use + WeakPublicKey, +}; diff --git a/lib/std/crypto/gimli.zig b/lib/std/crypto/gimli.zig index 1c1d6c79db..111e0c5274 100644 --- a/lib/std/crypto/gimli.zig +++ b/lib/std/crypto/gimli.zig @@ -20,6 +20,7 @@ const assert = std.debug.assert; const testing = std.testing; const htest = @import("test.zig"); const Vector = std.meta.Vector; +const Error = std.crypto.Error; pub const State = struct { pub const BLOCKBYTES = 48; @@ -392,7 +393,7 @@ pub const Aead = struct { /// npub: public nonce /// k: private key /// NOTE: the check of the authentication tag is currently not done in constant time - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) !void { + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) Error!void { assert(c.len == m.len); var state = Aead.init(ad, npub, k); @@ -429,7 +430,7 @@ pub const Aead = struct { // TODO: use a constant-time equality check here, see https://github.com/ziglang/zig/issues/1776 if (!mem.eql(u8, buf[0..State.RATE], &tag)) { @memset(m.ptr, undefined, m.len); - return error.InvalidMessage; + return error.AuthenticationFailed; } } }; diff --git a/lib/std/crypto/isap.zig b/lib/std/crypto/isap.zig index 990a0e7450..5219742d85 100644 --- a/lib/std/crypto/isap.zig +++ b/lib/std/crypto/isap.zig @@ -3,6 +3,7 @@ const debug = std.debug; const mem = std.mem; const math = std.math; const testing = std.testing; +const Error = std.crypto.Error; /// ISAPv2 is an authenticated encryption system hardened against side channels and fault attacks. /// https://csrc.nist.gov/CSRC/media/Projects/lightweight-cryptography/documents/round-2/spec-doc-rnd2/isap-spec-round2.pdf @@ -217,7 +218,7 @@ pub const IsapA128A = struct { tag.* = mac(c, ad, npub, key); } - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) !void { + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, key: [key_length]u8) Error!void { var computed_tag = mac(c, ad, npub, key); var acc: u8 = 0; for (computed_tag) |_, j| { diff --git a/lib/std/crypto/pbkdf2.zig b/lib/std/crypto/pbkdf2.zig index 25df1ba440..575fb83006 100644 --- a/lib/std/crypto/pbkdf2.zig +++ b/lib/std/crypto/pbkdf2.zig @@ -7,6 +7,7 @@ const std = @import("std"); const mem = std.mem; const maxInt = std.math.maxInt; +const Error = std.crypto.Error; // RFC 2898 Section 5.2 // @@ -19,36 +20,28 @@ const maxInt = std.math.maxInt; // pseudorandom function. See Appendix B.1 for further discussion.) // PBKDF2 is recommended for new applications. // -// PBKDF2 (P, S, c, dkLen) +// PBKDF2 (P, S, c, dk_len) // -// Options: PRF underlying pseudorandom function (hLen +// Options: PRF underlying pseudorandom function (h_len // denotes the length in octets of the // pseudorandom function output) // // Input: P password, an octet string // S salt, an octet string // c iteration count, a positive integer -// dkLen intended length in octets of the derived +// dk_len intended length in octets of the derived // key, a positive integer, at most -// (2^32 - 1) * hLen +// (2^32 - 1) * h_len // -// Output: DK derived key, a dkLen-octet string +// Output: DK derived key, a dk_len-octet string // Based on Apple's CommonKeyDerivation, based originally on code by Damien Bergamini. -pub const Pbkdf2Error = error{ - /// At least one round is required - TooFewRounds, - - /// Maximum length of the derived key is `maxInt(u32) * Prf.mac_length` - DerivedKeyTooLong, -}; - /// Apply PBKDF2 to generate a key from a password. /// /// PBKDF2 is defined in RFC 2898, and is a recommendation of NIST SP 800-132. /// -/// derivedKey: Slice of appropriate size for generated key. Generally 16 or 32 bytes in length. +/// dk: Slice of appropriate size for generated key. Generally 16 or 32 bytes in length. /// May be uninitialized. All bytes will be overwritten. /// Maximum size is `maxInt(u32) * Hash.digest_length` /// It is a programming error to pass buffer longer than the maximum size. @@ -59,43 +52,38 @@ pub const Pbkdf2Error = error{ /// /// rounds: Iteration count. Must be greater than 0. Common values range from 1,000 to 100,000. /// Larger iteration counts improve security by increasing the time required to compute -/// the derivedKey. It is common to tune this parameter to achieve approximately 100ms. +/// the dk. It is common to tune this parameter to achieve approximately 100ms. /// /// Prf: Pseudo-random function to use. A common choice is `std.crypto.auth.hmac.HmacSha256`. -pub fn pbkdf2(derivedKey: []u8, password: []const u8, salt: []const u8, rounds: u32, comptime Prf: type) Pbkdf2Error!void { - if (rounds < 1) return error.TooFewRounds; +pub fn pbkdf2(dk: []u8, password: []const u8, salt: []const u8, rounds: u32, comptime Prf: type) Error!void { + if (rounds < 1) return error.WeakParameters; - const dkLen = derivedKey.len; - const hLen = Prf.mac_length; - comptime std.debug.assert(hLen >= 1); + const dk_len = dk.len; + const h_len = Prf.mac_length; + comptime std.debug.assert(h_len >= 1); // FromSpec: // - // 1. If dkLen > maxInt(u32) * hLen, output "derived key too long" and + // 1. If dk_len > maxInt(u32) * h_len, output "derived key too long" and // stop. // - if (comptime (maxInt(usize) > maxInt(u32) * hLen) and (dkLen > @as(usize, maxInt(u32) * hLen))) { - // If maxInt(usize) is less than `maxInt(u32) * hLen` then dkLen is always inbounds - return error.DerivedKeyTooLong; + if (dk_len / h_len >= maxInt(u32)) { + // Counter starts at 1 and is 32 bit, so if we have to return more blocks, we would overflow + return error.OutputTooLong; } // FromSpec: // - // 2. Let l be the number of hLen-long blocks of bytes in the derived key, + // 2. Let l be the number of h_len-long blocks of bytes in the derived key, // rounding up, and let r be the number of bytes in the last // block // - // l will not overflow, proof: - // let `L(dkLen, hLen) = (dkLen + hLen - 1) / hLen` - // then `L^-1(l, hLen) = l*hLen - hLen + 1` - // 1) L^-1(maxInt(u32), hLen) <= maxInt(u32)*hLen - // 2) maxInt(u32)*hLen - hLen + 1 <= maxInt(u32)*hLen // subtract maxInt(u32)*hLen + 1 - // 3) -hLen <= -1 // multiply by -1 - // 4) hLen >= 1 - const r_ = dkLen % hLen; - const l = @intCast(u32, (dkLen / hLen) + @as(u1, if (r_ == 0) 0 else 1)); // original: (dkLen + hLen - 1) / hLen - const r = if (r_ == 0) hLen else r_; + const blocks_count = @intCast(u32, std.math.divCeil(usize, dk_len, h_len) catch unreachable); + var r = dk_len % h_len; + if (r == 0) { + r = h_len; + } // FromSpec: // @@ -125,37 +113,38 @@ pub fn pbkdf2(derivedKey: []u8, password: []const u8, salt: []const u8, rounds: // Here, INT (i) is a four-octet encoding of the integer i, most // significant octet first. // - // 4. Concatenate the blocks and extract the first dkLen octets to + // 4. Concatenate the blocks and extract the first dk_len octets to // produce a derived key DK: // // DK = T_1 || T_2 || ... || T_l<0..r-1> - var block: u32 = 0; // Spec limits to u32 - while (block < l) : (block += 1) { - var prevBlock: [hLen]u8 = undefined; - var newBlock: [hLen]u8 = undefined; + + var block: u32 = 0; + while (block < blocks_count) : (block += 1) { + var prev_block: [h_len]u8 = undefined; + var new_block: [h_len]u8 = undefined; // U_1 = PRF (P, S || INT (i)) - const blockIndex = mem.toBytes(mem.nativeToBig(u32, block + 1)); // Block index starts at 0001 + const block_index = mem.toBytes(mem.nativeToBig(u32, block + 1)); // Block index starts at 0001 var ctx = Prf.init(password); ctx.update(salt); - ctx.update(blockIndex[0..]); - ctx.final(prevBlock[0..]); + ctx.update(block_index[0..]); + ctx.final(prev_block[0..]); // Choose portion of DK to write into (T_n) and initialize - const offset = block * hLen; - const blockLen = if (block != l - 1) hLen else r; - const dkBlock: []u8 = derivedKey[offset..][0..blockLen]; - mem.copy(u8, dkBlock, prevBlock[0..dkBlock.len]); + const offset = block * h_len; + const block_len = if (block != blocks_count - 1) h_len else r; + const dk_block: []u8 = dk[offset..][0..block_len]; + mem.copy(u8, dk_block, prev_block[0..dk_block.len]); var i: u32 = 1; while (i < rounds) : (i += 1) { // U_c = PRF (P, U_{c-1}) - Prf.create(&newBlock, prevBlock[0..], password); - mem.copy(u8, prevBlock[0..], newBlock[0..]); + Prf.create(&new_block, prev_block[0..], password); + mem.copy(u8, prev_block[0..], new_block[0..]); // F (P, S, c, i) = U_1 \xor U_2 \xor ... \xor U_c - for (dkBlock) |_, j| { - dkBlock[j] ^= newBlock[j]; + for (dk_block) |_, j| { + dk_block[j] ^= new_block[j]; } } } @@ -165,49 +154,50 @@ const htest = @import("test.zig"); const HmacSha1 = std.crypto.auth.hmac.HmacSha1; // RFC 6070 PBKDF2 HMAC-SHA1 Test Vectors + test "RFC 6070 one iteration" { const p = "password"; const s = "salt"; const c = 1; - const dkLen = 20; + const dk_len = 20; - var derivedKey: [dkLen]u8 = undefined; + var dk: [dk_len]u8 = undefined; - try pbkdf2(&derivedKey, p, s, c, HmacSha1); + try pbkdf2(&dk, p, s, c, HmacSha1); const expected = "0c60c80f961f0e71f3a9b524af6012062fe037a6"; - htest.assertEqual(expected, derivedKey[0..]); + htest.assertEqual(expected, dk[0..]); } test "RFC 6070 two iterations" { const p = "password"; const s = "salt"; const c = 2; - const dkLen = 20; + const dk_len = 20; - var derivedKey: [dkLen]u8 = undefined; + var dk: [dk_len]u8 = undefined; - try pbkdf2(&derivedKey, p, s, c, HmacSha1); + try pbkdf2(&dk, p, s, c, HmacSha1); const expected = "ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957"; - htest.assertEqual(expected, derivedKey[0..]); + htest.assertEqual(expected, dk[0..]); } test "RFC 6070 4096 iterations" { const p = "password"; const s = "salt"; const c = 4096; - const dkLen = 20; + const dk_len = 20; - var derivedKey: [dkLen]u8 = undefined; + var dk: [dk_len]u8 = undefined; - try pbkdf2(&derivedKey, p, s, c, HmacSha1); + try pbkdf2(&dk, p, s, c, HmacSha1); const expected = "4b007901b765489abead49d926f721d065a429c1"; - htest.assertEqual(expected, derivedKey[0..]); + htest.assertEqual(expected, dk[0..]); } test "RFC 6070 16,777,216 iterations" { @@ -219,48 +209,48 @@ test "RFC 6070 16,777,216 iterations" { const p = "password"; const s = "salt"; const c = 16777216; - const dkLen = 20; + const dk_len = 20; - var derivedKey = [_]u8{0} ** dkLen; + var dk = [_]u8{0} ** dk_len; - try pbkdf2(&derivedKey, p, s, c, HmacSha1); + try pbkdf2(&dk, p, s, c, HmacSha1); const expected = "eefe3d61cd4da4e4e9945b3d6ba2158c2634e984"; - htest.assertEqual(expected, derivedKey[0..]); + htest.assertEqual(expected, dk[0..]); } test "RFC 6070 multi-block salt and password" { const p = "passwordPASSWORDpassword"; const s = "saltSALTsaltSALTsaltSALTsaltSALTsalt"; const c = 4096; - const dkLen = 25; + const dk_len = 25; - var derivedKey: [dkLen]u8 = undefined; + var dk: [dk_len]u8 = undefined; - try pbkdf2(&derivedKey, p, s, c, HmacSha1); + try pbkdf2(&dk, p, s, c, HmacSha1); const expected = "3d2eec4fe41c849b80c8d83662c0e44a8b291a964cf2f07038"; - htest.assertEqual(expected, derivedKey[0..]); + htest.assertEqual(expected, dk[0..]); } test "RFC 6070 embedded NUL" { const p = "pass\x00word"; const s = "sa\x00lt"; const c = 4096; - const dkLen = 16; + const dk_len = 16; - var derivedKey: [dkLen]u8 = undefined; + var dk: [dk_len]u8 = undefined; - try pbkdf2(&derivedKey, p, s, c, HmacSha1); + try pbkdf2(&dk, p, s, c, HmacSha1); const expected = "56fa6aa75548099dcc37d7f03425e0c3"; - htest.assertEqual(expected, derivedKey[0..]); + htest.assertEqual(expected, dk[0..]); } -test "Very large dkLen" { +test "Very large dk_len" { // This test allocates 8GB of memory and is expected to take several hours to run. if (true) { return error.SkipZigTest; @@ -268,13 +258,13 @@ test "Very large dkLen" { const p = "password"; const s = "salt"; const c = 1; - const dkLen = 1 << 33; + const dk_len = 1 << 33; - var derivedKey = try std.testing.allocator.alloc(u8, dkLen); + var dk = try std.testing.allocator.alloc(u8, dk_len); defer { - std.testing.allocator.free(derivedKey); + std.testing.allocator.free(dk); } - try pbkdf2(derivedKey, p, s, c, HmacSha1); // Just verify this doesn't crash with an overflow + try pbkdf2(dk, p, s, c, HmacSha1); } diff --git a/lib/std/crypto/salsa20.zig b/lib/std/crypto/salsa20.zig index e22668f998..006767c93f 100644 --- a/lib/std/crypto/salsa20.zig +++ b/lib/std/crypto/salsa20.zig @@ -15,6 +15,7 @@ const Vector = std.meta.Vector; const Poly1305 = crypto.onetimeauth.Poly1305; const Blake2b = crypto.hash.blake2.Blake2b; const X25519 = crypto.dh.X25519; +const Error = crypto.Error; const Salsa20VecImpl = struct { const Lane = Vector(4, u32); @@ -398,7 +399,7 @@ pub const XSalsa20Poly1305 = struct { /// ad: Associated Data /// npub: public nonce /// k: private key - pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) !void { + pub fn decrypt(m: []u8, c: []const u8, tag: [tag_length]u8, ad: []const u8, npub: [nonce_length]u8, k: [key_length]u8) Error!void { debug.assert(c.len == m.len); const extended = extend(k, npub); var block0 = [_]u8{0} ** 64; @@ -446,7 +447,7 @@ pub const SecretBox = struct { /// Verify and decrypt `c` using a nonce `npub` and a key `k`. /// `m` must be exactly `tag_length` smaller than `c`, as `c` includes an authentication tag in addition to the encrypted message. - pub fn open(m: []u8, c: []const u8, npub: [nonce_length]u8, k: [key_length]u8) !void { + pub fn open(m: []u8, c: []const u8, npub: [nonce_length]u8, k: [key_length]u8) Error!void { if (c.len < tag_length) { return error.AuthenticationFailed; } @@ -481,20 +482,20 @@ pub const Box = struct { pub const KeyPair = X25519.KeyPair; /// Compute a secret suitable for `secretbox` given a recipent's public key and a sender's secret key. - pub fn createSharedSecret(public_key: [public_length]u8, secret_key: [secret_length]u8) ![shared_length]u8 { + pub fn createSharedSecret(public_key: [public_length]u8, secret_key: [secret_length]u8) Error![shared_length]u8 { const p = try X25519.scalarmult(secret_key, public_key); const zero = [_]u8{0} ** 16; return Salsa20Impl.hsalsa20(zero, p); } /// Encrypt and authenticate a message using a recipient's public key `public_key` and a sender's `secret_key`. - pub fn seal(c: []u8, m: []const u8, npub: [nonce_length]u8, public_key: [public_length]u8, secret_key: [secret_length]u8) !void { + pub fn seal(c: []u8, m: []const u8, npub: [nonce_length]u8, public_key: [public_length]u8, secret_key: [secret_length]u8) Error!void { const shared_key = try createSharedSecret(public_key, secret_key); return SecretBox.seal(c, m, npub, shared_key); } /// Verify and decrypt a message using a recipient's secret key `public_key` and a sender's `public_key`. - pub fn open(m: []u8, c: []const u8, npub: [nonce_length]u8, public_key: [public_length]u8, secret_key: [secret_length]u8) !void { + pub fn open(m: []u8, c: []const u8, npub: [nonce_length]u8, public_key: [public_length]u8, secret_key: [secret_length]u8) Error!void { const shared_key = try createSharedSecret(public_key, secret_key); return SecretBox.open(m, c, npub, shared_key); } @@ -527,7 +528,7 @@ pub const SealedBox = struct { /// Encrypt a message `m` for a recipient whose public key is `public_key`. /// `c` must be `seal_length` bytes larger than `m`, so that the required metadata can be added. - pub fn seal(c: []u8, m: []const u8, public_key: [public_length]u8) !void { + pub fn seal(c: []u8, m: []const u8, public_key: [public_length]u8) Error!void { debug.assert(c.len == m.len + seal_length); var ekp = try KeyPair.create(null); const nonce = createNonce(ekp.public_key, public_key); @@ -538,7 +539,7 @@ pub const SealedBox = struct { /// Decrypt a message using a key pair. /// `m` must be exactly `seal_length` bytes smaller than `c`, as `c` also includes metadata. - pub fn open(m: []u8, c: []const u8, keypair: KeyPair) !void { + pub fn open(m: []u8, c: []const u8, keypair: KeyPair) Error!void { if (c.len < seal_length) { return error.AuthenticationFailed; } diff --git a/lib/std/debug.zig b/lib/std/debug.zig index 74fb95ffa8..a7badf7ed1 100644 --- a/lib/std/debug.zig +++ b/lib/std/debug.zig @@ -250,24 +250,6 @@ pub fn panicExtra(trace: ?*const builtin.StackTrace, first_trace_addr: ?usize, c resetSegfaultHandler(); } - if (comptime std.Target.current.isDarwin() and std.Target.current.cpu.arch == .aarch64) - nosuspend { - // As a workaround for not having threadlocal variable support in LLD for this target, - // we have a simpler panic implementation that does not use threadlocal variables. - // TODO https://github.com/ziglang/zig/issues/7527 - const stderr = io.getStdErr().writer(); - if (@atomicRmw(u8, &panicking, .Add, 1, .SeqCst) == 0) { - stderr.print("panic: " ++ format ++ "\n", args) catch os.abort(); - if (trace) |t| { - dumpStackTrace(t.*); - } - dumpCurrentStackTrace(first_trace_addr); - } else { - stderr.print("Panicked during a panic. Aborting.\n", .{}) catch os.abort(); - } - os.abort(); - }; - nosuspend switch (panic_stage) { 0 => { panic_stage = 1; diff --git a/lib/std/enums.zig b/lib/std/enums.zig new file mode 100644 index 0000000000..bddda38c9f --- /dev/null +++ b/lib/std/enums.zig @@ -0,0 +1,1281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2015-2021 Zig Contributors +// This file is part of [zig](https://ziglang.org/), which is MIT licensed. +// The MIT license requires this copyright notice to be included in all copies +// and substantial portions of the software. + +//! This module contains utilities and data structures for working with enums. + +const std = @import("std.zig"); +const assert = std.debug.assert; +const testing = std.testing; +const EnumField = std.builtin.TypeInfo.EnumField; + +/// Returns a struct with a field matching each unique named enum element. +/// If the enum is extern and has multiple names for the same value, only +/// the first name is used. Each field is of type Data and has the provided +/// default, which may be undefined. +pub fn EnumFieldStruct(comptime E: type, comptime Data: type, comptime field_default: ?Data) type { + const StructField = std.builtin.TypeInfo.StructField; + var fields: []const StructField = &[_]StructField{}; + for (uniqueFields(E)) |field, i| { + fields = fields ++ &[_]StructField{.{ + .name = field.name, + .field_type = Data, + .default_value = field_default, + .is_comptime = false, + .alignment = if (@sizeOf(Data) > 0) @alignOf(Data) else 0, + }}; + } + return @Type(.{ .Struct = .{ + .layout = .Auto, + .fields = fields, + .decls = &[_]std.builtin.TypeInfo.Declaration{}, + .is_tuple = false, + }}); +} + +/// Looks up the supplied fields in the given enum type. +/// Uses only the field names, field values are ignored. +/// The result array is in the same order as the input. +pub fn valuesFromFields(comptime E: type, comptime fields: []const EnumField) []const E { + comptime { + var result: [fields.len]E = undefined; + for (fields) |f, i| { + result[i] = @field(E, f.name); + } + return &result; + } +} + +test "std.enums.valuesFromFields" { + const E = extern enum { a, b, c, d = 0 }; + const fields = valuesFromFields(E, &[_]EnumField{ + .{ .name = "b", .value = undefined }, + .{ .name = "a", .value = undefined }, + .{ .name = "a", .value = undefined }, + .{ .name = "d", .value = undefined }, + }); + testing.expectEqual(E.b, fields[0]); + testing.expectEqual(E.a, fields[1]); + testing.expectEqual(E.d, fields[2]); // a == d + testing.expectEqual(E.d, fields[3]); +} + +/// Returns the set of all named values in the given enum, in +/// declaration order. +pub fn values(comptime E: type) []const E { + return comptime valuesFromFields(E, @typeInfo(E).Enum.fields); +} + +test "std.enum.values" { + const E = extern enum { a, b, c, d = 0 }; + testing.expectEqualSlices(E, &.{.a, .b, .c, .d}, values(E)); +} + +/// Returns the set of all unique named values in the given enum, in +/// declaration order. For repeated values in extern enums, only the +/// first name for each value is included. +pub fn uniqueValues(comptime E: type) []const E { + return comptime valuesFromFields(E, uniqueFields(E)); +} + +test "std.enum.uniqueValues" { + const E = extern enum { a, b, c, d = 0, e, f = 3 }; + testing.expectEqualSlices(E, &.{.a, .b, .c, .f}, uniqueValues(E)); + + const F = enum { a, b, c }; + testing.expectEqualSlices(F, &.{.a, .b, .c}, uniqueValues(F)); +} + +/// Returns the set of all unique field values in the given enum, in +/// declaration order. For repeated values in extern enums, only the +/// first name for each value is included. +pub fn uniqueFields(comptime E: type) []const EnumField { + comptime { + const info = @typeInfo(E).Enum; + const raw_fields = info.fields; + // Only extern enums can contain duplicates, + // so fast path other types. + if (info.layout != .Extern) { + return raw_fields; + } + + var unique_fields: []const EnumField = &[_]EnumField{}; + outer: + for (raw_fields) |candidate| { + for (unique_fields) |u| { + if (u.value == candidate.value) + continue :outer; + } + unique_fields = unique_fields ++ &[_]EnumField{candidate}; + } + + return unique_fields; + } +} + +/// Determines the length of a direct-mapped enum array, indexed by +/// @intCast(usize, @enumToInt(enum_value)). The enum must be exhaustive. +/// If the enum contains any fields with values that cannot be represented +/// by usize, a compile error is issued. The max_unused_slots parameter limits +/// the total number of items which have no matching enum key (holes in the enum +/// numbering). So for example, if an enum has values 1, 2, 5, and 6, max_unused_slots +/// must be at least 3, to allow unused slots 0, 3, and 4. +fn directEnumArrayLen(comptime E: type, comptime max_unused_slots: comptime_int) comptime_int { + const info = @typeInfo(E).Enum; + if (!info.is_exhaustive) { + @compileError("Cannot create direct array of non-exhaustive enum "++@typeName(E)); + } + + var max_value: comptime_int = -1; + const max_usize: comptime_int = ~@as(usize, 0); + const fields = uniqueFields(E); + for (fields) |f| { + if (f.value < 0) { + @compileError("Cannot create a direct enum array for "++@typeName(E)++", field ."++f.name++" has a negative value."); + } + if (f.value > max_value) { + if (f.value > max_usize) { + @compileError("Cannot create a direct enum array for "++@typeName(E)++", field ."++f.name++" is larger than the max value of usize."); + } + max_value = f.value; + } + } + + const unused_slots = max_value + 1 - fields.len; + if (unused_slots > max_unused_slots) { + const unused_str = std.fmt.comptimePrint("{d}", .{unused_slots}); + const allowed_str = std.fmt.comptimePrint("{d}", .{max_unused_slots}); + @compileError("Cannot create a direct enum array for "++@typeName(E)++". It would have "++unused_str++" unused slots, but only "++allowed_str++" are allowed."); + } + + return max_value + 1; +} + +/// Initializes an array of Data which can be indexed by +/// @intCast(usize, @enumToInt(enum_value)). The enum must be exhaustive. +/// If the enum contains any fields with values that cannot be represented +/// by usize, a compile error is issued. The max_unused_slots parameter limits +/// the total number of items which have no matching enum key (holes in the enum +/// numbering). So for example, if an enum has values 1, 2, 5, and 6, max_unused_slots +/// must be at least 3, to allow unused slots 0, 3, and 4. +/// The init_values parameter must be a struct with field names that match the enum values. +/// If the enum has multiple fields with the same value, the name of the first one must +/// be used. +pub fn directEnumArray( + comptime E: type, + comptime Data: type, + comptime max_unused_slots: comptime_int, + init_values: EnumFieldStruct(E, Data, null), +) [directEnumArrayLen(E, max_unused_slots)]Data { + return directEnumArrayDefault(E, Data, null, max_unused_slots, init_values); +} + +test "std.enums.directEnumArray" { + const E = enum(i4) { a = 4, b = 6, c = 2 }; + var runtime_false: bool = false; + const array = directEnumArray(E, bool, 4, .{ + .a = true, + .b = runtime_false, + .c = true, + }); + + testing.expectEqual([7]bool, @TypeOf(array)); + testing.expectEqual(true, array[4]); + testing.expectEqual(false, array[6]); + testing.expectEqual(true, array[2]); +} + +/// Initializes an array of Data which can be indexed by +/// @intCast(usize, @enumToInt(enum_value)). The enum must be exhaustive. +/// If the enum contains any fields with values that cannot be represented +/// by usize, a compile error is issued. The max_unused_slots parameter limits +/// the total number of items which have no matching enum key (holes in the enum +/// numbering). So for example, if an enum has values 1, 2, 5, and 6, max_unused_slots +/// must be at least 3, to allow unused slots 0, 3, and 4. +/// The init_values parameter must be a struct with field names that match the enum values. +/// If the enum has multiple fields with the same value, the name of the first one must +/// be used. +pub fn directEnumArrayDefault( + comptime E: type, + comptime Data: type, + comptime default: ?Data, + comptime max_unused_slots: comptime_int, + init_values: EnumFieldStruct(E, Data, default), +) [directEnumArrayLen(E, max_unused_slots)]Data { + const len = comptime directEnumArrayLen(E, max_unused_slots); + var result: [len]Data = if (default) |d| [_]Data{d} ** len else undefined; + inline for (@typeInfo(@TypeOf(init_values)).Struct.fields) |f, i| { + const enum_value = @field(E, f.name); + const index = @intCast(usize, @enumToInt(enum_value)); + result[index] = @field(init_values, f.name); + } + return result; +} + +test "std.enums.directEnumArrayDefault" { + const E = enum(i4) { a = 4, b = 6, c = 2 }; + var runtime_false: bool = false; + const array = directEnumArrayDefault(E, bool, false, 4, .{ + .a = true, + .b = runtime_false, + }); + + testing.expectEqual([7]bool, @TypeOf(array)); + testing.expectEqual(true, array[4]); + testing.expectEqual(false, array[6]); + testing.expectEqual(false, array[2]); +} + +/// Cast an enum literal, value, or string to the enum value of type E +/// with the same name. +pub fn nameCast(comptime E: type, comptime value: anytype) E { + comptime { + const V = @TypeOf(value); + if (V == E) return value; + var name: ?[]const u8 = switch (@typeInfo(V)) { + .EnumLiteral, .Enum => @tagName(value), + .Pointer => if (std.meta.trait.isZigString(V)) value else null, + else => null, + }; + if (name) |n| { + if (@hasField(E, n)) { + return @field(E, n); + } + @compileError("Enum "++@typeName(E)++" has no field named "++n); + } + @compileError("Cannot cast from "++@typeName(@TypeOf(value))++" to "++@typeName(E)); + } +} + +test "std.enums.nameCast" { + const A = enum { a = 0, b = 1 }; + const B = enum { a = 1, b = 0 }; + testing.expectEqual(A.a, nameCast(A, .a)); + testing.expectEqual(A.a, nameCast(A, A.a)); + testing.expectEqual(A.a, nameCast(A, B.a)); + testing.expectEqual(A.a, nameCast(A, "a")); + testing.expectEqual(A.a, nameCast(A, @as(*const[1]u8, "a"))); + testing.expectEqual(A.a, nameCast(A, @as([:0]const u8, "a"))); + testing.expectEqual(A.a, nameCast(A, @as([]const u8, "a"))); + + testing.expectEqual(B.a, nameCast(B, .a)); + testing.expectEqual(B.a, nameCast(B, A.a)); + testing.expectEqual(B.a, nameCast(B, B.a)); + testing.expectEqual(B.a, nameCast(B, "a")); + + testing.expectEqual(B.b, nameCast(B, .b)); + testing.expectEqual(B.b, nameCast(B, A.b)); + testing.expectEqual(B.b, nameCast(B, B.b)); + testing.expectEqual(B.b, nameCast(B, "b")); +} + +/// A set of enum elements, backed by a bitfield. If the enum +/// is not dense, a mapping will be constructed from enum values +/// to dense indices. This type does no dynamic allocation and +/// can be copied by value. +pub fn EnumSet(comptime E: type) type { + const mixin = struct { + fn EnumSetExt(comptime Self: type) type { + const Indexer = Self.Indexer; + return struct { + /// Initializes the set using a struct of bools + pub fn init(init_values: EnumFieldStruct(E, bool, false)) Self { + var result = Self{}; + comptime var i: usize = 0; + inline while (i < Self.len) : (i += 1) { + comptime const key = Indexer.keyForIndex(i); + comptime const tag = @tagName(key); + if (@field(init_values, tag)) { + result.bits.set(i); + } + } + return result; + } + }; + } + }; + return IndexedSet(EnumIndexer(E), mixin.EnumSetExt); +} + +/// A map keyed by an enum, backed by a bitfield and a dense array. +/// If the enum is not dense, a mapping will be constructed from +/// enum values to dense indices. This type does no dynamic +/// allocation and can be copied by value. +pub fn EnumMap(comptime E: type, comptime V: type) type { + const mixin = struct { + fn EnumMapExt(comptime Self: type) type { + const Indexer = Self.Indexer; + return struct { + /// Initializes the map using a sparse struct of optionals + pub fn init(init_values: EnumFieldStruct(E, ?V, @as(?V, null))) Self { + var result = Self{}; + comptime var i: usize = 0; + inline while (i < Self.len) : (i += 1) { + comptime const key = Indexer.keyForIndex(i); + comptime const tag = @tagName(key); + if (@field(init_values, tag)) |*v| { + result.bits.set(i); + result.values[i] = v.*; + } + } + return result; + } + /// Initializes a full mapping with all keys set to value. + /// Consider using EnumArray instead if the map will remain full. + pub fn initFull(value: V) Self { + var result = Self{ + .bits = Self.BitSet.initFull(), + .values = undefined, + }; + std.mem.set(V, &result.values, value); + return result; + } + /// Initializes a full mapping with supplied values. + /// Consider using EnumArray instead if the map will remain full. + pub fn initFullWith(init_values: EnumFieldStruct(E, V, @as(?V, null))) Self { + return initFullWithDefault(@as(?V, null), init_values); + } + /// Initializes a full mapping with a provided default. + /// Consider using EnumArray instead if the map will remain full. + pub fn initFullWithDefault(comptime default: ?V, init_values: EnumFieldStruct(E, V, default)) Self { + var result = Self{ + .bits = Self.BitSet.initFull(), + .values = undefined, + }; + comptime var i: usize = 0; + inline while (i < Self.len) : (i += 1) { + comptime const key = Indexer.keyForIndex(i); + comptime const tag = @tagName(key); + result.values[i] = @field(init_values, tag); + } + return result; + } + }; + } + }; + return IndexedMap(EnumIndexer(E), V, mixin.EnumMapExt); +} + +/// An array keyed by an enum, backed by a dense array. +/// If the enum is not dense, a mapping will be constructed from +/// enum values to dense indices. This type does no dynamic +/// allocation and can be copied by value. +pub fn EnumArray(comptime E: type, comptime V: type) type { + const mixin = struct { + fn EnumArrayExt(comptime Self: type) type { + const Indexer = Self.Indexer; + return struct { + /// Initializes all values in the enum array + pub fn init(init_values: EnumFieldStruct(E, V, @as(?V, null))) Self { + return initDefault(@as(?V, null), init_values); + } + + /// Initializes values in the enum array, with the specified default. + pub fn initDefault(comptime default: ?V, init_values: EnumFieldStruct(E, V, default)) Self { + var result = Self{ .values = undefined }; + comptime var i: usize = 0; + inline while (i < Self.len) : (i += 1) { + const key = comptime Indexer.keyForIndex(i); + const tag = @tagName(key); + result.values[i] = @field(init_values, tag); + } + return result; + } + }; + } + }; + return IndexedArray(EnumIndexer(E), V, mixin.EnumArrayExt); +} + +/// Pass this function as the Ext parameter to Indexed* if you +/// do not want to attach any extensions. This parameter was +/// originally an optional, but optional generic functions +/// seem to be broken at the moment. +/// TODO: Once #8169 is fixed, consider switching this param +/// back to an optional. +pub fn NoExtension(comptime Self: type) type { + return NoExt; +} +const NoExt = struct{}; + +/// A set type with an Indexer mapping from keys to indices. +/// Presence or absence is stored as a dense bitfield. This +/// type does no allocation and can be copied by value. +pub fn IndexedSet(comptime I: type, comptime Ext: fn(type)type) type { + comptime ensureIndexer(I); + return struct { + const Self = @This(); + + pub usingnamespace Ext(Self); + + /// The indexing rules for converting between keys and indices. + pub const Indexer = I; + /// The element type for this set. + pub const Key = Indexer.Key; + + const BitSet = std.StaticBitSet(Indexer.count); + + /// The maximum number of items in this set. + pub const len = Indexer.count; + + bits: BitSet = BitSet.initEmpty(), + + /// Returns a set containing all possible keys. + pub fn initFull() Self { + return .{ .bits = BitSet.initFull() }; + } + + /// Returns the number of keys in the set. + pub fn count(self: Self) usize { + return self.bits.count(); + } + + /// Checks if a key is in the set. + pub fn contains(self: Self, key: Key) bool { + return self.bits.isSet(Indexer.indexOf(key)); + } + + /// Puts a key in the set. + pub fn insert(self: *Self, key: Key) void { + self.bits.set(Indexer.indexOf(key)); + } + + /// Removes a key from the set. + pub fn remove(self: *Self, key: Key) void { + self.bits.unset(Indexer.indexOf(key)); + } + + /// Changes the presence of a key in the set to match the passed bool. + pub fn setPresent(self: *Self, key: Key, present: bool) void { + self.bits.setValue(Indexer.indexOf(key), present); + } + + /// Toggles the presence of a key in the set. If the key is in + /// the set, removes it. Otherwise adds it. + pub fn toggle(self: *Self, key: Key) void { + self.bits.toggle(Indexer.indexOf(key)); + } + + /// Toggles the presence of all keys in the passed set. + pub fn toggleSet(self: *Self, other: Self) void { + self.bits.toggleSet(other.bits); + } + + /// Toggles all possible keys in the set. + pub fn toggleAll(self: *Self) void { + self.bits.toggleAll(); + } + + /// Adds all keys in the passed set to this set. + pub fn setUnion(self: *Self, other: Self) void { + self.bits.setUnion(other.bits); + } + + /// Removes all keys which are not in the passed set. + pub fn setIntersection(self: *Self, other: Self) void { + self.bits.setIntersection(other.bits); + } + + /// Returns an iterator over this set, which iterates in + /// index order. Modifications to the set during iteration + /// may or may not be observed by the iterator, but will + /// not invalidate it. + pub fn iterator(self: *Self) Iterator { + return .{ .inner = self.bits.iterator(.{}) }; + } + + pub const Iterator = struct { + inner: BitSet.Iterator(.{}), + + pub fn next(self: *Iterator) ?Key { + return if (self.inner.next()) |index| + Indexer.keyForIndex(index) + else null; + } + }; + }; +} + +/// A map from keys to values, using an index lookup. Uses a +/// bitfield to track presence and a dense array of values. +/// This type does no allocation and can be copied by value. +pub fn IndexedMap(comptime I: type, comptime V: type, comptime Ext: fn(type)type) type { + comptime ensureIndexer(I); + return struct { + const Self = @This(); + + pub usingnamespace Ext(Self); + + /// The index mapping for this map + pub const Indexer = I; + /// The key type used to index this map + pub const Key = Indexer.Key; + /// The value type stored in this map + pub const Value = V; + /// The number of possible keys in the map + pub const len = Indexer.count; + + const BitSet = std.StaticBitSet(Indexer.count); + + /// Bits determining whether items are in the map + bits: BitSet = BitSet.initEmpty(), + /// Values of items in the map. If the associated + /// bit is zero, the value is undefined. + values: [Indexer.count]Value = undefined, + + /// The number of items in the map. + pub fn count(self: Self) usize { + return self.bits.count(); + } + + /// Checks if the map contains an item. + pub fn contains(self: Self, key: Key) bool { + return self.bits.isSet(Indexer.indexOf(key)); + } + + /// Gets the value associated with a key. + /// If the key is not in the map, returns null. + pub fn get(self: Self, key: Key) ?Value { + const index = Indexer.indexOf(key); + return if (self.bits.isSet(index)) self.values[index] else null; + } + + /// Gets the value associated with a key, which must + /// exist in the map. + pub fn getAssertContains(self: Self, key: Key) Value { + const index = Indexer.indexOf(key); + assert(self.bits.isSet(index)); + return self.values[index]; + } + + /// Gets the address of the value associated with a key. + /// If the key is not in the map, returns null. + pub fn getPtr(self: *Self, key: Key) ?*Value { + const index = Indexer.indexOf(key); + return if (self.bits.isSet(index)) &self.values[index] else null; + } + + /// Gets the address of the const value associated with a key. + /// If the key is not in the map, returns null. + pub fn getPtrConst(self: *const Self, key: Key) ?*const Value { + const index = Indexer.indexOf(key); + return if (self.bits.isSet(index)) &self.values[index] else null; + } + + /// Gets the address of the value associated with a key. + /// The key must be present in the map. + pub fn getPtrAssertContains(self: *Self, key: Key) *Value { + const index = Indexer.indexOf(key); + assert(self.bits.isSet(index)); + return &self.values[index]; + } + + /// Adds the key to the map with the supplied value. + /// If the key is already in the map, overwrites the value. + pub fn put(self: *Self, key: Key, value: Value) void { + const index = Indexer.indexOf(key); + self.bits.set(index); + self.values[index] = value; + } + + /// Adds the key to the map with an undefined value. + /// If the key is already in the map, the value becomes undefined. + /// A pointer to the value is returned, which should be + /// used to initialize the value. + pub fn putUninitialized(self: *Self, key: Key) *Value { + const index = Indexer.indexOf(key); + self.bits.set(index); + self.values[index] = undefined; + return &self.values[index]; + } + + /// Sets the value associated with the key in the map, + /// and returns the old value. If the key was not in + /// the map, returns null. + pub fn fetchPut(self: *Self, key: Key, value: Value) ?Value { + const index = Indexer.indexOf(key); + const result: ?Value = if (self.bits.isSet(index)) self.values[index] else null; + self.bits.set(index); + self.values[index] = value; + return result; + } + + /// Removes a key from the map. If the key was not in the map, + /// does nothing. + pub fn remove(self: *Self, key: Key) void { + const index = Indexer.indexOf(key); + self.bits.unset(index); + self.values[index] = undefined; + } + + /// Removes a key from the map, and returns the old value. + /// If the key was not in the map, returns null. + pub fn fetchRemove(self: *Self, key: Key) ?Value { + const index = Indexer.indexOf(key); + const result: ?Value = if (self.bits.isSet(index)) self.values[index] else null; + self.bits.unset(index); + self.values[index] = undefined; + return result; + } + + /// Returns an iterator over the map, which visits items in index order. + /// Modifications to the underlying map may or may not be observed by + /// the iterator, but will not invalidate it. + pub fn iterator(self: *Self) Iterator { + return .{ + .inner = self.bits.iterator(.{}), + .values = &self.values, + }; + } + + /// An entry in the map. + pub const Entry = struct { + /// The key associated with this entry. + /// Modifying this key will not change the map. + key: Key, + + /// A pointer to the value in the map associated + /// with this key. Modifications through this + /// pointer will modify the underlying data. + value: *Value, + }; + + pub const Iterator = struct { + inner: BitSet.Iterator(.{}), + values: *[Indexer.count]Value, + + pub fn next(self: *Iterator) ?Entry { + return if (self.inner.next()) |index| + Entry{ + .key = Indexer.keyForIndex(index), + .value = &self.values[index], + } + else null; + } + }; + }; +} + +/// A dense array of values, using an indexed lookup. +/// This type does no allocation and can be copied by value. +pub fn IndexedArray(comptime I: type, comptime V: type, comptime Ext: fn(type)type) type { + comptime ensureIndexer(I); + return struct { + const Self = @This(); + + pub usingnamespace Ext(Self); + + /// The index mapping for this map + pub const Indexer = I; + /// The key type used to index this map + pub const Key = Indexer.Key; + /// The value type stored in this map + pub const Value = V; + /// The number of possible keys in the map + pub const len = Indexer.count; + + values: [Indexer.count]Value, + + pub fn initUndefined() Self { + return Self{ .values = undefined }; + } + + pub fn initFill(v: Value) Self { + var self: Self = undefined; + std.mem.set(Value, &self.values, v); + return self; + } + + /// Returns the value in the array associated with a key. + pub fn get(self: Self, key: Key) Value { + return self.values[Indexer.indexOf(key)]; + } + + /// Returns a pointer to the slot in the array associated with a key. + pub fn getPtr(self: *Self, key: Key) *Value { + return &self.values[Indexer.indexOf(key)]; + } + + /// Returns a const pointer to the slot in the array associated with a key. + pub fn getPtrConst(self: *const Self, key: Key) *const Value { + return &self.values[Indexer.indexOf(key)]; + } + + /// Sets the value in the slot associated with a key. + pub fn set(self: *Self, key: Key, value: Value) void { + self.values[Indexer.indexOf(key)] = value; + } + + /// Iterates over the items in the array, in index order. + pub fn iterator(self: *Self) Iterator { + return .{ + .values = &self.values, + }; + } + + /// An entry in the array. + pub const Entry = struct { + /// The key associated with this entry. + /// Modifying this key will not change the array. + key: Key, + + /// A pointer to the value in the array associated + /// with this key. Modifications through this + /// pointer will modify the underlying data. + value: *Value, + }; + + pub const Iterator = struct { + index: usize = 0, + values: *[Indexer.count]Value, + + pub fn next(self: *Iterator) ?Entry { + const index = self.index; + if (index < Indexer.count) { + self.index += 1; + return Entry{ + .key = Indexer.keyForIndex(index), + .value = &self.values[index], + }; + } + return null; + } + }; + }; +} + +/// Verifies that a type is a valid Indexer, providing a helpful +/// compile error if not. An Indexer maps a comptime known set +/// of keys to a dense set of zero-based indices. +/// The indexer interface must look like this: +/// ``` +/// struct { +/// /// The key type which this indexer converts to indices +/// pub const Key: type, +/// /// The number of indexes in the dense mapping +/// pub const count: usize, +/// /// Converts from a key to an index +/// pub fn indexOf(Key) usize; +/// /// Converts from an index to a key +/// pub fn keyForIndex(usize) Key; +/// } +/// ``` +pub fn ensureIndexer(comptime T: type) void { + comptime { + if (!@hasDecl(T, "Key")) @compileError("Indexer must have decl Key: type."); + if (@TypeOf(T.Key) != type) @compileError("Indexer.Key must be a type."); + if (!@hasDecl(T, "count")) @compileError("Indexer must have decl count: usize."); + if (@TypeOf(T.count) != usize) @compileError("Indexer.count must be a usize."); + if (!@hasDecl(T, "indexOf")) @compileError("Indexer.indexOf must be a fn(Key)usize."); + if (@TypeOf(T.indexOf) != fn(T.Key)usize) @compileError("Indexer must have decl indexOf: fn(Key)usize."); + if (!@hasDecl(T, "keyForIndex")) @compileError("Indexer must have decl keyForIndex: fn(usize)Key."); + if (@TypeOf(T.keyForIndex) != fn(usize)T.Key) @compileError("Indexer.keyForIndex must be a fn(usize)Key."); + } +} + +test "std.enums.ensureIndexer" { + ensureIndexer(struct { + pub const Key = u32; + pub const count: usize = 8; + pub fn indexOf(k: Key) usize { + return @intCast(usize, k); + } + pub fn keyForIndex(index: usize) Key { + return @intCast(Key, index); + } + }); +} + +fn ascByValue(ctx: void, comptime a: EnumField, comptime b: EnumField) bool { + return a.value < b.value; +} +pub fn EnumIndexer(comptime E: type) type { + if (!@typeInfo(E).Enum.is_exhaustive) { + @compileError("Cannot create an enum indexer for a non-exhaustive enum."); + } + + const const_fields = uniqueFields(E); + var fields = const_fields[0..const_fields.len].*; + if (fields.len == 0) { + return struct { + pub const Key = E; + pub const count: usize = 0; + pub fn indexOf(e: E) usize { unreachable; } + pub fn keyForIndex(i: usize) E { unreachable; } + }; + } + std.sort.sort(EnumField, &fields, {}, ascByValue); + const min = fields[0].value; + const max = fields[fields.len-1].value; + if (max - min == fields.len-1) { + return struct { + pub const Key = E; + pub const count = fields.len; + pub fn indexOf(e: E) usize { + return @intCast(usize, @enumToInt(e) - min); + } + pub fn keyForIndex(i: usize) E { + // TODO fix addition semantics. This calculation + // gives up some safety to avoid artificially limiting + // the range of signed enum values to max_isize. + const enum_value = if (min < 0) @bitCast(isize, i) +% min else i + min; + return @intToEnum(E, @intCast(std.meta.Tag(E), enum_value)); + } + }; + } + + const keys = valuesFromFields(E, &fields); + + return struct { + pub const Key = E; + pub const count = fields.len; + pub fn indexOf(e: E) usize { + for (keys) |k, i| { + if (k == e) return i; + } + unreachable; + } + pub fn keyForIndex(i: usize) E { + return keys[i]; + } + }; +} + +test "std.enums.EnumIndexer dense zeroed" { + const E = enum{ b = 1, a = 0, c = 2 }; + const Indexer = EnumIndexer(E); + ensureIndexer(Indexer); + testing.expectEqual(E, Indexer.Key); + testing.expectEqual(@as(usize, 3), Indexer.count); + + testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a)); + testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b)); + testing.expectEqual(@as(usize, 2), Indexer.indexOf(.c)); + + testing.expectEqual(E.a, Indexer.keyForIndex(0)); + testing.expectEqual(E.b, Indexer.keyForIndex(1)); + testing.expectEqual(E.c, Indexer.keyForIndex(2)); +} + +test "std.enums.EnumIndexer dense positive" { + const E = enum(u4) { c = 6, a = 4, b = 5 }; + const Indexer = EnumIndexer(E); + ensureIndexer(Indexer); + testing.expectEqual(E, Indexer.Key); + testing.expectEqual(@as(usize, 3), Indexer.count); + + testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a)); + testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b)); + testing.expectEqual(@as(usize, 2), Indexer.indexOf(.c)); + + testing.expectEqual(E.a, Indexer.keyForIndex(0)); + testing.expectEqual(E.b, Indexer.keyForIndex(1)); + testing.expectEqual(E.c, Indexer.keyForIndex(2)); +} + +test "std.enums.EnumIndexer dense negative" { + const E = enum(i4) { a = -6, c = -4, b = -5 }; + const Indexer = EnumIndexer(E); + ensureIndexer(Indexer); + testing.expectEqual(E, Indexer.Key); + testing.expectEqual(@as(usize, 3), Indexer.count); + + testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a)); + testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b)); + testing.expectEqual(@as(usize, 2), Indexer.indexOf(.c)); + + testing.expectEqual(E.a, Indexer.keyForIndex(0)); + testing.expectEqual(E.b, Indexer.keyForIndex(1)); + testing.expectEqual(E.c, Indexer.keyForIndex(2)); +} + +test "std.enums.EnumIndexer sparse" { + const E = enum(i4) { a = -2, c = 6, b = 4 }; + const Indexer = EnumIndexer(E); + ensureIndexer(Indexer); + testing.expectEqual(E, Indexer.Key); + testing.expectEqual(@as(usize, 3), Indexer.count); + + testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a)); + testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b)); + testing.expectEqual(@as(usize, 2), Indexer.indexOf(.c)); + + testing.expectEqual(E.a, Indexer.keyForIndex(0)); + testing.expectEqual(E.b, Indexer.keyForIndex(1)); + testing.expectEqual(E.c, Indexer.keyForIndex(2)); +} + +test "std.enums.EnumIndexer repeats" { + const E = extern enum{ a = -2, c = 6, b = 4, b2 = 4 }; + const Indexer = EnumIndexer(E); + ensureIndexer(Indexer); + testing.expectEqual(E, Indexer.Key); + testing.expectEqual(@as(usize, 3), Indexer.count); + + testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a)); + testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b)); + testing.expectEqual(@as(usize, 2), Indexer.indexOf(.c)); + + testing.expectEqual(E.a, Indexer.keyForIndex(0)); + testing.expectEqual(E.b, Indexer.keyForIndex(1)); + testing.expectEqual(E.c, Indexer.keyForIndex(2)); +} + +test "std.enums.EnumSet" { + const E = extern enum { a, b, c, d, e = 0 }; + const Set = EnumSet(E); + testing.expectEqual(E, Set.Key); + testing.expectEqual(EnumIndexer(E), Set.Indexer); + testing.expectEqual(@as(usize, 4), Set.len); + + // Empty sets + const empty = Set{}; + comptime testing.expect(empty.count() == 0); + + var empty_b = Set.init(.{}); + testing.expect(empty_b.count() == 0); + + const empty_c = comptime Set.init(.{}); + comptime testing.expect(empty_c.count() == 0); + + const full = Set.initFull(); + testing.expect(full.count() == Set.len); + + const full_b = comptime Set.initFull(); + comptime testing.expect(full_b.count() == Set.len); + + testing.expectEqual(false, empty.contains(.a)); + testing.expectEqual(false, empty.contains(.b)); + testing.expectEqual(false, empty.contains(.c)); + testing.expectEqual(false, empty.contains(.d)); + testing.expectEqual(false, empty.contains(.e)); + { + var iter = empty_b.iterator(); + testing.expectEqual(@as(?E, null), iter.next()); + } + + var mut = Set.init(.{ + .a=true, .c=true, + }); + testing.expectEqual(@as(usize, 2), mut.count()); + testing.expectEqual(true, mut.contains(.a)); + testing.expectEqual(false, mut.contains(.b)); + testing.expectEqual(true, mut.contains(.c)); + testing.expectEqual(false, mut.contains(.d)); + testing.expectEqual(true, mut.contains(.e)); // aliases a + { + var it = mut.iterator(); + testing.expectEqual(@as(?E, .a), it.next()); + testing.expectEqual(@as(?E, .c), it.next()); + testing.expectEqual(@as(?E, null), it.next()); + } + + mut.toggleAll(); + testing.expectEqual(@as(usize, 2), mut.count()); + testing.expectEqual(false, mut.contains(.a)); + testing.expectEqual(true, mut.contains(.b)); + testing.expectEqual(false, mut.contains(.c)); + testing.expectEqual(true, mut.contains(.d)); + testing.expectEqual(false, mut.contains(.e)); // aliases a + { + var it = mut.iterator(); + testing.expectEqual(@as(?E, .b), it.next()); + testing.expectEqual(@as(?E, .d), it.next()); + testing.expectEqual(@as(?E, null), it.next()); + } + + mut.toggleSet(Set.init(.{ .a=true, .b=true })); + testing.expectEqual(@as(usize, 2), mut.count()); + testing.expectEqual(true, mut.contains(.a)); + testing.expectEqual(false, mut.contains(.b)); + testing.expectEqual(false, mut.contains(.c)); + testing.expectEqual(true, mut.contains(.d)); + testing.expectEqual(true, mut.contains(.e)); // aliases a + + mut.setUnion(Set.init(.{ .a=true, .b=true })); + testing.expectEqual(@as(usize, 3), mut.count()); + testing.expectEqual(true, mut.contains(.a)); + testing.expectEqual(true, mut.contains(.b)); + testing.expectEqual(false, mut.contains(.c)); + testing.expectEqual(true, mut.contains(.d)); + + mut.remove(.c); + mut.remove(.b); + testing.expectEqual(@as(usize, 2), mut.count()); + testing.expectEqual(true, mut.contains(.a)); + testing.expectEqual(false, mut.contains(.b)); + testing.expectEqual(false, mut.contains(.c)); + testing.expectEqual(true, mut.contains(.d)); + + mut.setIntersection(Set.init(.{ .a=true, .b=true })); + testing.expectEqual(@as(usize, 1), mut.count()); + testing.expectEqual(true, mut.contains(.a)); + testing.expectEqual(false, mut.contains(.b)); + testing.expectEqual(false, mut.contains(.c)); + testing.expectEqual(false, mut.contains(.d)); + + mut.insert(.a); + mut.insert(.b); + testing.expectEqual(@as(usize, 2), mut.count()); + testing.expectEqual(true, mut.contains(.a)); + testing.expectEqual(true, mut.contains(.b)); + testing.expectEqual(false, mut.contains(.c)); + testing.expectEqual(false, mut.contains(.d)); + + mut.setPresent(.a, false); + mut.toggle(.b); + mut.toggle(.c); + mut.setPresent(.d, true); + testing.expectEqual(@as(usize, 2), mut.count()); + testing.expectEqual(false, mut.contains(.a)); + testing.expectEqual(false, mut.contains(.b)); + testing.expectEqual(true, mut.contains(.c)); + testing.expectEqual(true, mut.contains(.d)); +} + +test "std.enums.EnumArray void" { + const E = extern enum { a, b, c, d, e = 0 }; + const ArrayVoid = EnumArray(E, void); + testing.expectEqual(E, ArrayVoid.Key); + testing.expectEqual(EnumIndexer(E), ArrayVoid.Indexer); + testing.expectEqual(void, ArrayVoid.Value); + testing.expectEqual(@as(usize, 4), ArrayVoid.len); + + const undef = ArrayVoid.initUndefined(); + var inst = ArrayVoid.initFill({}); + const inst2 = ArrayVoid.init(.{ .a = {}, .b = {}, .c = {}, .d = {} }); + const inst3 = ArrayVoid.initDefault({}, .{}); + + _ = inst.get(.a); + _ = inst.getPtr(.b); + _ = inst.getPtrConst(.c); + inst.set(.a, {}); + + var it = inst.iterator(); + testing.expectEqual(E.a, it.next().?.key); + testing.expectEqual(E.b, it.next().?.key); + testing.expectEqual(E.c, it.next().?.key); + testing.expectEqual(E.d, it.next().?.key); + testing.expect(it.next() == null); +} + +test "std.enums.EnumArray sized" { + const E = extern enum { a, b, c, d, e = 0 }; + const Array = EnumArray(E, usize); + testing.expectEqual(E, Array.Key); + testing.expectEqual(EnumIndexer(E), Array.Indexer); + testing.expectEqual(usize, Array.Value); + testing.expectEqual(@as(usize, 4), Array.len); + + const undef = Array.initUndefined(); + var inst = Array.initFill(5); + const inst2 = Array.init(.{ .a = 1, .b = 2, .c = 3, .d = 4 }); + const inst3 = Array.initDefault(6, .{.b = 4, .c = 2}); + + testing.expectEqual(@as(usize, 5), inst.get(.a)); + testing.expectEqual(@as(usize, 5), inst.get(.b)); + testing.expectEqual(@as(usize, 5), inst.get(.c)); + testing.expectEqual(@as(usize, 5), inst.get(.d)); + + testing.expectEqual(@as(usize, 1), inst2.get(.a)); + testing.expectEqual(@as(usize, 2), inst2.get(.b)); + testing.expectEqual(@as(usize, 3), inst2.get(.c)); + testing.expectEqual(@as(usize, 4), inst2.get(.d)); + + testing.expectEqual(@as(usize, 6), inst3.get(.a)); + testing.expectEqual(@as(usize, 4), inst3.get(.b)); + testing.expectEqual(@as(usize, 2), inst3.get(.c)); + testing.expectEqual(@as(usize, 6), inst3.get(.d)); + + testing.expectEqual(&inst.values[0], inst.getPtr(.a)); + testing.expectEqual(&inst.values[1], inst.getPtr(.b)); + testing.expectEqual(&inst.values[2], inst.getPtr(.c)); + testing.expectEqual(&inst.values[3], inst.getPtr(.d)); + + testing.expectEqual(@as(*const usize, &inst.values[0]), inst.getPtrConst(.a)); + testing.expectEqual(@as(*const usize, &inst.values[1]), inst.getPtrConst(.b)); + testing.expectEqual(@as(*const usize, &inst.values[2]), inst.getPtrConst(.c)); + testing.expectEqual(@as(*const usize, &inst.values[3]), inst.getPtrConst(.d)); + + inst.set(.c, 8); + testing.expectEqual(@as(usize, 5), inst.get(.a)); + testing.expectEqual(@as(usize, 5), inst.get(.b)); + testing.expectEqual(@as(usize, 8), inst.get(.c)); + testing.expectEqual(@as(usize, 5), inst.get(.d)); + + var it = inst.iterator(); + const Entry = Array.Entry; + testing.expectEqual(@as(?Entry, Entry{ + .key = .a, + .value = &inst.values[0], + }), it.next()); + testing.expectEqual(@as(?Entry, Entry{ + .key = .b, + .value = &inst.values[1], + }), it.next()); + testing.expectEqual(@as(?Entry, Entry{ + .key = .c, + .value = &inst.values[2], + }), it.next()); + testing.expectEqual(@as(?Entry, Entry{ + .key = .d, + .value = &inst.values[3], + }), it.next()); + testing.expectEqual(@as(?Entry, null), it.next()); +} + +test "std.enums.EnumMap void" { + const E = extern enum { a, b, c, d, e = 0 }; + const Map = EnumMap(E, void); + testing.expectEqual(E, Map.Key); + testing.expectEqual(EnumIndexer(E), Map.Indexer); + testing.expectEqual(void, Map.Value); + testing.expectEqual(@as(usize, 4), Map.len); + + const b = Map.initFull({}); + testing.expectEqual(@as(usize, 4), b.count()); + + const c = Map.initFullWith(.{ .a = {}, .b = {}, .c = {}, .d = {} }); + testing.expectEqual(@as(usize, 4), c.count()); + + const d = Map.initFullWithDefault({}, .{ .b = {} }); + testing.expectEqual(@as(usize, 4), d.count()); + + var a = Map.init(.{ .b = {}, .d = {} }); + testing.expectEqual(@as(usize, 2), a.count()); + testing.expectEqual(false, a.contains(.a)); + testing.expectEqual(true, a.contains(.b)); + testing.expectEqual(false, a.contains(.c)); + testing.expectEqual(true, a.contains(.d)); + testing.expect(a.get(.a) == null); + testing.expect(a.get(.b) != null); + testing.expect(a.get(.c) == null); + testing.expect(a.get(.d) != null); + testing.expect(a.getPtr(.a) == null); + testing.expect(a.getPtr(.b) != null); + testing.expect(a.getPtr(.c) == null); + testing.expect(a.getPtr(.d) != null); + testing.expect(a.getPtrConst(.a) == null); + testing.expect(a.getPtrConst(.b) != null); + testing.expect(a.getPtrConst(.c) == null); + testing.expect(a.getPtrConst(.d) != null); + _ = a.getPtrAssertContains(.b); + _ = a.getAssertContains(.d); + + a.put(.a, {}); + a.put(.a, {}); + a.putUninitialized(.c).* = {}; + a.putUninitialized(.c).* = {}; + + testing.expectEqual(@as(usize, 4), a.count()); + testing.expect(a.get(.a) != null); + testing.expect(a.get(.b) != null); + testing.expect(a.get(.c) != null); + testing.expect(a.get(.d) != null); + + a.remove(.a); + _ = a.fetchRemove(.c); + + var iter = a.iterator(); + const Entry = Map.Entry; + testing.expectEqual(E.b, iter.next().?.key); + testing.expectEqual(E.d, iter.next().?.key); + testing.expect(iter.next() == null); +} + +test "std.enums.EnumMap sized" { + const E = extern enum { a, b, c, d, e = 0 }; + const Map = EnumMap(E, usize); + testing.expectEqual(E, Map.Key); + testing.expectEqual(EnumIndexer(E), Map.Indexer); + testing.expectEqual(usize, Map.Value); + testing.expectEqual(@as(usize, 4), Map.len); + + const b = Map.initFull(5); + testing.expectEqual(@as(usize, 4), b.count()); + testing.expect(b.contains(.a)); + testing.expect(b.contains(.b)); + testing.expect(b.contains(.c)); + testing.expect(b.contains(.d)); + testing.expectEqual(@as(?usize, 5), b.get(.a)); + testing.expectEqual(@as(?usize, 5), b.get(.b)); + testing.expectEqual(@as(?usize, 5), b.get(.c)); + testing.expectEqual(@as(?usize, 5), b.get(.d)); + + const c = Map.initFullWith(.{ .a = 1, .b = 2, .c = 3, .d = 4 }); + testing.expectEqual(@as(usize, 4), c.count()); + testing.expect(c.contains(.a)); + testing.expect(c.contains(.b)); + testing.expect(c.contains(.c)); + testing.expect(c.contains(.d)); + testing.expectEqual(@as(?usize, 1), c.get(.a)); + testing.expectEqual(@as(?usize, 2), c.get(.b)); + testing.expectEqual(@as(?usize, 3), c.get(.c)); + testing.expectEqual(@as(?usize, 4), c.get(.d)); + + const d = Map.initFullWithDefault(6, .{ .b = 2, .c = 4 }); + testing.expectEqual(@as(usize, 4), d.count()); + testing.expect(d.contains(.a)); + testing.expect(d.contains(.b)); + testing.expect(d.contains(.c)); + testing.expect(d.contains(.d)); + testing.expectEqual(@as(?usize, 6), d.get(.a)); + testing.expectEqual(@as(?usize, 2), d.get(.b)); + testing.expectEqual(@as(?usize, 4), d.get(.c)); + testing.expectEqual(@as(?usize, 6), d.get(.d)); + + var a = Map.init(.{ .b = 2, .d = 4 }); + testing.expectEqual(@as(usize, 2), a.count()); + testing.expectEqual(false, a.contains(.a)); + testing.expectEqual(true, a.contains(.b)); + testing.expectEqual(false, a.contains(.c)); + testing.expectEqual(true, a.contains(.d)); + + testing.expectEqual(@as(?usize, null), a.get(.a)); + testing.expectEqual(@as(?usize, 2), a.get(.b)); + testing.expectEqual(@as(?usize, null), a.get(.c)); + testing.expectEqual(@as(?usize, 4), a.get(.d)); + + testing.expectEqual(@as(?*usize, null), a.getPtr(.a)); + testing.expectEqual(@as(?*usize, &a.values[1]), a.getPtr(.b)); + testing.expectEqual(@as(?*usize, null), a.getPtr(.c)); + testing.expectEqual(@as(?*usize, &a.values[3]), a.getPtr(.d)); + + testing.expectEqual(@as(?*const usize, null), a.getPtrConst(.a)); + testing.expectEqual(@as(?*const usize, &a.values[1]), a.getPtrConst(.b)); + testing.expectEqual(@as(?*const usize, null), a.getPtrConst(.c)); + testing.expectEqual(@as(?*const usize, &a.values[3]), a.getPtrConst(.d)); + + testing.expectEqual(@as(*const usize, &a.values[1]), a.getPtrAssertContains(.b)); + testing.expectEqual(@as(*const usize, &a.values[3]), a.getPtrAssertContains(.d)); + testing.expectEqual(@as(usize, 2), a.getAssertContains(.b)); + testing.expectEqual(@as(usize, 4), a.getAssertContains(.d)); + + a.put(.a, 3); + a.put(.a, 5); + a.putUninitialized(.c).* = 7; + a.putUninitialized(.c).* = 9; + + testing.expectEqual(@as(usize, 4), a.count()); + testing.expectEqual(@as(?usize, 5), a.get(.a)); + testing.expectEqual(@as(?usize, 2), a.get(.b)); + testing.expectEqual(@as(?usize, 9), a.get(.c)); + testing.expectEqual(@as(?usize, 4), a.get(.d)); + + a.remove(.a); + testing.expectEqual(@as(?usize, null), a.fetchRemove(.a)); + testing.expectEqual(@as(?usize, 9), a.fetchRemove(.c)); + a.remove(.c); + + var iter = a.iterator(); + const Entry = Map.Entry; + testing.expectEqual(@as(?Entry, Entry{ + .key = .b, .value = &a.values[1], + }), iter.next()); + testing.expectEqual(@as(?Entry, Entry{ + .key = .d, .value = &a.values[3], + }), iter.next()); + testing.expectEqual(@as(?Entry, null), iter.next()); +} diff --git a/lib/std/fmt.zig b/lib/std/fmt.zig index 90c0d98539..bfe28ef203 100644 --- a/lib/std/fmt.zig +++ b/lib/std/fmt.zig @@ -1250,9 +1250,9 @@ fn formatDuration(ns: u64, comptime fmt: []const u8, options: std.fmt.FormatOpti const kunits = ns_remaining * 1000 / unit.ns; if (kunits >= 1000) { try formatInt(kunits / 1000, 10, false, .{}, writer); - if (kunits > 1000) { + const frac = kunits % 1000; + if (frac > 0) { // Write up to 3 decimal places - const frac = kunits % 1000; var buf = [_]u8{ '.', 0, 0, 0 }; _ = formatIntBuf(buf[1..], frac, 10, false, .{ .fill = '0', .width = 3 }); var end: usize = 4; @@ -1286,9 +1286,14 @@ test "fmtDuration" { .{ .s = "1us", .d = std.time.ns_per_us }, .{ .s = "1.45us", .d = 1450 }, .{ .s = "1.5us", .d = 3 * std.time.ns_per_us / 2 }, + .{ .s = "14.5us", .d = 14500 }, + .{ .s = "145us", .d = 145000 }, .{ .s = "999.999us", .d = std.time.ns_per_ms - 1 }, .{ .s = "1ms", .d = std.time.ns_per_ms + 1 }, .{ .s = "1.5ms", .d = 3 * std.time.ns_per_ms / 2 }, + .{ .s = "1.11ms", .d = 1110000 }, + .{ .s = "1.111ms", .d = 1111000 }, + .{ .s = "1.111ms", .d = 1111100 }, .{ .s = "999.999ms", .d = std.time.ns_per_s - 1 }, .{ .s = "1s", .d = std.time.ns_per_s }, .{ .s = "59.999s", .d = std.time.ns_per_min - 1 }, diff --git a/lib/std/fs.zig b/lib/std/fs.zig index 79385708af..1a02cd5b6b 100644 --- a/lib/std/fs.zig +++ b/lib/std/fs.zig @@ -50,13 +50,13 @@ pub const MAX_PATH_BYTES = switch (builtin.os.tag) { else => @compileError("Unsupported OS"), }; -pub const base64_alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; +pub const base64_alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_".*; /// Base64 encoder, replacing the standard `+/` with `-_` so that it can be used in a file name on any filesystem. -pub const base64_encoder = base64.Base64Encoder.init(base64_alphabet, base64.standard_pad_char); +pub const base64_encoder = base64.Base64Encoder.init(base64_alphabet, null); /// Base64 decoder, replacing the standard `+/` with `-_` so that it can be used in a file name on any filesystem. -pub const base64_decoder = base64.Base64Decoder.init(base64_alphabet, base64.standard_pad_char); +pub const base64_decoder = base64.Base64Decoder.init(base64_alphabet, null); /// Whether or not async file system syscalls need a dedicated thread because the operating /// system does not support non-blocking I/O on the file system. @@ -77,7 +77,7 @@ pub fn atomicSymLink(allocator: *Allocator, existing_path: []const u8, new_path: const dirname = path.dirname(new_path) orelse "."; var rand_buf: [AtomicFile.RANDOM_BYTES]u8 = undefined; - const tmp_path = try allocator.alloc(u8, dirname.len + 1 + base64.Base64Encoder.calcSize(rand_buf.len)); + const tmp_path = try allocator.alloc(u8, dirname.len + 1 + base64_encoder.calcSize(rand_buf.len)); defer allocator.free(tmp_path); mem.copy(u8, tmp_path[0..], dirname); tmp_path[dirname.len] = path.sep; @@ -142,7 +142,7 @@ pub const AtomicFile = struct { const InitError = File.OpenError; const RANDOM_BYTES = 12; - const TMP_PATH_LEN = base64.Base64Encoder.calcSize(RANDOM_BYTES); + const TMP_PATH_LEN = base64_encoder.calcSize(RANDOM_BYTES); /// Note that the `Dir.atomicFile` API may be more handy than this lower-level function. pub fn init( diff --git a/lib/std/fs/path.zig b/lib/std/fs/path.zig index 776cb4040c..0bba522fb6 100644 --- a/lib/std/fs/path.zig +++ b/lib/std/fs/path.zig @@ -92,7 +92,7 @@ pub fn join(allocator: *Allocator, paths: []const []const u8) ![]u8 { /// Naively combines a series of paths with the native path seperator and null terminator. /// Allocates memory for the result, which must be freed by the caller. pub fn joinZ(allocator: *Allocator, paths: []const []const u8) ![:0]u8 { - const out = joinSepMaybeZ(allocator, sep, isSep, paths, true); + const out = try joinSepMaybeZ(allocator, sep, isSep, paths, true); return out[0 .. out.len - 1 :0]; } @@ -119,6 +119,16 @@ fn testJoinMaybeZPosix(paths: []const []const u8, expected: []const u8, zero: bo } test "join" { + { + const actual: []u8 = try join(testing.allocator, &[_][]const u8{}); + defer testing.allocator.free(actual); + testing.expectEqualSlices(u8, "", actual); + } + { + const actual: [:0]u8 = try joinZ(testing.allocator, &[_][]const u8{}); + defer testing.allocator.free(actual); + testing.expectEqualSlices(u8, "", actual); + } for (&[_]bool{ false, true }) |zero| { testJoinMaybeZWindows(&[_][]const u8{}, "", zero); testJoinMaybeZWindows(&[_][]const u8{ "c:\\a\\b", "c" }, "c:\\a\\b\\c", zero); diff --git a/lib/std/hash/auto_hash.zig b/lib/std/hash/auto_hash.zig index 4afc2b425b..e053e87efb 100644 --- a/lib/std/hash/auto_hash.zig +++ b/lib/std/hash/auto_hash.zig @@ -95,7 +95,7 @@ pub fn hash(hasher: anytype, key: anytype, comptime strat: HashStrategy) void { .EnumLiteral, .Frame, .Float, - => @compileError("cannot hash this type"), + => @compileError("unable to hash type " ++ @typeName(Key)), // Help the optimizer see that hashing an int is easy by inlining! // TODO Check if the situation is better after #561 is resolved. diff --git a/lib/std/macho.zig b/lib/std/macho.zig index 6785abffca..f66626bafe 100644 --- a/lib/std/macho.zig +++ b/lib/std/macho.zig @@ -1227,6 +1227,24 @@ pub const S_ATTR_EXT_RELOC = 0x200; /// section has local relocation entries pub const S_ATTR_LOC_RELOC = 0x100; +/// template of initial values for TLVs +pub const S_THREAD_LOCAL_REGULAR = 0x11; + +/// template of initial values for TLVs +pub const S_THREAD_LOCAL_ZEROFILL = 0x12; + +/// TLV descriptors +pub const S_THREAD_LOCAL_VARIABLES = 0x13; + +/// pointers to TLV descriptors +pub const S_THREAD_LOCAL_VARIABLE_POINTERS = 0x14; + +/// functions to call to initialize TLV values +pub const S_THREAD_LOCAL_INIT_FUNCTION_POINTERS = 0x15; + +/// 32-bit offsets to initializers +pub const S_INIT_FUNC_OFFSETS = 0x16; + pub const cpu_type_t = integer_t; pub const cpu_subtype_t = integer_t; pub const integer_t = c_int; @@ -1422,6 +1440,14 @@ pub const EXPORT_SYMBOL_FLAGS_KIND_WEAK_DEFINITION: u8 = 0x04; pub const EXPORT_SYMBOL_FLAGS_REEXPORT: u8 = 0x08; pub const EXPORT_SYMBOL_FLAGS_STUB_AND_RESOLVER: u8 = 0x10; +// An indirect symbol table entry is simply a 32bit index into the symbol table +// to the symbol that the pointer or stub is refering to. Unless it is for a +// non-lazy symbol pointer section for a defined symbol which strip(1) as +// removed. In which case it has the value INDIRECT_SYMBOL_LOCAL. If the +// symbol was also absolute INDIRECT_SYMBOL_ABS is or'ed with that. +pub const INDIRECT_SYMBOL_LOCAL: u32 = 0x80000000; +pub const INDIRECT_SYMBOL_ABS: u32 = 0x40000000; + // Codesign consts and structs taken from: // https://opensource.apple.com/source/xnu/xnu-6153.81.5/osfmk/kern/cs_blobs.h.auto.html @@ -1589,3 +1615,17 @@ pub const GenericBlob = extern struct { /// Total length of blob length: u32, }; + +/// The LC_DATA_IN_CODE load commands uses a linkedit_data_command +/// to point to an array of data_in_code_entry entries. Each entry +/// describes a range of data in a code section. +pub const data_in_code_entry = extern struct { + /// From mach_header to start of data range. + offset: u32, + + /// Number of bytes in data range. + length: u16, + + /// A DICE_KIND value. + kind: u16, +}; diff --git a/lib/std/mem.zig b/lib/std/mem.zig index 2bd5fdac7b..66505f5d29 100644 --- a/lib/std/mem.zig +++ b/lib/std/mem.zig @@ -1373,6 +1373,20 @@ test "mem.tokenize (multibyte)" { testing.expect(it.next() == null); } +test "mem.tokenize (reset)" { + var it = tokenize(" abc def ghi ", " "); + testing.expect(eql(u8, it.next().?, "abc")); + testing.expect(eql(u8, it.next().?, "def")); + testing.expect(eql(u8, it.next().?, "ghi")); + + it.reset(); + + testing.expect(eql(u8, it.next().?, "abc")); + testing.expect(eql(u8, it.next().?, "def")); + testing.expect(eql(u8, it.next().?, "ghi")); + testing.expect(it.next() == null); +} + /// Returns an iterator that iterates over the slices of `buffer` that /// are separated by bytes in `delimiter`. /// split("abc|def||ghi", "|") @@ -1471,6 +1485,11 @@ pub const TokenIterator = struct { return self.buffer[index..]; } + /// Resets the iterator to the initial token. + pub fn reset(self: *TokenIterator) void { + self.index = 0; + } + fn isSplitByte(self: TokenIterator, byte: u8) bool { for (self.delimiter_bytes) |delimiter_byte| { if (byte == delimiter_byte) { diff --git a/lib/std/meta.zig b/lib/std/meta.zig index fd3e03bdbd..cdc93e5d33 100644 --- a/lib/std/meta.zig +++ b/lib/std/meta.zig @@ -888,19 +888,20 @@ pub fn Vector(comptime len: u32, comptime child: type) type { /// Given a type and value, cast the value to the type as c would. /// This is for translate-c and is not intended for general use. pub fn cast(comptime DestType: type, target: anytype) DestType { - const TargetType = @TypeOf(target); + // this function should behave like transCCast in translate-c, except it's for macros + const SourceType = @TypeOf(target); switch (@typeInfo(DestType)) { - .Pointer => |dest_ptr| { - switch (@typeInfo(TargetType)) { + .Pointer => { + switch (@typeInfo(SourceType)) { .Int, .ComptimeInt => { return @intToPtr(DestType, target); }, - .Pointer => |ptr| { - return @ptrCast(DestType, @alignCast(dest_ptr.alignment, target)); + .Pointer => { + return castPtr(DestType, target); }, .Optional => |opt| { if (@typeInfo(opt.child) == .Pointer) { - return @ptrCast(DestType, @alignCast(dest_ptr.alignment, target)); + return castPtr(DestType, target); } }, else => {}, @@ -908,17 +909,16 @@ pub fn cast(comptime DestType: type, target: anytype) DestType { }, .Optional => |dest_opt| { if (@typeInfo(dest_opt.child) == .Pointer) { - const dest_ptr = @typeInfo(dest_opt.child).Pointer; - switch (@typeInfo(TargetType)) { + switch (@typeInfo(SourceType)) { .Int, .ComptimeInt => { return @intToPtr(DestType, target); }, .Pointer => { - return @ptrCast(DestType, @alignCast(dest_ptr.alignment, target)); + return castPtr(DestType, target); }, .Optional => |target_opt| { if (@typeInfo(target_opt.child) == .Pointer) { - return @ptrCast(DestType, @alignCast(dest_ptr.alignment, target)); + return castPtr(DestType, target); } }, else => {}, @@ -926,25 +926,25 @@ pub fn cast(comptime DestType: type, target: anytype) DestType { } }, .Enum => { - if (@typeInfo(TargetType) == .Int or @typeInfo(TargetType) == .ComptimeInt) { + if (@typeInfo(SourceType) == .Int or @typeInfo(SourceType) == .ComptimeInt) { return @intToEnum(DestType, target); } }, - .Int, .ComptimeInt => { - switch (@typeInfo(TargetType)) { + .Int => { + switch (@typeInfo(SourceType)) { .Pointer => { - return @intCast(DestType, @ptrToInt(target)); + return castInt(DestType, @ptrToInt(target)); }, .Optional => |opt| { if (@typeInfo(opt.child) == .Pointer) { - return @intCast(DestType, @ptrToInt(target)); + return castInt(DestType, @ptrToInt(target)); } }, .Enum => { - return @intCast(DestType, @enumToInt(target)); + return castInt(DestType, @enumToInt(target)); }, - .Int, .ComptimeInt => { - return @intCast(DestType, target); + .Int => { + return castInt(DestType, target); }, else => {}, } @@ -954,6 +954,34 @@ pub fn cast(comptime DestType: type, target: anytype) DestType { return @as(DestType, target); } +fn castInt(comptime DestType: type, target: anytype) DestType { + const dest = @typeInfo(DestType).Int; + const source = @typeInfo(@TypeOf(target)).Int; + + if (dest.bits < source.bits) + return @bitCast(DestType, @truncate(Int(source.signedness, dest.bits), target)) + else + return @bitCast(DestType, @as(Int(source.signedness, dest.bits), target)); +} + +fn castPtr(comptime DestType: type, target: anytype) DestType { + const dest = ptrInfo(DestType); + const source = ptrInfo(@TypeOf(target)); + + if (source.is_const and !dest.is_const or source.is_volatile and !dest.is_volatile) + return @intToPtr(DestType, @ptrToInt(target)) + else + return @ptrCast(DestType, @alignCast(dest.alignment, target)); +} + +fn ptrInfo(comptime PtrType: type) TypeInfo.Pointer { + return switch(@typeInfo(PtrType)){ + .Optional => |opt_info| @typeInfo(opt_info.child).Pointer, + .Pointer => |ptr_info| ptr_info, + else => unreachable, + }; +} + test "std.meta.cast" { const E = enum(u2) { Zero, @@ -977,6 +1005,11 @@ test "std.meta.cast" { testing.expectEqual(@as(u32, 4), cast(u32, @intToPtr(?*u32, 4))); testing.expectEqual(@as(u32, 10), cast(u32, @as(u64, 10))); testing.expectEqual(@as(u8, 2), cast(u8, E.Two)); + + testing.expectEqual(@bitCast(i32, @as(u32, 0x8000_0000)), cast(i32, @as(u32, 0x8000_0000))); + + testing.expectEqual(@intToPtr(*u8, 2), cast(*u8, @intToPtr(*const u8, 2))); + testing.expectEqual(@intToPtr(*u8, 2), cast(*u8, @intToPtr(*volatile u8, 2))); } /// Given a value returns its size as C's sizeof operator would. diff --git a/lib/std/meta/trait.zig b/lib/std/meta/trait.zig index e67f9b9bc4..481bfe212b 100644 --- a/lib/std/meta/trait.zig +++ b/lib/std/meta/trait.zig @@ -408,6 +408,84 @@ test "std.meta.trait.isTuple" { testing.expect(isTuple(@TypeOf(t3))); } +/// Returns true if the passed type will coerce to []const u8. +/// Any of the following are considered strings: +/// ``` +/// []const u8, [:S]const u8, *const [N]u8, *const [N:S]u8, +/// []u8, [:S]u8, *[:S]u8, *[N:S]u8. +/// ``` +/// These types are not considered strings: +/// ``` +/// u8, [N]u8, [*]const u8, [*:0]const u8, +/// [*]const [N]u8, []const u16, []const i8, +/// *const u8, ?[]const u8, ?*const [N]u8. +/// ``` +pub fn isZigString(comptime T: type) bool { + comptime { + // Only pointer types can be strings, no optionals + const info = @typeInfo(T); + if (info != .Pointer) return false; + + const ptr = &info.Pointer; + // Check for CV qualifiers that would prevent coerction to []const u8 + if (ptr.is_volatile or ptr.is_allowzero) return false; + + // If it's already a slice, simple check. + if (ptr.size == .Slice) { + return ptr.child == u8; + } + + // Otherwise check if it's an array type that coerces to slice. + if (ptr.size == .One) { + const child = @typeInfo(ptr.child); + if (child == .Array) { + const arr = &child.Array; + return arr.child == u8; + } + } + + return false; + } +} + +test "std.meta.trait.isZigString" { + testing.expect(isZigString([]const u8)); + testing.expect(isZigString([]u8)); + testing.expect(isZigString([:0]const u8)); + testing.expect(isZigString([:0]u8)); + testing.expect(isZigString([:5]const u8)); + testing.expect(isZigString([:5]u8)); + testing.expect(isZigString(*const [0]u8)); + testing.expect(isZigString(*[0]u8)); + testing.expect(isZigString(*const [0:0]u8)); + testing.expect(isZigString(*[0:0]u8)); + testing.expect(isZigString(*const [0:5]u8)); + testing.expect(isZigString(*[0:5]u8)); + testing.expect(isZigString(*const [10]u8)); + testing.expect(isZigString(*[10]u8)); + testing.expect(isZigString(*const [10:0]u8)); + testing.expect(isZigString(*[10:0]u8)); + testing.expect(isZigString(*const [10:5]u8)); + testing.expect(isZigString(*[10:5]u8)); + + testing.expect(!isZigString(u8)); + testing.expect(!isZigString([4]u8)); + testing.expect(!isZigString([4:0]u8)); + testing.expect(!isZigString([*]const u8)); + testing.expect(!isZigString([*]const [4]u8)); + testing.expect(!isZigString([*c]const u8)); + testing.expect(!isZigString([*c]const [4]u8)); + testing.expect(!isZigString([*:0]const u8)); + testing.expect(!isZigString([*:0]const u8)); + testing.expect(!isZigString(*[]const u8)); + testing.expect(!isZigString(?[]const u8)); + testing.expect(!isZigString(?*const [4]u8)); + testing.expect(!isZigString([]allowzero u8)); + testing.expect(!isZigString([]volatile u8)); + testing.expect(!isZigString(*allowzero [4]u8)); + testing.expect(!isZigString(*volatile [4]u8)); +} + pub fn hasDecls(comptime T: type, comptime names: anytype) bool { inline for (names) |name| { if (!@hasDecl(T, name)) diff --git a/lib/std/os.zig b/lib/std/os.zig index 362a58f7fb..9d9fd872a8 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -2879,7 +2879,7 @@ pub fn bind(sock: socket_t, addr: *const sockaddr, len: socklen_t) BindError!voi unreachable; } -const ListenError = error{ +pub const ListenError = error{ /// Another socket is already listening on the same port. /// For Internet domain sockets, the socket referred to by sockfd had not previously /// been bound to an address and, upon attempting to bind it to an ephemeral port, it @@ -5610,6 +5610,7 @@ pub fn recvfrom( EAGAIN => return error.WouldBlock, ENOMEM => return error.SystemResources, ECONNREFUSED => return error.ConnectionRefused, + ECONNRESET => return error.ConnectionResetByPeer, else => |err| return unexpectedErrno(err), } } @@ -5827,7 +5828,7 @@ pub fn tcsetattr(handle: fd_t, optional_action: TCSA, termios_p: termios) Termio } } -const IoCtl_SIOCGIFINDEX_Error = error{ +pub const IoCtl_SIOCGIFINDEX_Error = error{ FileSystem, InterfaceNotFound, } || UnexpectedError; diff --git a/lib/std/os/linux/io_uring.zig b/lib/std/os/linux/io_uring.zig index e900bdcd6a..4342beca00 100644 --- a/lib/std/os/linux/io_uring.zig +++ b/lib/std/os/linux/io_uring.zig @@ -1353,7 +1353,7 @@ test "timeout (after a relative time)" { .res = -linux.ETIME, .flags = 0, }, cqe); - testing.expectWithinMargin(@intToFloat(f64, ms), @intToFloat(f64, stopped - started), margin); + testing.expectApproxEqAbs(@intToFloat(f64, ms), @intToFloat(f64, stopped - started), margin); } test "timeout (after a number of completions)" { diff --git a/lib/std/os/linux/mips.zig b/lib/std/os/linux/mips.zig index 2622628533..ddb3103cfa 100644 --- a/lib/std/os/linux/mips.zig +++ b/lib/std/os/linux/mips.zig @@ -115,6 +115,9 @@ pub fn syscall5(number: SYS, arg1: usize, arg2: usize, arg3: usize, arg4: usize, ); } +// NOTE: The o32 calling convention requires the callee to reserve 16 bytes for +// the first four arguments even though they're passed in $a0-$a3. + pub fn syscall6( number: SYS, arg1: usize, @@ -146,6 +149,40 @@ pub fn syscall6( ); } +pub fn syscall7( + number: SYS, + arg1: usize, + arg2: usize, + arg3: usize, + arg4: usize, + arg5: usize, + arg6: usize, + arg7: usize, +) usize { + return asm volatile ( + \\ .set noat + \\ subu $sp, $sp, 32 + \\ sw %[arg5], 16($sp) + \\ sw %[arg6], 20($sp) + \\ sw %[arg7], 24($sp) + \\ syscall + \\ addu $sp, $sp, 32 + \\ blez $7, 1f + \\ subu $2, $0, $2 + \\ 1: + : [ret] "={$2}" (-> usize) + : [number] "{$2}" (@enumToInt(number)), + [arg1] "{$4}" (arg1), + [arg2] "{$5}" (arg2), + [arg3] "{$6}" (arg3), + [arg4] "{$7}" (arg4), + [arg5] "r" (arg5), + [arg6] "r" (arg6), + [arg7] "r" (arg7) + : "memory", "cc", "$7" + ); +} + /// This matches the libc clone function. pub extern fn clone(func: fn (arg: usize) callconv(.C) u8, stack: usize, flags: u32, arg: usize, ptid: *i32, tls: usize, ctid: *i32) usize; diff --git a/lib/std/os/uefi/tables/boot_services.zig b/lib/std/os/uefi/tables/boot_services.zig index b96881fcc2..2b3e896960 100644 --- a/lib/std/os/uefi/tables/boot_services.zig +++ b/lib/std/os/uefi/tables/boot_services.zig @@ -78,7 +78,8 @@ pub const BootServices = extern struct { /// Returns an array of handles that support a specified protocol. locateHandle: fn (LocateSearchType, ?*align(8) const Guid, ?*const c_void, *usize, [*]Handle) callconv(.C) Status, - locateDevicePath: Status, // TODO + /// Locates the handle to a device on the device path that supports the specified protocol + locateDevicePath: fn (*align(8) const Guid, **const DevicePathProtocol, *?Handle) callconv(.C) Status, installConfigurationTable: Status, // TODO /// Loads an EFI image into memory. diff --git a/lib/std/os/windows/user32.zig b/lib/std/os/windows/user32.zig index 186a1af59f..9a058f35c0 100644 --- a/lib/std/os/windows/user32.zig +++ b/lib/std/os/windows/user32.zig @@ -373,7 +373,7 @@ pub fn createWindowExA(dwExStyle: u32, lpClassName: [*:0]const u8, lpWindowName: } pub extern "user32" fn CreateWindowExW(dwExStyle: DWORD, lpClassName: [*:0]const u16, lpWindowName: [*:0]const u16, dwStyle: DWORD, X: i32, Y: i32, nWidth: i32, nHeight: i32, hWindParent: ?HWND, hMenu: ?HMENU, hInstance: HINSTANCE, lpParam: ?LPVOID) callconv(WINAPI) ?HWND; -pub var pfnCreateWindowExW: @TypeOf(RegisterClassExW) = undefined; +pub var pfnCreateWindowExW: @TypeOf(CreateWindowExW) = undefined; pub fn createWindowExW(dwExStyle: u32, lpClassName: [*:0]const u16, lpWindowName: [*:0]const u16, dwStyle: u32, X: i32, Y: i32, nWidth: i32, nHeight: i32, hWindParent: ?HWND, hMenu: ?HMENU, hInstance: HINSTANCE, lpParam: ?*c_void) !HWND { const function = selectSymbol(CreateWindowExW, pfnCreateWindowExW, .win2k); const window = function(dwExStyle, lpClassName, lpWindowName, dwStyle, X, Y, nWidth, nHeight, hWindParent, hMenu, hInstance, lpParam); diff --git a/lib/std/special/build_runner.zig b/lib/std/special/build_runner.zig index 0b7baf0fc1..70aa3c8dc6 100644 --- a/lib/std/special/build_runner.zig +++ b/lib/std/special/build_runner.zig @@ -60,6 +60,7 @@ pub fn main() !void { const stderr_stream = io.getStdErr().writer(); const stdout_stream = io.getStdOut().writer(); + var install_prefix: ?[]const u8 = null; while (nextArg(args, &arg_idx)) |arg| { if (mem.startsWith(u8, arg, "-D")) { const option_contents = arg[2..]; @@ -82,7 +83,7 @@ pub fn main() !void { } else if (mem.eql(u8, arg, "-h") or mem.eql(u8, arg, "--help")) { return usage(builder, false, stdout_stream); } else if (mem.eql(u8, arg, "--prefix")) { - builder.install_prefix = nextArg(args, &arg_idx) orelse { + install_prefix = nextArg(args, &arg_idx) orelse { warn("Expected argument after --prefix\n\n", .{}); return usageAndErr(builder, false, stderr_stream); }; @@ -134,7 +135,7 @@ pub fn main() !void { } } - builder.resolveInstallPrefix(); + builder.resolveInstallPrefix(install_prefix); try runBuild(builder); if (builder.validateUserInputDidItFail()) @@ -162,8 +163,7 @@ fn runBuild(builder: *Builder) anyerror!void { fn usage(builder: *Builder, already_ran_build: bool, out_stream: anytype) !void { // run the build script to collect the options if (!already_ran_build) { - builder.setInstallPrefix(null); - builder.resolveInstallPrefix(); + builder.resolveInstallPrefix(null); try runBuild(builder); } diff --git a/lib/std/std.zig b/lib/std/std.zig index a7e5bcb682..82249af157 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -20,6 +20,9 @@ pub const ComptimeStringMap = @import("comptime_string_map.zig").ComptimeStringM pub const DynLib = @import("dynamic_library.zig").DynLib; pub const DynamicBitSet = bit_set.DynamicBitSet; pub const DynamicBitSetUnmanaged = bit_set.DynamicBitSetUnmanaged; +pub const EnumArray = enums.EnumArray; +pub const EnumMap = enums.EnumMap; +pub const EnumSet = enums.EnumSet; pub const HashMap = hash_map.HashMap; pub const HashMapUnmanaged = hash_map.HashMapUnmanaged; pub const MultiArrayList = @import("multi_array_list.zig").MultiArrayList; @@ -54,6 +57,7 @@ pub const cstr = @import("cstr.zig"); pub const debug = @import("debug.zig"); pub const dwarf = @import("dwarf.zig"); pub const elf = @import("elf.zig"); +pub const enums = @import("enums.zig"); pub const event = @import("event.zig"); pub const fifo = @import("fifo.zig"); pub const fmt = @import("fmt.zig"); diff --git a/lib/std/testing.zig b/lib/std/testing.zig index 1d89155a58..eb2b6e87b3 100644 --- a/lib/std/testing.zig +++ b/lib/std/testing.zig @@ -200,67 +200,69 @@ pub fn expectFmt(expected: []const u8, comptime template: []const u8, args: anyt return error.TestFailed; } -/// This function is intended to be used only in tests. When the actual value is not -/// within the margin of the expected value, -/// prints diagnostics to stderr to show exactly how they are not equal, then aborts. +pub const expectWithinMargin = @compileError("expectWithinMargin is deprecated, use expectApproxEqAbs or expectApproxEqRel"); +pub const expectWithinEpsilon = @compileError("expectWithinEpsilon is deprecated, use expectApproxEqAbs or expectApproxEqRel"); + +/// This function is intended to be used only in tests. When the actual value is +/// not approximately equal to the expected value, prints diagnostics to stderr +/// to show exactly how they are not equal, then aborts. +/// See `math.approxEqAbs` for more informations on the tolerance parameter. /// The types must be floating point -pub fn expectWithinMargin(expected: anytype, actual: @TypeOf(expected), margin: @TypeOf(expected)) void { - std.debug.assert(margin >= 0.0); +pub fn expectApproxEqAbs(expected: anytype, actual: @TypeOf(expected), tolerance: @TypeOf(expected)) void { + const T = @TypeOf(expected); + + switch (@typeInfo(T)) { + .Float => if (!math.approxEqAbs(T, expected, actual, tolerance)) + std.debug.panic("actual {}, not within absolute tolerance {} of expected {}", .{ actual, tolerance, expected }), + + .ComptimeFloat => @compileError("Cannot approximately compare two comptime_float values"), - switch (@typeInfo(@TypeOf(actual))) { - .Float, - .ComptimeFloat, - => { - if (@fabs(expected - actual) > margin) { - std.debug.panic("actual {}, not within margin {} of expected {}", .{ actual, margin, expected }); - } - }, else => @compileError("Unable to compare non floating point values"), } } -test "expectWithinMargin" { +test "expectApproxEqAbs" { inline for ([_]type{ f16, f32, f64, f128 }) |T| { const pos_x: T = 12.0; const pos_y: T = 12.06; const neg_x: T = -12.0; const neg_y: T = -12.06; - expectWithinMargin(pos_x, pos_y, 0.1); - expectWithinMargin(neg_x, neg_y, 0.1); + expectApproxEqAbs(pos_x, pos_y, 0.1); + expectApproxEqAbs(neg_x, neg_y, 0.1); } } -/// This function is intended to be used only in tests. When the actual value is not -/// within the epsilon of the expected value, -/// prints diagnostics to stderr to show exactly how they are not equal, then aborts. +/// This function is intended to be used only in tests. When the actual value is +/// not approximately equal to the expected value, prints diagnostics to stderr +/// to show exactly how they are not equal, then aborts. +/// See `math.approxEqRel` for more informations on the tolerance parameter. /// The types must be floating point -pub fn expectWithinEpsilon(expected: anytype, actual: @TypeOf(expected), epsilon: @TypeOf(expected)) void { - std.debug.assert(epsilon >= 0.0 and epsilon <= 1.0); +pub fn expectApproxEqRel(expected: anytype, actual: @TypeOf(expected), tolerance: @TypeOf(expected)) void { + const T = @TypeOf(expected); + + switch (@typeInfo(T)) { + .Float => if (!math.approxEqRel(T, expected, actual, tolerance)) + std.debug.panic("actual {}, not within relative tolerance {} of expected {}", .{ actual, tolerance, expected }), + + .ComptimeFloat => @compileError("Cannot approximately compare two comptime_float values"), - // Relative epsilon test. - const margin = math.max(math.fabs(expected), math.fabs(actual)) * epsilon; - switch (@typeInfo(@TypeOf(actual))) { - .Float, - .ComptimeFloat, - => { - if (@fabs(expected - actual) > margin) { - std.debug.panic("actual {}, not within epsilon {}, of expected {}", .{ actual, epsilon, expected }); - } - }, else => @compileError("Unable to compare non floating point values"), } } -test "expectWithinEpsilon" { +test "expectApproxEqRel" { inline for ([_]type{ f16, f32, f64, f128 }) |T| { + const eps_value = comptime math.epsilon(T); + const sqrt_eps_value = comptime math.sqrt(eps_value); + const pos_x: T = 12.0; - const pos_y: T = 13.2; + const pos_y: T = pos_x + 2 * eps_value; const neg_x: T = -12.0; - const neg_y: T = -13.2; + const neg_y: T = neg_x - 2 * eps_value; - expectWithinEpsilon(pos_x, pos_y, 0.1); - expectWithinEpsilon(neg_x, neg_y, 0.1); + expectApproxEqRel(pos_x, pos_y, sqrt_eps_value); + expectApproxEqRel(neg_x, neg_y, sqrt_eps_value); } } @@ -296,7 +298,7 @@ pub const TmpDir = struct { sub_path: [sub_path_len]u8, const random_bytes_count = 12; - const sub_path_len = std.base64.Base64Encoder.calcSize(random_bytes_count); + const sub_path_len = std.fs.base64_encoder.calcSize(random_bytes_count); pub fn cleanup(self: *TmpDir) void { self.dir.close(); diff --git a/lib/std/zig/parser_test.zig b/lib/std/zig/parser_test.zig index b6bd2844a4..2a343a6edc 100644 --- a/lib/std/zig/parser_test.zig +++ b/lib/std/zig/parser_test.zig @@ -4,6 +4,31 @@ // The MIT license requires this copyright notice to be included in all copies // and substantial portions of the software. +test "zig fmt: respect line breaks in struct field value declaration" { + try testCanonical( + \\const Foo = struct { + \\ bar: u32 = + \\ 42, + \\ bar: u32 = + \\ // a comment + \\ 42, + \\ bar: u32 = + \\ 42, + \\ // a comment + \\ bar: []const u8 = + \\ \\ foo + \\ \\ bar + \\ \\ baz + \\ , + \\ bar: u32 = + \\ blk: { + \\ break :blk 42; + \\ }, + \\}; + \\ + ); +} + // TODO Remove this after zig 0.9.0 is released. test "zig fmt: rewrite inline functions as callconv(.Inline)" { try testTransform( @@ -3038,6 +3063,54 @@ test "zig fmt: switch" { \\} \\ ); + + try testTransform( + \\test { + \\ switch (x) { + \\ foo => + \\ "bar", + \\ } + \\} + \\ + , + \\test { + \\ switch (x) { + \\ foo => "bar", + \\ } + \\} + \\ + ); +} + +test "zig fmt: switch multiline string" { + try testCanonical( + \\test "switch multiline string" { + \\ const x: u32 = 0; + \\ const str = switch (x) { + \\ 1 => "one", + \\ 2 => + \\ \\ Comma after the multiline string + \\ \\ is needed + \\ , + \\ 3 => "three", + \\ else => "else", + \\ }; + \\ + \\ const Union = union(enum) { + \\ Int: i64, + \\ Float: f64, + \\ }; + \\ + \\ const str = switch (u) { + \\ Union.Int => |int| + \\ \\ Comma after the multiline string + \\ \\ is needed + \\ , + \\ Union.Float => |*float| unreachable, + \\ }; + \\} + \\ + ); } test "zig fmt: while" { @@ -3068,6 +3141,11 @@ test "zig fmt: while" { \\ while (i < 10) : ({ \\ i += 1; \\ j += 1; + \\ }) continue; + \\ + \\ while (i < 10) : ({ + \\ i += 1; + \\ j += 1; \\ }) { \\ continue; \\ } @@ -3184,6 +3262,156 @@ test "zig fmt: for" { ); } +test "zig fmt: for if" { + try testCanonical( + \\test { + \\ for (a) |x| if (x) f(x); + \\ + \\ for (a) |x| if (x) + \\ f(x); + \\ + \\ for (a) |x| if (x) { + \\ f(x); + \\ }; + \\ + \\ for (a) |x| + \\ if (x) + \\ f(x); + \\ + \\ for (a) |x| + \\ if (x) { + \\ f(x); + \\ }; + \\} + \\ + ); +} + +test "zig fmt: if for" { + try testCanonical( + \\test { + \\ if (a) for (x) |x| f(x); + \\ + \\ if (a) for (x) |x| + \\ f(x); + \\ + \\ if (a) for (x) |x| { + \\ f(x); + \\ }; + \\ + \\ if (a) + \\ for (x) |x| + \\ f(x); + \\ + \\ if (a) + \\ for (x) |x| { + \\ f(x); + \\ }; + \\} + \\ + ); +} + +test "zig fmt: while if" { + try testCanonical( + \\test { + \\ while (a) if (x) f(x); + \\ + \\ while (a) if (x) + \\ f(x); + \\ + \\ while (a) if (x) { + \\ f(x); + \\ }; + \\ + \\ while (a) + \\ if (x) + \\ f(x); + \\ + \\ while (a) + \\ if (x) { + \\ f(x); + \\ }; + \\} + \\ + ); +} + +test "zig fmt: if while" { + try testCanonical( + \\test { + \\ if (a) while (x) : (cont) f(x); + \\ + \\ if (a) while (x) : (cont) + \\ f(x); + \\ + \\ if (a) while (x) : (cont) { + \\ f(x); + \\ }; + \\ + \\ if (a) + \\ while (x) : (cont) + \\ f(x); + \\ + \\ if (a) + \\ while (x) : (cont) { + \\ f(x); + \\ }; + \\} + \\ + ); +} + +test "zig fmt: while for" { + try testCanonical( + \\test { + \\ while (a) for (x) |x| f(x); + \\ + \\ while (a) for (x) |x| + \\ f(x); + \\ + \\ while (a) for (x) |x| { + \\ f(x); + \\ }; + \\ + \\ while (a) + \\ for (x) |x| + \\ f(x); + \\ + \\ while (a) + \\ for (x) |x| { + \\ f(x); + \\ }; + \\} + \\ + ); +} + +test "zig fmt: for while" { + try testCanonical( + \\test { + \\ for (a) |a| while (x) |x| f(x); + \\ + \\ for (a) |a| while (x) |x| + \\ f(x); + \\ + \\ for (a) |a| while (x) |x| { + \\ f(x); + \\ }; + \\ + \\ for (a) |a| + \\ while (x) |x| + \\ f(x); + \\ + \\ for (a) |a| + \\ while (x) |x| { + \\ f(x); + \\ }; + \\} + \\ + ); +} + test "zig fmt: if" { try testCanonical( \\test "if" { @@ -3233,6 +3461,82 @@ test "zig fmt: if" { ); } +test "zig fmt: fix single statement if/for/while line breaks" { + try testTransform( + \\test { + \\ if (cond) a + \\ else b; + \\ + \\ if (cond) + \\ a + \\ else b; + \\ + \\ for (xs) |x| foo() + \\ else bar(); + \\ + \\ for (xs) |x| + \\ foo() + \\ else bar(); + \\ + \\ while (a) : (b) foo() + \\ else bar(); + \\ + \\ while (a) : (b) + \\ foo() + \\ else bar(); + \\} + \\ + , + \\test { + \\ if (cond) a else b; + \\ + \\ if (cond) + \\ a + \\ else + \\ b; + \\ + \\ for (xs) |x| foo() else bar(); + \\ + \\ for (xs) |x| + \\ foo() + \\ else + \\ bar(); + \\ + \\ while (a) : (b) foo() else bar(); + \\ + \\ while (a) : (b) + \\ foo() + \\ else + \\ bar(); + \\} + \\ + ); +} + +test "zig fmt: anon struct/array literal in if" { + try testCanonical( + \\test { + \\ const a = if (cond) .{ + \\ 1, 2, + \\ 3, 4, + \\ } else .{ + \\ 1, + \\ 2, + \\ 3, + \\ }; + \\ + \\ const rl_and_tag: struct { rl: ResultLoc, tag: zir.Inst.Tag } = if (any_payload_is_ref) .{ + \\ .rl = .ref, + \\ .tag = .switchbr_ref, + \\ } else .{ + \\ .rl = .none, + \\ .tag = .switchbr, + \\ }; + \\} + \\ + ); +} + test "zig fmt: defer" { try testCanonical( \\test "defer" { @@ -3820,6 +4124,7 @@ test "zig fmt: comments in ternary ifs" { \\ // Comment \\ 1 \\else + \\ // Comment \\ 0; \\ \\pub extern "c" fn printf(format: [*:0]const u8, ...) c_int; @@ -3827,6 +4132,20 @@ test "zig fmt: comments in ternary ifs" { ); } +test "zig fmt: while statement in blockless if" { + try testCanonical( + \\pub fn main() void { + \\ const zoom_node = if (focused_node == layout_first) + \\ while (it.next()) |node| { + \\ if (!node.view.pending.float and !node.view.pending.fullscreen) break node; + \\ } else null + \\ else + \\ focused_node; + \\} + \\ + ); +} + test "zig fmt: test comments in field access chain" { try testCanonical( \\pub const str = struct { diff --git a/lib/std/zig/render.zig b/lib/std/zig/render.zig index 30e83e9a7c..640f25829a 100644 --- a/lib/std/zig/render.zig +++ b/lib/std/zig/render.zig @@ -1018,147 +1018,14 @@ fn renderWhile(gpa: *Allocator, ais: *Ais, tree: ast.Tree, while_node: ast.full. try renderToken(ais, tree, inline_token, .space); // inline } - try renderToken(ais, tree, while_node.ast.while_token, .space); // if + try renderToken(ais, tree, while_node.ast.while_token, .space); // if/for/while try renderToken(ais, tree, while_node.ast.while_token + 1, .none); // lparen try renderExpression(gpa, ais, tree, while_node.ast.cond_expr, .none); // condition - const then_tag = node_tags[while_node.ast.then_expr]; - if (nodeIsBlock(then_tag) and !nodeIsIf(then_tag)) { - if (while_node.payload_token) |payload_token| { - try renderToken(ais, tree, payload_token - 2, .space); // rparen - try renderToken(ais, tree, payload_token - 1, .none); // | - const ident = blk: { - if (token_tags[payload_token] == .asterisk) { - try renderToken(ais, tree, payload_token, .none); // * - break :blk payload_token + 1; - } else { - break :blk payload_token; - } - }; - try renderToken(ais, tree, ident, .none); // identifier - const pipe = blk: { - if (token_tags[ident + 1] == .comma) { - try renderToken(ais, tree, ident + 1, .space); // , - try renderToken(ais, tree, ident + 2, .none); // index - break :blk ident + 3; - } else { - break :blk ident + 1; - } - }; - const brace_space = if (while_node.ast.cont_expr == 0 and ais.isLineOverIndented()) - Space.newline - else - Space.space; - try renderToken(ais, tree, pipe, brace_space); // | - } else { - const rparen = tree.lastToken(while_node.ast.cond_expr) + 1; - const brace_space = if (while_node.ast.cont_expr == 0 and ais.isLineOverIndented()) - Space.newline - else - Space.space; - try renderToken(ais, tree, rparen, brace_space); // rparen - } - if (while_node.ast.cont_expr != 0) { - const rparen = tree.lastToken(while_node.ast.cont_expr) + 1; - const lparen = tree.firstToken(while_node.ast.cont_expr) - 1; - try renderToken(ais, tree, lparen - 1, .space); // : - try renderToken(ais, tree, lparen, .none); // lparen - try renderExpression(gpa, ais, tree, while_node.ast.cont_expr, .none); - const brace_space: Space = if (ais.isLineOverIndented()) .newline else .space; - try renderToken(ais, tree, rparen, brace_space); // rparen - } - if (while_node.ast.else_expr != 0) { - try renderExpression(gpa, ais, tree, while_node.ast.then_expr, Space.space); - try renderToken(ais, tree, while_node.else_token, .space); // else - if (while_node.error_token) |error_token| { - try renderToken(ais, tree, error_token - 1, .none); // | - try renderToken(ais, tree, error_token, .none); // identifier - try renderToken(ais, tree, error_token + 1, .space); // | - } - return renderExpression(gpa, ais, tree, while_node.ast.else_expr, space); - } else { - return renderExpression(gpa, ais, tree, while_node.ast.then_expr, space); - } - } - - const rparen = tree.lastToken(while_node.ast.cond_expr) + 1; - const last_then_token = tree.lastToken(while_node.ast.then_expr); - const src_has_newline = !tree.tokensOnSameLine(rparen, last_then_token); - - if (src_has_newline) { - if (while_node.payload_token) |payload_token| { - try renderToken(ais, tree, payload_token - 2, .space); // rparen - try renderToken(ais, tree, payload_token - 1, .none); // | - const ident = blk: { - if (token_tags[payload_token] == .asterisk) { - try renderToken(ais, tree, payload_token, .none); // * - break :blk payload_token + 1; - } else { - break :blk payload_token; - } - }; - try renderToken(ais, tree, ident, .none); // identifier - const pipe = blk: { - if (token_tags[ident + 1] == .comma) { - try renderToken(ais, tree, ident + 1, .space); // , - try renderToken(ais, tree, ident + 2, .none); // index - break :blk ident + 3; - } else { - break :blk ident + 1; - } - }; - const after_space: Space = if (while_node.ast.cont_expr != 0) .space else .newline; - try renderToken(ais, tree, pipe, after_space); // | - } else { - ais.pushIndent(); - const after_space: Space = if (while_node.ast.cont_expr != 0) .space else .newline; - try renderToken(ais, tree, rparen, after_space); // rparen - ais.popIndent(); - } - if (while_node.ast.cont_expr != 0) { - const cont_rparen = tree.lastToken(while_node.ast.cont_expr) + 1; - const cont_lparen = tree.firstToken(while_node.ast.cont_expr) - 1; - try renderToken(ais, tree, cont_lparen - 1, .space); // : - try renderToken(ais, tree, cont_lparen, .none); // lparen - try renderExpression(gpa, ais, tree, while_node.ast.cont_expr, .none); - try renderToken(ais, tree, cont_rparen, .newline); // rparen - } - if (while_node.ast.else_expr != 0) { - ais.pushIndent(); - try renderExpression(gpa, ais, tree, while_node.ast.then_expr, Space.newline); - ais.popIndent(); - const else_is_block = nodeIsBlock(node_tags[while_node.ast.else_expr]); - if (else_is_block) { - try renderToken(ais, tree, while_node.else_token, .space); // else - if (while_node.error_token) |error_token| { - try renderToken(ais, tree, error_token - 1, .none); // | - try renderToken(ais, tree, error_token, .none); // identifier - try renderToken(ais, tree, error_token + 1, .space); // | - } - return renderExpression(gpa, ais, tree, while_node.ast.else_expr, space); - } else { - if (while_node.error_token) |error_token| { - try renderToken(ais, tree, while_node.else_token, .space); // else - try renderToken(ais, tree, error_token - 1, .none); // | - try renderToken(ais, tree, error_token, .none); // identifier - try renderToken(ais, tree, error_token + 1, .space); // | - } else { - try renderToken(ais, tree, while_node.else_token, .newline); // else - } - try renderExpressionIndented(gpa, ais, tree, while_node.ast.else_expr, space); - return; - } - } else { - try renderExpressionIndented(gpa, ais, tree, while_node.ast.then_expr, space); - return; - } - } - - // Render everything on a single line. + var last_prefix_token = tree.lastToken(while_node.ast.cond_expr) + 1; // rparen if (while_node.payload_token) |payload_token| { - assert(payload_token - 2 == rparen); - try renderToken(ais, tree, payload_token - 2, .space); // ) + try renderToken(ais, tree, last_prefix_token, .space); try renderToken(ais, tree, payload_token - 1, .none); // | const ident = blk: { if (token_tags[payload_token] == .asterisk) { @@ -1178,33 +1045,67 @@ fn renderWhile(gpa: *Allocator, ais: *Ais, tree: ast.Tree, while_node: ast.full. break :blk ident + 1; } }; - try renderToken(ais, tree, pipe, .space); // | - } else { - try renderToken(ais, tree, rparen, .space); // ) + last_prefix_token = pipe; } if (while_node.ast.cont_expr != 0) { - const cont_rparen = tree.lastToken(while_node.ast.cont_expr) + 1; - const cont_lparen = tree.firstToken(while_node.ast.cont_expr) - 1; - try renderToken(ais, tree, cont_lparen - 1, .space); // : - try renderToken(ais, tree, cont_lparen, .none); // lparen + try renderToken(ais, tree, last_prefix_token, .space); + const lparen = tree.firstToken(while_node.ast.cont_expr) - 1; + try renderToken(ais, tree, lparen - 1, .space); // : + try renderToken(ais, tree, lparen, .none); // lparen try renderExpression(gpa, ais, tree, while_node.ast.cont_expr, .none); - try renderToken(ais, tree, cont_rparen, .space); // rparen + last_prefix_token = tree.lastToken(while_node.ast.cont_expr) + 1; // rparen + } + + const then_expr_is_block = nodeIsBlock(node_tags[while_node.ast.then_expr]); + const indent_then_expr = !then_expr_is_block and + !tree.tokensOnSameLine(last_prefix_token, tree.firstToken(while_node.ast.then_expr)); + if (indent_then_expr or (then_expr_is_block and ais.isLineOverIndented())) { + ais.pushIndentNextLine(); + try renderToken(ais, tree, last_prefix_token, .newline); + ais.popIndent(); + } else { + try renderToken(ais, tree, last_prefix_token, .space); } if (while_node.ast.else_expr != 0) { - try renderExpression(gpa, ais, tree, while_node.ast.then_expr, .space); - try renderToken(ais, tree, while_node.else_token, .space); // else + const first_else_expr_tok = tree.firstToken(while_node.ast.else_expr); + + if (indent_then_expr) { + ais.pushIndent(); + try renderExpression(gpa, ais, tree, while_node.ast.then_expr, .newline); + ais.popIndent(); + } else { + try renderExpression(gpa, ais, tree, while_node.ast.then_expr, .space); + } + + var last_else_token = while_node.else_token; if (while_node.error_token) |error_token| { + try renderToken(ais, tree, while_node.else_token, .space); // else try renderToken(ais, tree, error_token - 1, .none); // | try renderToken(ais, tree, error_token, .none); // identifier - try renderToken(ais, tree, error_token + 1, .space); // | + last_else_token = error_token + 1; // | } - return renderExpression(gpa, ais, tree, while_node.ast.else_expr, space); + const indent_else_expr = indent_then_expr and + !nodeIsBlock(node_tags[while_node.ast.else_expr]) and + !nodeIsIfForWhileSwitch(node_tags[while_node.ast.else_expr]); + if (indent_else_expr) { + ais.pushIndentNextLine(); + try renderToken(ais, tree, last_else_token, .newline); + ais.popIndent(); + try renderExpressionIndented(gpa, ais, tree, while_node.ast.else_expr, space); + } else { + try renderToken(ais, tree, last_else_token, .space); + try renderExpression(gpa, ais, tree, while_node.ast.else_expr, space); + } } else { - return renderExpression(gpa, ais, tree, while_node.ast.then_expr, space); + if (indent_then_expr) { + try renderExpressionIndented(gpa, ais, tree, while_node.ast.then_expr, space); + } else { + try renderExpression(gpa, ais, tree, while_node.ast.then_expr, space); + } } } @@ -1258,8 +1159,29 @@ fn renderContainerField( try renderToken(ais, tree, rparen_token, .space); // ) } const eq_token = tree.firstToken(field.ast.value_expr) - 1; - try renderToken(ais, tree, eq_token, .space); // = - return renderExpressionComma(gpa, ais, tree, field.ast.value_expr, space); // value + const eq_space: Space = if (tree.tokensOnSameLine(eq_token, eq_token + 1)) .space else .newline; + { + ais.pushIndent(); + try renderToken(ais, tree, eq_token, eq_space); // = + ais.popIndent(); + } + + if (eq_space == .space) + return renderExpressionComma(gpa, ais, tree, field.ast.value_expr, space); // value + + const token_tags = tree.tokens.items(.tag); + const maybe_comma = tree.lastToken(field.ast.value_expr) + 1; + + if (token_tags[maybe_comma] == .comma) { + ais.pushIndent(); + try renderExpression(gpa, ais, tree, field.ast.value_expr, .none); // value + ais.popIndent(); + try renderToken(ais, tree, maybe_comma, space); + } else { + ais.pushIndent(); + try renderExpression(gpa, ais, tree, field.ast.value_expr, space); // value + ais.popIndent(); + } } fn renderBuiltinCall( @@ -1522,6 +1444,7 @@ fn renderSwitchCase( switch_case: ast.full.SwitchCase, space: Space, ) Error!void { + const node_tags = tree.nodes.items(.tag); const token_tags = tree.tokens.items(.tag); const trailing_comma = token_tags[switch_case.ast.arrow_token - 1] == .comma; @@ -1544,17 +1467,23 @@ fn renderSwitchCase( } // Render the arrow and everything after it - try renderToken(ais, tree, switch_case.ast.arrow_token, .space); + const pre_target_space = if (node_tags[switch_case.ast.target_expr] == .multiline_string_literal) + // Newline gets inserted when rendering the target expr. + Space.none + else + Space.space; + const after_arrow_space: Space = if (switch_case.payload_token == null) pre_target_space else .space; + try renderToken(ais, tree, switch_case.ast.arrow_token, after_arrow_space); if (switch_case.payload_token) |payload_token| { try renderToken(ais, tree, payload_token - 1, .none); // pipe if (token_tags[payload_token] == .asterisk) { try renderToken(ais, tree, payload_token, .none); // asterisk try renderToken(ais, tree, payload_token + 1, .none); // identifier - try renderToken(ais, tree, payload_token + 2, .space); // pipe + try renderToken(ais, tree, payload_token + 2, pre_target_space); // pipe } else { try renderToken(ais, tree, payload_token, .none); // identifier - try renderToken(ais, tree, payload_token + 1, .space); // pipe + try renderToken(ais, tree, payload_token + 1, pre_target_space); // pipe } } @@ -2493,6 +2422,21 @@ fn nodeIsBlock(tag: ast.Node.Tag) bool { .block_semicolon, .block_two, .block_two_semicolon, + .struct_init_dot, + .struct_init_dot_comma, + .struct_init_dot_two, + .struct_init_dot_two_comma, + .array_init_dot, + .array_init_dot_comma, + .array_init_dot_two, + .array_init_dot_two_comma, + => true, + else => false, + }; +} + +fn nodeIsIfForWhileSwitch(tag: ast.Node.Tag) bool { + return switch (tag) { .@"if", .if_simple, .@"for", @@ -2507,13 +2451,6 @@ fn nodeIsBlock(tag: ast.Node.Tag) bool { }; } -fn nodeIsIf(tag: ast.Node.Tag) bool { - return switch (tag) { - .@"if", .if_simple => true, - else => false, - }; -} - fn nodeCausesSliceOpSpace(tag: ast.Node.Tag) bool { return switch (tag) { .@"catch", |
