diff options
Diffstat (limited to 'lib/std')
| -rw-r--r-- | lib/std/crypto/Tls.zig | 153 | ||||
| -rw-r--r-- | lib/std/http/Client.zig | 23 |
2 files changed, 122 insertions, 54 deletions
diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index 6b5374512b..19bd3442cf 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -9,12 +9,14 @@ application_cipher: ApplicationCipher, read_seq: u64, write_seq: u64, /// The size is enough to contain exactly one TLSCiphertext record. -partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8, +partially_read_buffer: [max_ciphertext_record_len]u8, /// The number of partially read bytes inside `partiall_read_buffer`. partially_read_len: u15, +eof: bool, pub const ciphertext_record_header_len = 5; pub const max_ciphertext_len = (1 << 14) + 256; +pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len; pub const ProtocolVersion = enum(u16) { tls_1_2 = 0x0303, @@ -416,7 +418,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { var cipher_params: CipherParams = undefined; - var handshake_buf: [4000]u8 = undefined; + var handshake_buf: [8000]u8 = undefined; var len: usize = 0; var i: usize = i: { const plaintext = handshake_buf[0..5]; @@ -554,8 +556,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { // std.fmt.fmtSliceHexLower(&hello_hash), // std.fmt.fmtSliceHexLower(&early_secret), // std.fmt.fmtSliceHexLower(&empty_hash), - // std.fmt.fmtSliceHexLower(&derived_secret), - // std.fmt.fmtSliceHexLower(&handshake_secret), + // std.fmt.fmtSliceHexLower(&hs_derived_secret), + // std.fmt.fmtSliceHexLower(&p.handshake_secret), // std.fmt.fmtSliceHexLower(&client_secret), // std.fmt.fmtSliceHexLower(&server_secret), //}); @@ -582,7 +584,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const end_hdr = i + 5; if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow; if (end_hdr > len) { + std.debug.print("read len={d} atleast={d}\n", .{ len, end_hdr - len }); len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len); + std.debug.print("new len: {d} bytes\n", .{len}); if (end_hdr > len) return error.EndOfStream; } const ct = @intToEnum(ContentType, handshake_buf[i]); @@ -593,9 +597,12 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]); i += 2; const end = i + record_size; + std.debug.print("ct={any} record_size={d} end={d}\n", .{ ct, record_size, end }); if (end > handshake_buf.len) return error.TlsRecordOverflow; if (end > len) { + std.debug.print("read len={d} atleast={d}\n", .{ len, end - len }); len += try stream.readAtLeast(handshake_buf[len..], end - len); + std.debug.print("new len: {d} bytes\n", .{len}); if (end > len) return error.EndOfStream; } switch (ct) { @@ -604,7 +611,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage; }, .application_data => { - var cleartext_buf: [1000]u8 = undefined; + var cleartext_buf: [8000]u8 = undefined; const cleartext = switch (cipher_params) { inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { const P = @TypeOf(p.*); @@ -637,17 +644,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { }; const inner_ct = cleartext[cleartext.len - 1]; + std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)}); switch (inner_ct) { @enumToInt(ContentType.handshake) => { const handshake_len = mem.readIntBig(u24, cleartext[1..4]); - if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength; + if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength; + std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len }); switch (cleartext[0]) { @enumToInt(HandshakeType.encrypted_extensions) => { const ext_size = mem.readIntBig(u16, cleartext[4..6]); - if (ext_size != 0) { - @panic("TODO handle encrypted extensions"); - } - std.debug.print("empty encrypted extensions\n", .{}); + std.debug.print("{d} bytes of encrypted extensions\n", .{ + ext_size, + }); }, @enumToInt(HandshakeType.certificate) => { std.debug.print("cool certificate bro\n", .{}); @@ -688,22 +696,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { const nonce = p.client_handshake_iv; P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - { - var iovecs = [_]std.os.iovec_const{ - .{ - .iov_base = &client_change_cipher_spec_msg, - .iov_len = client_change_cipher_spec_msg.len, - }, - .{ - .iov_base = &finished_msg, - .iov_len = finished_msg.len, - }, - }; - try stream.writevAll(&iovecs); - } + //const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + _ = client_change_cipher_spec_msg; + const both_msgs = finished_msg; + try stream.writeAll(&both_msgs); const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{ + // std.fmt.fmtSliceHexLower(&p.master_secret), + // std.fmt.fmtSliceHexLower(&client_secret), + // std.fmt.fmtSliceHexLower(&server_secret), + //}); break :c @unionInit(ApplicationCipher, @tagName(tag), .{ .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), @@ -721,12 +725,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { @panic("TODO"); }, }; + std.debug.print("remaining bytes: {d}\n", .{len - end}); return .{ .application_cipher = app_cipher, - .read_seq = read_seq, - .write_seq = 1, + .read_seq = 0, + .write_seq = 0, .partially_read_buffer = undefined, .partially_read_len = 0, + .eof = false, }; }, else => { @@ -753,49 +759,67 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls { } pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { - var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined; + var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined; + // Due to the trailing inner content type byte in the ciphertext, we need + // an additional buffer for storing the cleartext into before encrypting. + var cleartext_buf: [max_ciphertext_len]u8 = undefined; var iovecs_buf: [5]std.os.iovec_const = undefined; var ciphertext_end: usize = 0; var iovec_end: usize = 0; var bytes_i: usize = 0; - switch (tls.application_cipher) { - inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| { + // How many bytes are taken up by overhead per record. + const overhead_len: usize = switch (tls.application_cipher) { + inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); + const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1; while (true) { - const ciphertext_len = @intCast(u16, @min( - @min(bytes.len - bytes_i, max_ciphertext_len), - ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end, + const encrypted_content_len = @intCast(u16, @min( + @min(bytes.len - bytes_i, max_ciphertext_len - 1), + ciphertext_buf.len - + ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1, )); - if (ciphertext_len == 0) return bytes_i; + if (encrypted_content_len == 0) break :l overhead_len; - const wrapped_len = ciphertext_len + P.AEAD.tag_length; - const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len]; + mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]); + cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data); + bytes_i += encrypted_content_len; + const ciphertext_len = encrypted_content_len + 1; + const cleartext = cleartext_buf[0..ciphertext_len]; - const ad = record[0..5]; - ciphertext_end += 5; + const record_start = ciphertext_end; + const ad = ciphertext_buf[ciphertext_end..][0..5]; + ad.* = + [_]u8{@enumToInt(ContentType.application_data)} ++ + int2(@enumToInt(ProtocolVersion.tls_1_2)) ++ + int2(ciphertext_len + P.AEAD.tag_length); + ciphertext_end += ad.len; const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; ciphertext_end += ciphertext_len; const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += P.AEAD.tag_length; + ciphertext_end += auth_tag.len; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq)); tls.write_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand; - ad.* = - [_]u8{@enumToInt(ContentType.application_data)} ++ - int2(@enumToInt(ProtocolVersion.tls_1_2)) ++ - int2(wrapped_len); - const cleartext = bytes[bytes_i..ciphertext.len]; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); - + //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{ + // tls.write_seq - 1, + // std.fmt.fmtSliceHexLower(&nonce), + // std.fmt.fmtSliceHexLower(&p.client_key), + // std.fmt.fmtSliceHexLower(&p.client_iv), + // std.fmt.fmtSliceHexLower(ad), + // std.fmt.fmtSliceHexLower(auth_tag), + // std.fmt.fmtSliceHexLower(&p.server_key), + // std.fmt.fmtSliceHexLower(&p.server_iv), + //}); + + const record = ciphertext_buf[record_start..ciphertext_end]; iovecs_buf[iovec_end] = .{ .iov_base = record.ptr, .iov_len = record.len, }; iovec_end += 1; - - bytes_i += ciphertext_len; } }, .TLS_CHACHA20_POLY1305_SHA256 => { @@ -807,7 +831,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { .TLS_AES_128_CCM_8_SHA256 => { @panic("TODO"); }, - } + }; // Ideally we would call writev exactly once here, however, we must ensure // that we don't return with a record partially written. @@ -815,9 +839,10 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize { var total_amt: usize = 0; while (true) { var amt = try stream.writev(iovecs_buf[i..iovec_end]); - total_amt += amt; while (amt >= iovecs_buf[i].iov_len) { - amt -= iovecs_buf[i].iov_len; + const encrypted_amt = iovecs_buf[i].iov_len; + total_amt += encrypted_amt - overhead_len; + amt -= encrypted_amt; i += 1; // Rely on the property that iovecs delineate records, meaning that // if amt equals zero here, we have fortunately found ourselves @@ -849,11 +874,17 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len); const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]); const frag = in_buf[0 .. prev_len + actual_read_len]; + if (frag.len == 0) { + tls.eof = true; + return 0; + } + std.debug.print("actual_read_len={d} frag.len={d}\n", .{ actual_read_len, frag.len }); var in: usize = 0; var out: usize = 0; while (true) { if (in + ciphertext_record_header_len > frag.len) { + std.debug.print("in={d} frag.len={d}\n", .{ in, frag.len }); return finishRead(tls, frag, in, out); } const ct = @intToEnum(ContentType, frag[in]); @@ -866,6 +897,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const end = in + record_size; if (end > frag.len) { if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; + std.debug.print("end={d} frag.len={d}\n", .{ end, frag.len }); return finishRead(tls, frag, in, out); } switch (ct) { @@ -877,6 +909,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); + const ad = frag[in - 5 ..][0..5]; const ciphertext_len = record_size - P.AEAD.tag_length; const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; @@ -886,7 +919,12 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq)); tls.read_seq += 1; const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - const ad = frag[0..ciphertext_record_header_len]; + //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{ + // tls.read_seq - 1, + // std.fmt.fmtSliceHexLower(&nonce), + // std.fmt.fmtSliceHexLower(&p.server_key), + // std.fmt.fmtSliceHexLower(&p.server_iv), + //}); P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch return error.TlsBadRecordMac; break :c cleartext.len; @@ -902,15 +940,26 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize { }, }; - const inner_ct = buffer[out + cleartext_len - 1]; + const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]); switch (inner_ct) { - @enumToInt(ContentType.handshake) => { + .alert => { + const level = @intToEnum(AlertLevel, buffer[out]); + const desc = @intToEnum(AlertDescription, buffer[out + 1]); + if (desc == .close_notify) { + tls.eof = true; + return out; + } + std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); + return error.TlsAlert; + }, + .handshake => { std.debug.print("the server wants to keep shaking hands\n", .{}); }, - @enumToInt(ContentType.application_data) => { + .application_data => { out += cleartext_len - 1; }, else => { + std.debug.print("inner content type: {d}\n", .{inner_ct}); return error.TlsUnexpectedMessage; }, } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index e7b056830a..2c92163435 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -62,6 +62,25 @@ pub const Request = struct { .https => return req.tls.read(req.stream, buffer), } } + + pub fn readAll(req: *Request, buffer: []u8) !usize { + return readAtLeast(req, buffer, buffer.len); + } + + pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { + var index: usize = 0; + while (index < len) { + const amt = try req.read(buffer[index..]); + if (amt == 0) { + switch (req.protocol) { + .http => break, + .https => if (req.tls.eof) break, + } + } + index += amt; + } + return index; + } }; pub fn deinit(client: *Client) void { @@ -92,7 +111,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { @tagName(options.method).len + 1 + options.path.len + - " HTTP/2\r\nHost: ".len + + " HTTP/1.1\r\nHost: ".len + options.host.len + "\r\nUpgrade-Insecure-Requests: 1\r\n".len + client.headers.items.len + @@ -101,7 +120,7 @@ pub fn request(client: *Client, options: Request.Options) !Request { req.headers.appendSliceAssumeCapacity(@tagName(options.method)); req.headers.appendSliceAssumeCapacity(" "); req.headers.appendSliceAssumeCapacity(options.path); - req.headers.appendSliceAssumeCapacity(" HTTP/2\r\nHost: "); + req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: "); req.headers.appendSliceAssumeCapacity(options.host); switch (options.protocol) { .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"), |
