aboutsummaryrefslogtreecommitdiff
path: root/lib/std
diff options
context:
space:
mode:
Diffstat (limited to 'lib/std')
-rw-r--r--lib/std/crypto/Tls.zig153
-rw-r--r--lib/std/http/Client.zig23
2 files changed, 122 insertions, 54 deletions
diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig
index 6b5374512b..19bd3442cf 100644
--- a/lib/std/crypto/Tls.zig
+++ b/lib/std/crypto/Tls.zig
@@ -9,12 +9,14 @@ application_cipher: ApplicationCipher,
read_seq: u64,
write_seq: u64,
/// The size is enough to contain exactly one TLSCiphertext record.
-partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8,
+partially_read_buffer: [max_ciphertext_record_len]u8,
/// The number of partially read bytes inside `partiall_read_buffer`.
partially_read_len: u15,
+eof: bool,
pub const ciphertext_record_header_len = 5;
pub const max_ciphertext_len = (1 << 14) + 256;
+pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
pub const ProtocolVersion = enum(u16) {
tls_1_2 = 0x0303,
@@ -416,7 +418,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
var cipher_params: CipherParams = undefined;
- var handshake_buf: [4000]u8 = undefined;
+ var handshake_buf: [8000]u8 = undefined;
var len: usize = 0;
var i: usize = i: {
const plaintext = handshake_buf[0..5];
@@ -554,8 +556,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
// std.fmt.fmtSliceHexLower(&hello_hash),
// std.fmt.fmtSliceHexLower(&early_secret),
// std.fmt.fmtSliceHexLower(&empty_hash),
- // std.fmt.fmtSliceHexLower(&derived_secret),
- // std.fmt.fmtSliceHexLower(&handshake_secret),
+ // std.fmt.fmtSliceHexLower(&hs_derived_secret),
+ // std.fmt.fmtSliceHexLower(&p.handshake_secret),
// std.fmt.fmtSliceHexLower(&client_secret),
// std.fmt.fmtSliceHexLower(&server_secret),
//});
@@ -582,7 +584,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
const end_hdr = i + 5;
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
if (end_hdr > len) {
+ std.debug.print("read len={d} atleast={d}\n", .{ len, end_hdr - len });
len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
+ std.debug.print("new len: {d} bytes\n", .{len});
if (end_hdr > len) return error.EndOfStream;
}
const ct = @intToEnum(ContentType, handshake_buf[i]);
@@ -593,9 +597,12 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
i += 2;
const end = i + record_size;
+ std.debug.print("ct={any} record_size={d} end={d}\n", .{ ct, record_size, end });
if (end > handshake_buf.len) return error.TlsRecordOverflow;
if (end > len) {
+ std.debug.print("read len={d} atleast={d}\n", .{ len, end - len });
len += try stream.readAtLeast(handshake_buf[len..], end - len);
+ std.debug.print("new len: {d} bytes\n", .{len});
if (end > len) return error.EndOfStream;
}
switch (ct) {
@@ -604,7 +611,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
},
.application_data => {
- var cleartext_buf: [1000]u8 = undefined;
+ var cleartext_buf: [8000]u8 = undefined;
const cleartext = switch (cipher_params) {
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
const P = @TypeOf(p.*);
@@ -637,17 +644,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
};
const inner_ct = cleartext[cleartext.len - 1];
+ std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)});
switch (inner_ct) {
@enumToInt(ContentType.handshake) => {
const handshake_len = mem.readIntBig(u24, cleartext[1..4]);
- if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength;
+ if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength;
+ std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len });
switch (cleartext[0]) {
@enumToInt(HandshakeType.encrypted_extensions) => {
const ext_size = mem.readIntBig(u16, cleartext[4..6]);
- if (ext_size != 0) {
- @panic("TODO handle encrypted extensions");
- }
- std.debug.print("empty encrypted extensions\n", .{});
+ std.debug.print("{d} bytes of encrypted extensions\n", .{
+ ext_size,
+ });
},
@enumToInt(HandshakeType.certificate) => {
std.debug.print("cool certificate bro\n", .{});
@@ -688,22 +696,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
const nonce = p.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
- {
- var iovecs = [_]std.os.iovec_const{
- .{
- .iov_base = &client_change_cipher_spec_msg,
- .iov_len = client_change_cipher_spec_msg.len,
- },
- .{
- .iov_base = &finished_msg,
- .iov_len = finished_msg.len,
- },
- };
- try stream.writevAll(&iovecs);
- }
+ //const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
+ _ = client_change_cipher_spec_msg;
+ const both_msgs = finished_msg;
+ try stream.writeAll(&both_msgs);
const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
+ //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{
+ // std.fmt.fmtSliceHexLower(&p.master_secret),
+ // std.fmt.fmtSliceHexLower(&client_secret),
+ // std.fmt.fmtSliceHexLower(&server_secret),
+ //});
break :c @unionInit(ApplicationCipher, @tagName(tag), .{
.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
@@ -721,12 +725,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
@panic("TODO");
},
};
+ std.debug.print("remaining bytes: {d}\n", .{len - end});
return .{
.application_cipher = app_cipher,
- .read_seq = read_seq,
- .write_seq = 1,
+ .read_seq = 0,
+ .write_seq = 0,
.partially_read_buffer = undefined,
.partially_read_len = 0,
+ .eof = false,
};
},
else => {
@@ -753,49 +759,67 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
}
pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
- var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined;
+ var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined;
+ // Due to the trailing inner content type byte in the ciphertext, we need
+ // an additional buffer for storing the cleartext into before encrypting.
+ var cleartext_buf: [max_ciphertext_len]u8 = undefined;
var iovecs_buf: [5]std.os.iovec_const = undefined;
var ciphertext_end: usize = 0;
var iovec_end: usize = 0;
var bytes_i: usize = 0;
- switch (tls.application_cipher) {
- inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| {
+ // How many bytes are taken up by overhead per record.
+ const overhead_len: usize = switch (tls.application_cipher) {
+ inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
+ const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1;
while (true) {
- const ciphertext_len = @intCast(u16, @min(
- @min(bytes.len - bytes_i, max_ciphertext_len),
- ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end,
+ const encrypted_content_len = @intCast(u16, @min(
+ @min(bytes.len - bytes_i, max_ciphertext_len - 1),
+ ciphertext_buf.len -
+ ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
));
- if (ciphertext_len == 0) return bytes_i;
+ if (encrypted_content_len == 0) break :l overhead_len;
- const wrapped_len = ciphertext_len + P.AEAD.tag_length;
- const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len];
+ mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
+ cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
+ bytes_i += encrypted_content_len;
+ const ciphertext_len = encrypted_content_len + 1;
+ const cleartext = cleartext_buf[0..ciphertext_len];
- const ad = record[0..5];
- ciphertext_end += 5;
+ const record_start = ciphertext_end;
+ const ad = ciphertext_buf[ciphertext_end..][0..5];
+ ad.* =
+ [_]u8{@enumToInt(ContentType.application_data)} ++
+ int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
+ int2(ciphertext_len + P.AEAD.tag_length);
+ ciphertext_end += ad.len;
const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
ciphertext_end += ciphertext_len;
const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
- ciphertext_end += P.AEAD.tag_length;
+ ciphertext_end += auth_tag.len;
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq));
tls.write_seq += 1;
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand;
- ad.* =
- [_]u8{@enumToInt(ContentType.application_data)} ++
- int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
- int2(wrapped_len);
- const cleartext = bytes[bytes_i..ciphertext.len];
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
-
+ //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{
+ // tls.write_seq - 1,
+ // std.fmt.fmtSliceHexLower(&nonce),
+ // std.fmt.fmtSliceHexLower(&p.client_key),
+ // std.fmt.fmtSliceHexLower(&p.client_iv),
+ // std.fmt.fmtSliceHexLower(ad),
+ // std.fmt.fmtSliceHexLower(auth_tag),
+ // std.fmt.fmtSliceHexLower(&p.server_key),
+ // std.fmt.fmtSliceHexLower(&p.server_iv),
+ //});
+
+ const record = ciphertext_buf[record_start..ciphertext_end];
iovecs_buf[iovec_end] = .{
.iov_base = record.ptr,
.iov_len = record.len,
};
iovec_end += 1;
-
- bytes_i += ciphertext_len;
}
},
.TLS_CHACHA20_POLY1305_SHA256 => {
@@ -807,7 +831,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
.TLS_AES_128_CCM_8_SHA256 => {
@panic("TODO");
},
- }
+ };
// Ideally we would call writev exactly once here, however, we must ensure
// that we don't return with a record partially written.
@@ -815,9 +839,10 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
var total_amt: usize = 0;
while (true) {
var amt = try stream.writev(iovecs_buf[i..iovec_end]);
- total_amt += amt;
while (amt >= iovecs_buf[i].iov_len) {
- amt -= iovecs_buf[i].iov_len;
+ const encrypted_amt = iovecs_buf[i].iov_len;
+ total_amt += encrypted_amt - overhead_len;
+ amt -= encrypted_amt;
i += 1;
// Rely on the property that iovecs delineate records, meaning that
// if amt equals zero here, we have fortunately found ourselves
@@ -849,11 +874,17 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len);
const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]);
const frag = in_buf[0 .. prev_len + actual_read_len];
+ if (frag.len == 0) {
+ tls.eof = true;
+ return 0;
+ }
+ std.debug.print("actual_read_len={d} frag.len={d}\n", .{ actual_read_len, frag.len });
var in: usize = 0;
var out: usize = 0;
while (true) {
if (in + ciphertext_record_header_len > frag.len) {
+ std.debug.print("in={d} frag.len={d}\n", .{ in, frag.len });
return finishRead(tls, frag, in, out);
}
const ct = @intToEnum(ContentType, frag[in]);
@@ -866,6 +897,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
const end = in + record_size;
if (end > frag.len) {
if (record_size > max_ciphertext_len) return error.TlsRecordOverflow;
+ std.debug.print("end={d} frag.len={d}\n", .{ end, frag.len });
return finishRead(tls, frag, in, out);
}
switch (ct) {
@@ -877,6 +909,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
+ const ad = frag[in - 5 ..][0..5];
const ciphertext_len = record_size - P.AEAD.tag_length;
const ciphertext = frag[in..][0..ciphertext_len];
in += ciphertext_len;
@@ -886,7 +919,12 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq));
tls.read_seq += 1;
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
- const ad = frag[0..ciphertext_record_header_len];
+ //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{
+ // tls.read_seq - 1,
+ // std.fmt.fmtSliceHexLower(&nonce),
+ // std.fmt.fmtSliceHexLower(&p.server_key),
+ // std.fmt.fmtSliceHexLower(&p.server_iv),
+ //});
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
return error.TlsBadRecordMac;
break :c cleartext.len;
@@ -902,15 +940,26 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
},
};
- const inner_ct = buffer[out + cleartext_len - 1];
+ const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
switch (inner_ct) {
- @enumToInt(ContentType.handshake) => {
+ .alert => {
+ const level = @intToEnum(AlertLevel, buffer[out]);
+ const desc = @intToEnum(AlertDescription, buffer[out + 1]);
+ if (desc == .close_notify) {
+ tls.eof = true;
+ return out;
+ }
+ std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
+ return error.TlsAlert;
+ },
+ .handshake => {
std.debug.print("the server wants to keep shaking hands\n", .{});
},
- @enumToInt(ContentType.application_data) => {
+ .application_data => {
out += cleartext_len - 1;
},
else => {
+ std.debug.print("inner content type: {d}\n", .{inner_ct});
return error.TlsUnexpectedMessage;
},
}
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
index e7b056830a..2c92163435 100644
--- a/lib/std/http/Client.zig
+++ b/lib/std/http/Client.zig
@@ -62,6 +62,25 @@ pub const Request = struct {
.https => return req.tls.read(req.stream, buffer),
}
}
+
+ pub fn readAll(req: *Request, buffer: []u8) !usize {
+ return readAtLeast(req, buffer, buffer.len);
+ }
+
+ pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
+ var index: usize = 0;
+ while (index < len) {
+ const amt = try req.read(buffer[index..]);
+ if (amt == 0) {
+ switch (req.protocol) {
+ .http => break,
+ .https => if (req.tls.eof) break,
+ }
+ }
+ index += amt;
+ }
+ return index;
+ }
};
pub fn deinit(client: *Client) void {
@@ -92,7 +111,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
@tagName(options.method).len +
1 +
options.path.len +
- " HTTP/2\r\nHost: ".len +
+ " HTTP/1.1\r\nHost: ".len +
options.host.len +
"\r\nUpgrade-Insecure-Requests: 1\r\n".len +
client.headers.items.len +
@@ -101,7 +120,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
req.headers.appendSliceAssumeCapacity(@tagName(options.method));
req.headers.appendSliceAssumeCapacity(" ");
req.headers.appendSliceAssumeCapacity(options.path);
- req.headers.appendSliceAssumeCapacity(" HTTP/2\r\nHost: ");
+ req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: ");
req.headers.appendSliceAssumeCapacity(options.host);
switch (options.protocol) {
.https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"),