diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2022-12-13 20:15:41 -0700 |
|---|---|---|
| committer | Andrew Kelley <andrew@ziglang.org> | 2023-01-02 16:57:15 -0700 |
| commit | d2f5d0b1990a160aa1d648531ea5b1df7b2acdce (patch) | |
| tree | ca92f3233708feb032d3fc3f667efcca7d4a296b /lib/std/crypto/Tls.zig | |
| parent | ba44513c2fe363b55b2c534be98179286b832b7e (diff) | |
| download | zig-d2f5d0b1990a160aa1d648531ea5b1df7b2acdce.tar.gz zig-d2f5d0b1990a160aa1d648531ea5b1df7b2acdce.zip | |
std.crypto.Tls: parse the ServerHello handshake
Diffstat (limited to 'lib/std/crypto/Tls.zig')
| -rw-r--r-- | lib/std/crypto/Tls.zig | 127 |
1 files changed, 114 insertions, 13 deletions
diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig index ab54b42b70..0dc6946003 100644 --- a/lib/std/crypto/Tls.zig +++ b/lib/std/crypto/Tls.zig @@ -8,6 +8,13 @@ const assert = std.debug.assert; state: State = .start, x25519_priv_key: [32]u8 = undefined, x25519_pub_key: [32]u8 = undefined, +x25519_server_pub_key: [32]u8 = undefined, + +const ProtocolVersion = enum(u16) { + tls_1_2 = 0x0303, + tls_1_3 = 0x0304, + _, +}; const State = enum { /// In this state, all fields are undefined except state. @@ -186,6 +193,18 @@ const NamedGroup = enum(u16) { // * length: u24 // * data: opaque +// ServerHello: +// * ProtocolVersion legacy_version = 0x0303; +// * Random random; +// * opaque legacy_session_id_echo<0..32>; +// * CipherSuite cipher_suite; +// * uint8 legacy_compression_method = 0; +// * Extension extensions<6..2^16-1>; + +// Extension: +// * ExtensionType extension_type; +// * opaque extension_data<0..2^16-1>; + const CipherSuite = enum(u16) { TLS_AES_128_GCM_SHA256 = 0x1301, TLS_AES_256_GCM_SHA384 = 0x1302, @@ -259,10 +278,10 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { // Extension: key_share 0, 51, // ExtensionType.key_share - 0x00, 38, // byte length of this extension payload - 0x00, 36, // byte length of client_shares + 0, 38, // byte length of this extension payload + 0, 36, // byte length of client_shares 0x00, 0x1D, // NamedGroup.x25519 - 0x00, 32, // byte length of key_exchange + 0, 32, // byte length of key_exchange } ++ tls.x25519_pub_key ++ [_]u8{ // Extension: server_name @@ -313,21 +332,103 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void { try stream.writevAll(&iovecs); { - var buf: [1000]u8 = undefined; - const amt = try stream.read(&buf); - const resp = buf[0..amt]; - const ct = @intToEnum(ContentType, resp[0]); + var handshake_buf: [4000]u8 = undefined; + const plaintext = handshake_buf[0..5]; + const amt = try stream.readAtLeast(&handshake_buf, plaintext.len); + if (amt < plaintext.len) return error.EndOfStream; + const ct = @intToEnum(ContentType, plaintext[0]); + const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]); + const end = plaintext.len + frag_len; + if (end > handshake_buf.len) return error.TlsServerHelloTooBig; + if (amt < end) { + const amt2 = try stream.readAll(handshake_buf[amt..end]); + if (amt2 < plaintext.len) return error.EndOfStream; + } + const frag = handshake_buf[plaintext.len..end]; + if (ct == .alert) { - //const prot_ver = @bitCast(u16, resp[1..][0..2].*); - const len = std.mem.readIntBig(u16, resp[3..][0..2]); - const alert = resp[5..][0..len]; - const level = @intToEnum(AlertLevel, alert[0]); - const desc = @intToEnum(AlertDescription, alert[1]); + const level = @intToEnum(AlertLevel, frag[0]); + const desc = @intToEnum(AlertDescription, frag[1]); std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); std.process.exit(1); + } else if (ct == .handshake) { + if (frag[0] != @enumToInt(HandshakeType.server_hello)) { + return error.TlsUnexpectedMessage; + } + const length = mem.readIntBig(u24, frag[1..4]); + if (4 + length != frag.len) return error.TlsBadLength; + const hello = frag[4..]; + const legacy_version = mem.readIntBig(u16, hello[0..2]); + const random = hello[2..34].*; + _ = random; + const legacy_session_id_echo_len = hello[34]; + if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter; + const cipher_suite_int = mem.readIntBig(u16, hello[35..37]); + const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch + return error.TlsIllegalParameter; + std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)}); + const legacy_compression_method = hello[37]; + _ = legacy_compression_method; + const extensions_size = mem.readIntBig(u16, hello[38..40]); + if (40 + extensions_size != hello.len) return error.TlsBadLength; + var i: usize = 40; + var supported_version: u16 = 0; + var have_server_pub_key = false; + while (i < hello.len) { + const et = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const ext_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + const next_i = i + ext_size; + if (next_i > hello.len) return error.TlsBadLength; + switch (et) { + @enumToInt(ExtensionType.supported_versions) => { + if (supported_version != 0) return error.TlsIllegalParameter; + supported_version = mem.readIntBig(u16, hello[i..][0..2]); + }, + @enumToInt(ExtensionType.key_share) => { + if (have_server_pub_key) return error.TlsIllegalParameter; + const named_group = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + switch (named_group) { + @enumToInt(NamedGroup.x25519) => { + const key_size = mem.readIntBig(u16, hello[i..][0..2]); + i += 2; + if (key_size != 32) return error.TlsBadLength; + const encrypted_key = hello[i..][0..32].*; + const server_pub_key = try crypto.dh.X25519.scalarmult( + tls.x25519_priv_key, + encrypted_key, + ); + tls.x25519_server_pub_key = server_pub_key; + have_server_pub_key = true; + }, + else => { + std.debug.print("named group: {x}\n", .{named_group}); + return error.TlsIllegalParameter; + }, + } + }, + else => { + std.debug.print("unexpected extension: {x}\n", .{et}); + }, + } + i = next_i; + } + if (!have_server_pub_key) return error.TlsIllegalParameter; + const tls_version = if (supported_version == 0) legacy_version else supported_version; + switch (tls_version) { + @enumToInt(ProtocolVersion.tls_1_2) => { + std.debug.print("server wants TLS v1.2\n", .{}); + }, + @enumToInt(ProtocolVersion.tls_1_3) => { + std.debug.print("server wants TLS v1.3\n", .{}); + }, + else => return error.TlsIllegalParameter, + } } else { std.debug.print("content_type: {s}\n", .{@tagName(ct)}); - std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(resp) }); + std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) }); } } |
