aboutsummaryrefslogtreecommitdiff
path: root/lib/std/crypto/Tls.zig
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2022-12-13 20:15:41 -0700
committerAndrew Kelley <andrew@ziglang.org>2023-01-02 16:57:15 -0700
commitd2f5d0b1990a160aa1d648531ea5b1df7b2acdce (patch)
treeca92f3233708feb032d3fc3f667efcca7d4a296b /lib/std/crypto/Tls.zig
parentba44513c2fe363b55b2c534be98179286b832b7e (diff)
downloadzig-d2f5d0b1990a160aa1d648531ea5b1df7b2acdce.tar.gz
zig-d2f5d0b1990a160aa1d648531ea5b1df7b2acdce.zip
std.crypto.Tls: parse the ServerHello handshake
Diffstat (limited to 'lib/std/crypto/Tls.zig')
-rw-r--r--lib/std/crypto/Tls.zig127
1 files changed, 114 insertions, 13 deletions
diff --git a/lib/std/crypto/Tls.zig b/lib/std/crypto/Tls.zig
index ab54b42b70..0dc6946003 100644
--- a/lib/std/crypto/Tls.zig
+++ b/lib/std/crypto/Tls.zig
@@ -8,6 +8,13 @@ const assert = std.debug.assert;
state: State = .start,
x25519_priv_key: [32]u8 = undefined,
x25519_pub_key: [32]u8 = undefined,
+x25519_server_pub_key: [32]u8 = undefined,
+
+const ProtocolVersion = enum(u16) {
+ tls_1_2 = 0x0303,
+ tls_1_3 = 0x0304,
+ _,
+};
const State = enum {
/// In this state, all fields are undefined except state.
@@ -186,6 +193,18 @@ const NamedGroup = enum(u16) {
// * length: u24
// * data: opaque
+// ServerHello:
+// * ProtocolVersion legacy_version = 0x0303;
+// * Random random;
+// * opaque legacy_session_id_echo<0..32>;
+// * CipherSuite cipher_suite;
+// * uint8 legacy_compression_method = 0;
+// * Extension extensions<6..2^16-1>;
+
+// Extension:
+// * ExtensionType extension_type;
+// * opaque extension_data<0..2^16-1>;
+
const CipherSuite = enum(u16) {
TLS_AES_128_GCM_SHA256 = 0x1301,
TLS_AES_256_GCM_SHA384 = 0x1302,
@@ -259,10 +278,10 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
// Extension: key_share
0, 51, // ExtensionType.key_share
- 0x00, 38, // byte length of this extension payload
- 0x00, 36, // byte length of client_shares
+ 0, 38, // byte length of this extension payload
+ 0, 36, // byte length of client_shares
0x00, 0x1D, // NamedGroup.x25519
- 0x00, 32, // byte length of key_exchange
+ 0, 32, // byte length of key_exchange
} ++ tls.x25519_pub_key ++ [_]u8{
// Extension: server_name
@@ -313,21 +332,103 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
try stream.writevAll(&iovecs);
{
- var buf: [1000]u8 = undefined;
- const amt = try stream.read(&buf);
- const resp = buf[0..amt];
- const ct = @intToEnum(ContentType, resp[0]);
+ var handshake_buf: [4000]u8 = undefined;
+ const plaintext = handshake_buf[0..5];
+ const amt = try stream.readAtLeast(&handshake_buf, plaintext.len);
+ if (amt < plaintext.len) return error.EndOfStream;
+ const ct = @intToEnum(ContentType, plaintext[0]);
+ const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
+ const end = plaintext.len + frag_len;
+ if (end > handshake_buf.len) return error.TlsServerHelloTooBig;
+ if (amt < end) {
+ const amt2 = try stream.readAll(handshake_buf[amt..end]);
+ if (amt2 < plaintext.len) return error.EndOfStream;
+ }
+ const frag = handshake_buf[plaintext.len..end];
+
if (ct == .alert) {
- //const prot_ver = @bitCast(u16, resp[1..][0..2].*);
- const len = std.mem.readIntBig(u16, resp[3..][0..2]);
- const alert = resp[5..][0..len];
- const level = @intToEnum(AlertLevel, alert[0]);
- const desc = @intToEnum(AlertDescription, alert[1]);
+ const level = @intToEnum(AlertLevel, frag[0]);
+ const desc = @intToEnum(AlertDescription, frag[1]);
std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
std.process.exit(1);
+ } else if (ct == .handshake) {
+ if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
+ return error.TlsUnexpectedMessage;
+ }
+ const length = mem.readIntBig(u24, frag[1..4]);
+ if (4 + length != frag.len) return error.TlsBadLength;
+ const hello = frag[4..];
+ const legacy_version = mem.readIntBig(u16, hello[0..2]);
+ const random = hello[2..34].*;
+ _ = random;
+ const legacy_session_id_echo_len = hello[34];
+ if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
+ const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
+ const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
+ return error.TlsIllegalParameter;
+ std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
+ const legacy_compression_method = hello[37];
+ _ = legacy_compression_method;
+ const extensions_size = mem.readIntBig(u16, hello[38..40]);
+ if (40 + extensions_size != hello.len) return error.TlsBadLength;
+ var i: usize = 40;
+ var supported_version: u16 = 0;
+ var have_server_pub_key = false;
+ while (i < hello.len) {
+ const et = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ const next_i = i + ext_size;
+ if (next_i > hello.len) return error.TlsBadLength;
+ switch (et) {
+ @enumToInt(ExtensionType.supported_versions) => {
+ if (supported_version != 0) return error.TlsIllegalParameter;
+ supported_version = mem.readIntBig(u16, hello[i..][0..2]);
+ },
+ @enumToInt(ExtensionType.key_share) => {
+ if (have_server_pub_key) return error.TlsIllegalParameter;
+ const named_group = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ switch (named_group) {
+ @enumToInt(NamedGroup.x25519) => {
+ const key_size = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ if (key_size != 32) return error.TlsBadLength;
+ const encrypted_key = hello[i..][0..32].*;
+ const server_pub_key = try crypto.dh.X25519.scalarmult(
+ tls.x25519_priv_key,
+ encrypted_key,
+ );
+ tls.x25519_server_pub_key = server_pub_key;
+ have_server_pub_key = true;
+ },
+ else => {
+ std.debug.print("named group: {x}\n", .{named_group});
+ return error.TlsIllegalParameter;
+ },
+ }
+ },
+ else => {
+ std.debug.print("unexpected extension: {x}\n", .{et});
+ },
+ }
+ i = next_i;
+ }
+ if (!have_server_pub_key) return error.TlsIllegalParameter;
+ const tls_version = if (supported_version == 0) legacy_version else supported_version;
+ switch (tls_version) {
+ @enumToInt(ProtocolVersion.tls_1_2) => {
+ std.debug.print("server wants TLS v1.2\n", .{});
+ },
+ @enumToInt(ProtocolVersion.tls_1_3) => {
+ std.debug.print("server wants TLS v1.3\n", .{});
+ },
+ else => return error.TlsIllegalParameter,
+ }
} else {
std.debug.print("content_type: {s}\n", .{@tagName(ct)});
- std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(resp) });
+ std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) });
}
}