diff options
Diffstat (limited to 'lib/std')
| -rw-r--r-- | lib/std/Uri.zig | 112 | ||||
| -rw-r--r-- | lib/std/crypto/tls/Client.zig | 2 | ||||
| -rw-r--r-- | lib/std/http.zig | 6 | ||||
| -rw-r--r-- | lib/std/http/Client.zig | 762 | ||||
| -rw-r--r-- | lib/std/http/Headers.zig | 11 | ||||
| -rw-r--r-- | lib/std/http/Server.zig | 77 | ||||
| -rw-r--r-- | lib/std/http/protocol.zig | 6 | ||||
| -rw-r--r-- | lib/std/std.zig | 11 |
8 files changed, 687 insertions, 300 deletions
diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 21f05be2b8..3f277d0cc6 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -208,24 +208,45 @@ pub fn parseWithoutScheme(text: []const u8) ParseError!Uri { return uri; } -pub fn format( +pub const WriteToStreamOptions = struct { + /// When true, include the scheme part of the URI. + scheme: bool = false, + + /// When true, include the user and password part of the URI. Ignored if `authority` is false. + authentication: bool = false, + + /// When true, include the authority part of the URI. + authority: bool = false, + + /// When true, include the path part of the URI. + path: bool = false, + + /// When true, include the query part of the URI. Ignored when `path` is false. + query: bool = false, + + /// When true, include the fragment part of the URI. Ignored when `path` is false. + fragment: bool = false, + + /// When true, do not escape any part of the URI. + raw: bool = false, +}; + +pub fn writeToStream( uri: Uri, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, + options: WriteToStreamOptions, writer: anytype, ) @TypeOf(writer).Error!void { - _ = options; - - const needs_absolute = comptime std.mem.indexOf(u8, fmt, "+") != null; - const needs_path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0; - const raw_uri = comptime std.mem.indexOf(u8, fmt, "r") != null; - const needs_fragment = comptime std.mem.indexOf(u8, fmt, "#") != null; - - if (needs_absolute) { + if (options.scheme) { try writer.writeAll(uri.scheme); try writer.writeAll(":"); - if (uri.host) |host| { + + if (options.authority and uri.host != null) { try writer.writeAll("//"); + } + } + + if (options.authority) { + if (options.authentication and uri.host != null) { if (uri.user) |user| { try writer.writeAll(user); if (uri.password) |password| { @@ -234,7 +255,9 @@ pub fn format( } try writer.writeAll("@"); } + } + if (uri.host) |host| { try writer.writeAll(host); if (uri.port) |port| { @@ -244,39 +267,62 @@ pub fn format( } } - if (needs_path) { + if (options.path) { if (uri.path.len == 0) { try writer.writeAll("/"); + } else if (options.raw) { + try writer.writeAll(uri.path); } else { - if (raw_uri) { - try writer.writeAll(uri.path); - } else { - try Uri.writeEscapedPath(writer, uri.path); - } + try writeEscapedPath(writer, uri.path); } - if (uri.query) |q| { + if (options.query) if (uri.query) |q| { try writer.writeAll("?"); - if (raw_uri) { + if (options.raw) { try writer.writeAll(q); } else { - try Uri.writeEscapedQuery(writer, q); + try writeEscapedQuery(writer, q); } - } + }; - if (needs_fragment) { - if (uri.fragment) |f| { - try writer.writeAll("#"); - if (raw_uri) { - try writer.writeAll(f); - } else { - try Uri.writeEscapedQuery(writer, f); - } + if (options.fragment) if (uri.fragment) |f| { + try writer.writeAll("#"); + if (options.raw) { + try writer.writeAll(f); + } else { + try writeEscapedQuery(writer, f); } - } + }; } } +pub fn format( + uri: Uri, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, +) @TypeOf(writer).Error!void { + _ = options; + + const scheme = comptime std.mem.indexOf(u8, fmt, ":") != null or fmt.len == 0; + const authentication = comptime std.mem.indexOf(u8, fmt, "@") != null or fmt.len == 0; + const authority = comptime std.mem.indexOf(u8, fmt, "+") != null or fmt.len == 0; + const path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0; + const query = comptime std.mem.indexOf(u8, fmt, "?") != null or fmt.len == 0; + const fragment = comptime std.mem.indexOf(u8, fmt, "#") != null or fmt.len == 0; + const raw = comptime std.mem.indexOf(u8, fmt, "r") != null or fmt.len == 0; + + return writeToStream(uri, .{ + .scheme = scheme, + .authentication = authentication, + .authority = authority, + .path = path, + .query = query, + .fragment = fragment, + .raw = raw, + }, writer); +} + /// Parses the URI or returns an error. /// The return value will contain unescaped strings pointing into the /// original `text`. Each component that is provided, will be non-`null`. @@ -711,7 +757,7 @@ test "URI query escaping" { const parsed = try Uri.parse(address); // format the URI to escape it - const formatted_uri = try std.fmt.allocPrint(std.testing.allocator, "{}", .{parsed}); + const formatted_uri = try std.fmt.allocPrint(std.testing.allocator, "{/?}", .{parsed}); defer std.testing.allocator.free(formatted_uri); try std.testing.expectEqualStrings("/?response-content-type=application%2Foctet-stream", formatted_uri); } @@ -729,6 +775,6 @@ test "format" { }; var buf = std.ArrayList(u8).init(std.testing.allocator); defer buf.deinit(); - try uri.format("+/", .{}, buf.writer()); + try uri.format(":/?#", .{}, buf.writer()); try std.testing.expectEqualSlices(u8, "file:/foo/bar/baz", buf.items); } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 37306dd37f..7671d06469 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -881,7 +881,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { /// The `iovecs` parameter is mutable because this function needs to mutate the fields in /// order to handle partial reads from the underlying stream layer. pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { - return readvAtLeast(c, stream, iovecs); + return readvAtLeast(c, stream, iovecs, 1); } /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. diff --git a/lib/std/http.zig b/lib/std/http.zig index 532e5a6de8..9487e82106 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -35,7 +35,8 @@ pub const Method = enum(u64) { // TODO: should be u192 or u256, but neither is s /// Asserts that `s` is 24 or fewer bytes. pub fn parse(s: []const u8) u64 { var x: u64 = 0; - @memcpy(std.mem.asBytes(&x)[0..s.len], s); + const len = @min(s.len, @sizeOf(@TypeOf(x))); + @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); return x; } @@ -289,14 +290,17 @@ pub const Status = enum(u10) { pub const TransferEncoding = enum { chunked, + none, // compression is intentionally omitted here, as std.http.Client stores it as content-encoding }; pub const ContentEncoding = enum { identity, compress, + @"x-compress", deflate, gzip, + @"x-gzip", zstd, }; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 1e244af15e..2a6f12103f 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -13,12 +13,16 @@ const assert = std.debug.assert; const Client = @This(); const proto = @import("protocol.zig"); -pub const default_connection_pool_size = 32; -pub const connection_pool_size = std.options.http_connection_pool_size; +pub const disable_tls = std.options.http_disable_tls; +/// Allocator used for all allocations made by the client. +/// +/// This allocator must be thread-safe. allocator: Allocator, -ca_bundle: std.crypto.Certificate.Bundle = .{}, + +ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, + /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, @@ -26,7 +30,11 @@ next_https_rescan_certs: bool = true, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, -proxy: ?HttpProxy = null, +/// This is the proxy that will handle http:// connections. It *must not* be modified when the client has any active connections. +http_proxy: ?Proxy = null, + +/// This is the proxy that will handle https:// connections. It *must not* be modified when the client has any active connections. +https_proxy: ?Proxy = null, /// A set of linked lists of connections that can be reused. pub const ConnectionPool = struct { @@ -34,7 +42,7 @@ pub const ConnectionPool = struct { pub const Criteria = struct { host: []const u8, port: u16, - is_tls: bool, + protocol: Connection.Protocol, }; const Queue = std.DoublyLinkedList(Connection); @@ -46,22 +54,24 @@ pub const ConnectionPool = struct { /// Open connections that are not currently in use. free: Queue = .{}, free_len: usize = 0, - free_size: usize = connection_pool_size, + free_size: usize = 32, /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. /// If no connection is found, null is returned. - pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node { + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { pool.mutex.lock(); defer pool.mutex.unlock(); var next = pool.free.last; while (next) |node| : (next = node.prev) { - if ((node.data.protocol == .tls) != criteria.is_tls) continue; + if (node.data.protocol != criteria.protocol) continue; if (node.data.port != criteria.port) continue; - if (!mem.eql(u8, node.data.host, criteria.host)) continue; + + // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) + if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); - return node; + return &node.data; } return null; @@ -85,23 +95,28 @@ pub const ConnectionPool = struct { /// Tries to release a connection back to the connection pool. This function is threadsafe. /// If the connection is marked as closing, it will be closed instead. - pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void { + /// + /// The allocator must be the owner of all nodes in this pool. + /// The allocator must be the owner of all resources associated with the connection. + pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { pool.mutex.lock(); defer pool.mutex.unlock(); + const node = @fieldParentPtr(Node, "data", connection); + pool.used.remove(node); - if (node.data.closing) { - node.data.deinit(client); - return client.allocator.destroy(node); + if (node.data.closing or pool.free_size == 0) { + node.data.close(allocator); + return allocator.destroy(node); } if (pool.free_len >= pool.free_size) { const popped = pool.free.popFirst() orelse unreachable; pool.free_len -= 1; - popped.data.deinit(client); - client.allocator.destroy(popped); + popped.data.close(allocator); + allocator.destroy(popped); } if (node.data.proxied) { @@ -121,23 +136,43 @@ pub const ConnectionPool = struct { pool.used.append(node); } - pub fn deinit(pool: *ConnectionPool, client: *Client) void { + /// Resizes the connection pool. This function is threadsafe. + /// + /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + var next = pool.free.first; + _ = next; + while (pool.free_len > new_size) { + const popped = pool.free.popFirst() orelse unreachable; + pool.free_len -= 1; + + popped.data.close(allocator); + allocator.destroy(popped); + } + + pool.free_size = new_size; + } + + pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { pool.mutex.lock(); var next = pool.free.first; while (next) |node| { - defer client.allocator.destroy(node); + defer allocator.destroy(node); next = node.next; - node.data.deinit(client); + node.data.close(allocator); } next = pool.used.first; while (next) |node| { - defer client.allocator.destroy(node); + defer allocator.destroy(node); next = node.next; - node.data.deinit(client); + node.data.close(allocator); } pool.* = undefined; @@ -147,11 +182,13 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + const BufferSize = std.math.IntFittingRange(0, buffer_size); + pub const Protocol = enum { plain, tls }; stream: net.Stream, /// undefined unless protocol is tls. - tls_client: *std.crypto.tls.Client, + tls_client: if (!disable_tls) *std.crypto.tls.Client else void, protocol: Protocol, host: []u8, @@ -160,16 +197,15 @@ pub const Connection = struct { proxied: bool = false, closing: bool = false, - read_start: u16 = 0, - read_end: u16 = 0, + read_start: BufferSize = 0, + read_end: BufferSize = 0, + write_end: BufferSize = 0, read_buf: [buffer_size]u8 = undefined, + write_buf: [buffer_size]u8 = undefined, - pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), - } catch |err| { - // TODO: https://github.com/ziglang/zig/issues/2473 + pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { + return conn.tls_client.readv(conn.stream, buffers) catch |err| { + // https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; switch (err) { @@ -181,61 +217,69 @@ pub const Connection = struct { }; } + pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.readvDirectTls(buffers); + } + + return conn.stream.readv(buffers) catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; + } + pub fn fill(conn: *Connection) ReadError!void { if (conn.read_end != conn.read_start) return; - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); + var iovecs = [1]std.os.iovec{ + .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); if (nread == 0) return error.EndOfStream; conn.read_start = 0; - conn.read_end = @as(u16, @intCast(nread)); + conn.read_end = @intCast(nread); } pub fn peek(conn: *Connection) []const u8 { return conn.read_buf[conn.read_start..conn.read_end]; } - pub fn drop(conn: *Connection, num: u16) void { + pub fn drop(conn: *Connection, num: BufferSize) void { conn.read_start += num; } - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + const available_read = conn.read_end - conn.read_start; + const available_buffer = buffer.len; - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; + if (available_read > available_buffer) { // partially read buffered data + @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); + conn.read_start += @intCast(available_buffer); - if (out_index >= len) break; - } + return available_buffer; + } else if (available_read > 0) { // fully read buffered data + @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); + conn.read_start += available_read; - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; + return available_read; + } - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); - } + var iovecs = [2]std.os.iovec{ + .{ .iov_base = buffer.ptr, .iov_len = buffer.len }, + .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, + }; + const nread = try conn.readvDirect(&iovecs); - try conn.fill(); + if (nread > buffer.len) { + conn.read_start = 0; + conn.read_end = @intCast(nread - buffer.len); + return buffer.len; } - return out_index; - } - - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); + return nread; } pub const ReadError = error{ @@ -253,26 +297,49 @@ pub const Connection = struct { return Reader{ .context = conn }; } - pub fn writeAll(conn: *Connection, buffer: []const u8) !void { - return switch (conn.protocol) { - .plain => conn.stream.writeAll(buffer), - .tls => conn.tls_client.writeAll(conn.stream, buffer), - } catch |err| switch (err) { + pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { + return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; } - pub fn write(conn: *Connection, buffer: []const u8) !usize { - return switch (conn.protocol) { - .plain => conn.stream.write(buffer), - .tls => conn.tls_client.write(conn.stream, buffer), - } catch |err| switch (err) { + pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { + if (conn.protocol == .tls) { + if (disable_tls) unreachable; + + return conn.writeAllDirectTls(buffer); + } + + return conn.stream.writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; } + pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + if (conn.write_end + buffer.len > conn.write_buf.len) { + try conn.flush(); + + if (buffer.len > conn.write_buf.len) { + try conn.writeAllDirect(buffer); + return buffer.len; + } + } + + @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); + conn.write_end += @intCast(buffer.len); + + return buffer.len; + } + + pub fn flush(conn: *Connection) WriteError!void { + if (conn.write_end == 0) return; + + try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + conn.write_end = 0; + } + pub const WriteError = error{ ConnectionResetByPeer, UnexpectedWriteFailure, @@ -284,19 +351,17 @@ pub const Connection = struct { return Writer{ .context = conn }; } - pub fn close(conn: *Connection, client: *const Client) void { + pub fn close(conn: *Connection, allocator: Allocator) void { if (conn.protocol == .tls) { + if (disable_tls) unreachable; + // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; - client.allocator.destroy(conn.tls_client); + allocator.destroy(conn.tls_client); } conn.stream.close(); - } - - pub fn deinit(conn: *Connection, client: *const Client) void { - conn.close(client); - client.allocator.free(conn.host); + allocator.free(conn.host); } }; @@ -331,7 +396,7 @@ pub const Response = struct { }; pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void { - var it = mem.tokenizeAny(u8, bytes[0 .. bytes.len - 4], "\r\n"); + var it = mem.tokenizeAny(u8, bytes, "\r\n"); const first_line = it.next() orelse return error.HttpHeadersInvalid; if (first_line.len < 12) @@ -350,6 +415,8 @@ pub const Response = struct { res.status = status; res.reason = reason; + res.headers.clearRetainingCapacity(); + while (it.next()) |line| { if (line.len == 0) return error.HttpHeadersInvalid; switch (line[0]) { @@ -365,46 +432,42 @@ pub const Response = struct { if (trailing) continue; - if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { // Transfer-Encoding: second, first // Transfer-Encoding: deflate, chunked var iter = mem.splitBackwardsScalar(u8, header_value, ','); - if (iter.next()) |first| { - const trimmed = mem.trim(u8, first, " "); + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); - if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { - if (res.transfer_encoding != null) return error.HttpHeadersInvalid; - res.transfer_encoding = te; - } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - if (res.transfer_compression != null) return error.HttpHeadersInvalid; - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; - if (iter.next()) |second| { - if (res.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + next = iter.next(); + } - const trimmed = mem.trim(u8, second, " "); + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.transfer_compression = transfer; } else { return error.HttpTransferEncodingUnsupported; } } if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; + + res.content_length = content_length; } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != null) return error.HttpHeadersInvalid; + if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; const trimmed = mem.trim(u8, header_value, " "); @@ -440,13 +503,21 @@ pub const Response = struct { status: http.Status, reason: []const u8, + /// If present, the number of bytes in the response body. content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - transfer_compression: ?http.ContentEncoding = null, + /// If present, the transfer encoding of the response body, otherwise none. + transfer_encoding: http.TransferEncoding = .none, + + /// If present, the compression of the response body, otherwise identity (no compression). + transfer_compression: http.ContentEncoding = .identity, + + /// The headers received from the server. headers: http.Headers, parser: proto.HeadersParser, compression: Compression = .none, + + /// Whether the response body should be skipped. Any data read from the response body will be discarded. skip: bool = false, }; @@ -457,15 +528,18 @@ pub const Request = struct { uri: Uri, client: *Client, /// is null when this connection is released - connection: ?*ConnectionPool.Node, + connection: ?*Connection, method: http.Method, version: http.Version = .@"HTTP/1.1", headers: http.Headers, + + /// The transfer encoding of the request body. transfer_encoding: RequestTransfer = .none, redirects_left: u32, handle_redirects: bool, + handle_continue: bool, response: Response, @@ -491,9 +565,9 @@ pub const Request = struct { if (req.connection) |connection| { if (!req.response.parser.done) { // If the response wasn't fully read, then we need to close the connection. - connection.data.closing = true; + connection.closing = true; } - req.client.connection_pool.release(req.client, connection); + req.client.connection_pool.release(req.client.allocator, connection); } req.arena.deinit(); @@ -512,7 +586,7 @@ pub const Request = struct { .zstd => |*zstd| zstd.deinit(), } - req.client.connection_pool.release(req.client, req.connection.?); + req.client.connection_pool.release(req.client.allocator, req.connection.?); req.connection = null; const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; @@ -539,42 +613,33 @@ pub const Request = struct { }; } - pub const StartError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; - pub const StartOptions = struct { - /// Specifies that the uri should be used as is + pub const SendOptions = struct { + /// Specifies that the uri should be used as is. You guarantee that the uri is already escaped. raw_uri: bool = false, }; - /// Send the request to the server. - pub fn start(req: *Request, options: StartOptions) StartError!void { + /// Send the HTTP request headers to the server. + pub fn send(req: *Request, options: SendOptions) SendError!void { if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding; - var buffered = std.io.bufferedWriter(req.connection.?.data.writer()); - const w = buffered.writer(); + const w = req.connection.?.writer(); try req.method.write(w); try w.writeByte(' '); if (req.method == .CONNECT) { - try w.writeAll(req.uri.host.?); - try w.writeByte(':'); - try w.print("{}", .{req.uri.port.?}); + try req.uri.writeToStream(.{ .authority = true }, w); } else { - if (req.connection.?.data.proxied) { - // proxied connections require the full uri - if (options.raw_uri) { - try w.print("{+/r}", .{req.uri}); - } else { - try w.print("{+/}", .{req.uri}); - } - } else { - if (options.raw_uri) { - try w.print("{/r}", .{req.uri}); - } else { - try w.print("{/}", .{req.uri}); - } - } + try req.uri.writeToStream(.{ + .scheme = req.connection.?.proxied, + .authentication = req.connection.?.proxied, + .authority = req.connection.?.proxied, + .path = true, + .query = true, + .raw = options.raw_uri, + }, w); } try w.writeByte(' '); try w.writeAll(@tagName(req.version)); @@ -582,7 +647,7 @@ pub const Request = struct { if (!req.headers.contains("host")) { try w.writeAll("Host: "); - try w.writeAll(req.uri.host.?); + try req.uri.writeToStream(.{ .authority = true }, w); try w.writeAll("\r\n"); } @@ -614,17 +679,17 @@ pub const Request = struct { .none => {}, } } else { - if (has_content_length) { - const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; - - req.transfer_encoding = .{ .content_length = content_length }; - } else if (has_transfer_encoding) { + if (has_transfer_encoding) { const transfer_encoding = req.headers.getFirstValue("transfer-encoding").?; if (std.mem.eql(u8, transfer_encoding, "chunked")) { req.transfer_encoding = .chunked; } else { return error.UnsupportedTransferEncoding; } + } else if (has_content_length) { + const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; + + req.transfer_encoding = .{ .content_length = content_length }; } else { req.transfer_encoding = .none; } @@ -639,9 +704,27 @@ pub const Request = struct { try w.writeAll("\r\n"); } + if (req.connection.?.proxied) { + const proxy_headers: ?http.Headers = switch (req.connection.?.protocol) { + .plain => if (req.client.http_proxy) |proxy| proxy.headers else null, + .tls => if (req.client.https_proxy) |proxy| proxy.headers else null, + }; + + if (proxy_headers) |headers| { + for (headers.list.items) |entry| { + if (entry.value.len == 0) continue; + + try w.writeAll(entry.name); + try w.writeAll(": "); + try w.writeAll(entry.value); + try w.writeAll("\r\n"); + } + } + } + try w.writeAll("\r\n"); - try buffered.flush(); + try req.connection.?.flush(); } const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; @@ -657,7 +740,7 @@ pub const Request = struct { var index: usize = 0; while (index == 0) { - const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip); + const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); if (amt == 0 and req.response.parser.done) break; index += amt; } @@ -665,20 +748,22 @@ pub const Request = struct { return index; } - pub const WaitError = RequestError || StartError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, RedirectRequiresResend, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported }; + pub const WaitError = RequestError || SendError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, RedirectRequiresResend, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported }; /// Waits for a response from the server and parses any headers that are sent. /// This function will block until the final response is received. /// /// If `handle_redirects` is true and the request has no payload, then this function will automatically follow /// redirects. If a request payload is present, then this function will error with error.RedirectRequiresResend. + /// + /// Must be called after `start` and, if any data was written to the request body, then also after `finish`. pub fn wait(req: *Request) WaitError!void { while (true) { // handle redirects while (true) { // read headers - try req.connection.?.data.fill(); + try req.connection.?.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); - req.connection.?.data.drop(@as(u16, @intCast(nchecked))); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); if (req.response.parser.state.isContent()) break; } @@ -688,12 +773,16 @@ pub const Request = struct { if (req.response.status == .@"continue") { req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response req.response.parser.reset(); + + if (req.handle_continue) + continue; + break; } // we're switching protocols, so this connection is no longer doing http if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) { - req.connection.?.data.closing = false; + req.connection.?.closing = false; req.response.parser.done = true; } @@ -704,13 +793,14 @@ pub const Request = struct { const res_connection = req.response.headers.getFirstValue("connection"); const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); if (res_keepalive and (req_keepalive or req_connection == null)) { - req.connection.?.data.closing = false; + req.connection.?.closing = false; } else { - req.connection.?.data.closing = true; + req.connection.?.closing = true; } - if (req.response.transfer_encoding) |te| { - switch (te) { + if (req.response.transfer_encoding != .none) { + switch (req.response.transfer_encoding) { + .none => unreachable, .chunked => { req.response.parser.next_chunk_length = 0; req.response.parser.state = .chunk_head_size; @@ -774,23 +864,23 @@ pub const Request = struct { try req.redirect(resolved_url); - try req.start(.{}); + try req.send(.{}); } else { req.response.skip = false; if (!req.response.parser.done) { - if (req.response.transfer_compression) |tc| switch (tc) { + switch (req.response.transfer_compression) { .identity => req.response.compression = .none, - .compress => return error.CompressionNotSupported, + .compress, .@"x-compress" => return error.CompressionNotSupported, .deflate => req.response.compression = .{ .deflate = std.compress.zlib.decompressStream(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed, }, - .gzip => req.response.compression = .{ + .gzip, .@"x-gzip" => req.response.compression = .{ .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => req.response.compression = .{ .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), }, - }; + } } break; @@ -806,7 +896,7 @@ pub const Request = struct { return .{ .context = req }; } - /// Reads data from the response body. Must be called after `do`. + /// Reads data from the response body. Must be called after `wait`. pub fn read(req: *Request, buffer: []u8) ReadError!usize { const out_index = switch (req.response.compression) { .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, @@ -819,15 +909,13 @@ pub const Request = struct { const has_trail = !req.response.parser.state.isContent(); while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.data.fill(); + try req.connection.?.fill(); - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek()); - req.connection.?.data.drop(@as(u16, @intCast(nchecked))); + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek()); + req.connection.?.drop(@intCast(nchecked)); } if (has_trail) { - req.response.headers.clearRetainingCapacity(); - // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. // This will *only* fail for a malformed trailer. req.response.parse(req.response.parser.header_bytes.items, true) catch return error.InvalidTrailers; @@ -837,7 +925,7 @@ pub const Request = struct { return out_index; } - /// Reads data from the response body. Must be called after `do`. + /// Reads data from the response body. Must be called after `wait`. pub fn readAll(req: *Request, buffer: []u8) !usize { var index: usize = 0; while (index < buffer.len) { @@ -856,20 +944,21 @@ pub const Request = struct { return .{ .context = req }; } - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn write(req: *Request, bytes: []const u8) WriteError!usize { switch (req.transfer_encoding) { .chunked => { - try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.data.writeAll(bytes); - try req.connection.?.data.writeAll("\r\n"); + try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); + try req.connection.?.writer().writeAll(bytes); + try req.connection.?.writer().writeAll("\r\n"); return bytes.len; }, .content_length => |*len| { if (len.* < bytes.len) return error.MessageTooLong; - const amt = try req.connection.?.data.write(bytes); + const amt = try req.connection.?.write(bytes); len.* -= amt; return amt; }, @@ -877,6 +966,8 @@ pub const Request = struct { } } + /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { @@ -887,50 +978,169 @@ pub const Request = struct { pub const FinishError = WriteError || error{MessageNotCompleted}; /// Finish the body of a request. This notifies the server that you have no more data to send. + /// Must be called after `start`. pub fn finish(req: *Request) FinishError!void { switch (req.transfer_encoding) { - .chunked => try req.connection.?.data.writeAll("0\r\n\r\n"), + .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } + + try req.connection.?.flush(); } }; -pub const HttpProxy = struct { - pub const ProxyAuthentication = union(enum) { - basic: []const u8, - custom: []const u8, - }; +pub const Proxy = struct { + allocator: Allocator, + headers: http.Headers, protocol: Connection.Protocol, host: []const u8, - port: ?u16 = null, + port: u16, - /// The value for the Proxy-Authorization header. - auth: ?ProxyAuthentication = null, + supports_connect: bool = true, }; /// Release all associated resources with the client. -/// TODO: currently leaks all request allocated data +/// +/// All pending requests must be de-initialized and all active connections released +/// before calling this function. pub fn deinit(client: *Client) void { - client.connection_pool.deinit(client); + assert(client.connection_pool.used.first == null); // There are still active requests. + + client.connection_pool.deinit(client.allocator); + + if (client.http_proxy) |*proxy| { + proxy.allocator.free(proxy.host); + proxy.headers.deinit(); + } + + if (client.https_proxy) |*proxy| { + proxy.allocator.free(proxy.host); + proxy.headers.deinit(); + } + + if (!disable_tls) + client.ca_bundle.deinit(client.allocator); - client.ca_bundle.deinit(client.allocator); client.* = undefined; } -pub const ConnectUnproxiedError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +/// Uses the *_proxy environment variable to set any unset proxies for the client. +/// This function *must not* be called when the client has any active connections. +pub fn loadDefaultProxies(client: *Client) !void { + // Prevent any new connections from being created. + client.connection_pool.mutex.lock(); + defer client.connection_pool.mutex.unlock(); + + assert(client.connection_pool.used.first == null); // There are still active requests. + + if (client.http_proxy == null) http: { + const content: []const u8 = if (std.process.hasEnvVarConstant("http_proxy")) + try std.process.getEnvVarOwned(client.allocator, "http_proxy") + else if (std.process.hasEnvVarConstant("HTTP_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "HTTP_PROXY") + else if (std.process.hasEnvVarConstant("all_proxy")) + try std.process.getEnvVarOwned(client.allocator, "all_proxy") + else if (std.process.hasEnvVarConstant("ALL_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "ALL_PROXY") + else + break :http; + defer client.allocator.free(content); + + const uri = try Uri.parse(content); + + const protocol = protocol_map.get(uri.scheme) orelse break :http; // Unknown scheme, ignore + const host = if (uri.host) |host| try client.allocator.dupe(u8, host) else break :http; // Missing host, ignore + client.http_proxy = .{ + .allocator = client.allocator, + .headers = .{ .allocator = client.allocator }, + + .protocol = protocol, + .host = host, + .port = uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }, + }; + + if (uri.user != null and uri.password != null) { + const prefix_len = "Basic ".len; + + const unencoded = try std.fmt.allocPrint(client.allocator, "{s}:{s}", .{ uri.user.?, uri.password.? }); + defer client.allocator.free(unencoded); + + const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len) + prefix_len); + defer client.allocator.free(buffer); + + const result = std.base64.standard.Encoder.encode(buffer[prefix_len..], unencoded); + @memcpy(buffer[0..prefix_len], "Basic "); + + try client.http_proxy.?.headers.append("proxy-authorization", result); + } + } + + if (client.https_proxy == null) https: { + const content: []const u8 = if (std.process.hasEnvVarConstant("https_proxy")) + try std.process.getEnvVarOwned(client.allocator, "https_proxy") + else if (std.process.hasEnvVarConstant("HTTPS_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "HTTPS_PROXY") + else if (std.process.hasEnvVarConstant("all_proxy")) + try std.process.getEnvVarOwned(client.allocator, "all_proxy") + else if (std.process.hasEnvVarConstant("ALL_PROXY")) + try std.process.getEnvVarOwned(client.allocator, "ALL_PROXY") + else + break :https; + defer client.allocator.free(content); + + const uri = try Uri.parse(content); + + const protocol = protocol_map.get(uri.scheme) orelse break :https; // Unknown scheme, ignore + const host = if (uri.host) |host| try client.allocator.dupe(u8, host) else break :https; // Missing host, ignore + client.http_proxy = .{ + .allocator = client.allocator, + .headers = .{ .allocator = client.allocator }, + + .protocol = protocol, + .host = host, + .port = uri.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }, + }; + + if (uri.user != null and uri.password != null) { + const prefix_len = "Basic ".len; + + const unencoded = try std.fmt.allocPrint(client.allocator, "{s}:{s}", .{ uri.user.?, uri.password.? }); + defer client.allocator.free(unencoded); + + const buffer = try client.allocator.alloc(u8, std.base64.standard.Encoder.calcSize(unencoded.len) + prefix_len); + defer client.allocator.free(buffer); + + const result = std.base64.standard.Encoder.encode(buffer[prefix_len..], unencoded); + @memcpy(buffer[0..prefix_len], "Basic "); + + try client.https_proxy.?.headers.append("proxy-authorization", result); + } + } +} + +pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. /// This function is threadsafe. -pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectUnproxiedError!*ConnectionPool.Node { +pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { if (client.connection_pool.findConnection(.{ .host = host, .port = port, - .is_tls = protocol == .tls, + .protocol = protocol, })) |node| return node; + if (disable_tls and protocol == .tls) + return error.TlsInitializationFailed; + const conn = try client.allocator.create(ConnectionPool.Node); errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; @@ -951,40 +1161,41 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: conn.data = .{ .stream = stream, .tls_client = undefined, - .protocol = protocol, + .protocol = protocol, .host = try client.allocator.dupe(u8, host), .port = port, }; errdefer client.allocator.free(conn.data.host); - switch (protocol) { - .plain => {}, - .tls => { - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.data.tls_client); + if (protocol == .tls) { + if (disable_tls) unreachable; - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; - }, + conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + errdefer client.allocator.destroy(conn.data.tls_client); + + conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + conn.data.tls_client.allow_truncation_attacks = true; } client.connection_pool.addUsed(conn); - return conn; + return &conn.data; } pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError; -pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node { +/// Connect to `path` as a unix domain socket. This will reuse a connection if one is already open. +/// This function is threadsafe. +pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection { if (!net.has_unix_sockets) return error.Unsupported; if (client.connection_pool.findConnection(.{ .host = path, .port = 0, - .is_tls = false, + .protocol = .plain, })) |node| return node; @@ -1007,37 +1218,130 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti client.connection_pool.addUsed(conn); - return conn; + return &conn.data; } -// Prevents a dependency loop in request() -const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; +/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP CONNECT. This will reuse a connection if one is already open. +/// This function is threadsafe. +pub fn connectTunnel( + client: *Client, + proxy: *Proxy, + tunnel_host: []const u8, + tunnel_port: u16, +) !*Connection { + if (!proxy.supports_connect) return error.TunnelNotSupported; -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .is_tls = protocol == .tls, + .host = tunnel_host, + .port = tunnel_port, + .protocol = proxy.protocol, })) |node| return node; - if (client.proxy) |proxy| { - const proxy_port: u16 = proxy.port orelse switch (proxy.protocol) { - .plain => 80, - .tls => 443, + var maybe_valid = false; + (tunnel: { + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(client.allocator, conn); + } + + const uri = Uri{ + .scheme = "http", + .user = null, + .password = null, + .host = tunnel_host, + .port = tunnel_port, + .path = "", + .query = null, + .fragment = null, + }; + + // we can use a small buffer here because a CONNECT response should be very small + var buffer: [8096]u8 = undefined; + + var req = client.open(.CONNECT, uri, proxy.headers, .{ + .handle_redirects = false, + .connection = conn, + .header_strategy = .{ .static = &buffer }, + }) catch |err| { + std.log.debug("err {}", .{err}); + break :tunnel err; }; + defer req.deinit(); + + req.send(.{ .raw_uri = true }) catch |err| break :tunnel err; + req.wait() catch |err| break :tunnel err; + + if (req.response.status.class() == .server_error) { + maybe_valid = true; + break :tunnel error.ServerError; + } + + if (req.response.status != .ok) break :tunnel error.ConnectionRefused; - const conn = try client.connectUnproxied(proxy.host, proxy_port, proxy.protocol); - conn.data.proxied = true; + // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + req.connection = null; + + client.allocator.free(conn.host); + conn.host = try client.allocator.dupe(u8, tunnel_host); + errdefer client.allocator.free(conn.host); + + conn.port = tunnel_port; + conn.closing = false; return conn; - } else { - return client.connectUnproxied(host, port, protocol); + }) catch { + // something went wrong with the tunnel + proxy.supports_connect = maybe_valid; + return error.TunnelNotSupported; + }; +} + +// Prevents a dependency loop in request() +const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused }; +pub const ConnectError = ConnectErrorPartial || RequestError; + +/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// +/// If a proxy is configured for the client, then the proxy will be used to connect to the host. +/// +/// This function is threadsafe. +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection { + // pointer required so that `supports_connect` can be updated if a CONNECT fails + const potential_proxy: ?*Proxy = switch (protocol) { + .plain => if (client.http_proxy) |*proxy_info| proxy_info else null, + .tls => if (client.https_proxy) |*proxy_info| proxy_info else null, + }; + + if (potential_proxy) |proxy| { + // don't attempt to proxy the proxy thru itself. + if (std.mem.eql(u8, proxy.host, host) and proxy.port == port and proxy.protocol == protocol) { + return client.connectTcp(host, port, protocol); + } + + if (proxy.supports_connect) tunnel: { + return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + error.TunnelNotSupported => break :tunnel, + else => |e| return e, + }; + } + + // fall back to using the proxy as a normal http proxy + const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + errdefer { + conn.closing = true; + client.connection_pool.release(conn); + } + + conn.proxied = true; + return conn; } + + return client.connectTcp(host, port, protocol); } -pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{ +pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || std.fmt.ParseIntError || Connection.WriteError || error{ UnsupportedUrlScheme, UriMissingHost, @@ -1048,12 +1352,20 @@ pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request pub const RequestOptions = struct { version: http.Version = .@"HTTP/1.1", + /// Automatically ignore 100 Continue responses. This assumes you don't care, and will have sent the body before you + /// wait for the response. + /// + /// If this is not the case AND you know the server will send a 100 Continue, set this to false and wait for a + /// response before sending the body. If you wait AND the server does not send a 100 Continue before you finish the + /// request, then the request *will* deadlock. + handle_continue: bool = true, + handle_redirects: bool = true, max_redirects: u32 = 3, header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 }, /// Must be an already acquired connection. - connection: ?*ConnectionPool.Node = null, + connection: ?*Connection = null, pub const StorageStrategy = union(enum) { /// In this case, the client's Allocator will be used to store the @@ -1076,14 +1388,14 @@ pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{ .{ "wss", .tls }, }); -/// Form and send a http request to a server. +/// Open a connection to the host specified by `uri` and prepare to send a HTTP request. /// /// `uri` must remain alive during the entire request. /// `headers` is cloned and may be freed after this function returns. /// /// The caller is responsible for calling `deinit()` on the `Request`. /// This function is threadsafe. -pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: RequestOptions) RequestError!Request { +pub fn open(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: RequestOptions) RequestError!Request { const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = uri.port orelse switch (protocol) { @@ -1094,6 +1406,8 @@ pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Hea const host = uri.host orelse return error.UriMissingHost; if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) { + if (disable_tls) unreachable; + client.ca_bundle_mutex.lock(); defer client.ca_bundle_mutex.unlock(); @@ -1114,6 +1428,7 @@ pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Hea .version = options.version, .redirects_left = options.max_redirects, .handle_redirects = options.handle_redirects, + .handle_continue = options.handle_continue, .response = .{ .status = undefined, .reason = undefined, @@ -1178,6 +1493,9 @@ pub const FetchResult = struct { } }; +/// Perform a one-shot HTTP request with the provided options. +/// +/// This function is threadsafe. pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !FetchResult { const has_transfer_encoding = options.headers.contains("transfer-encoding"); const has_content_length = options.headers.contains("content-length"); @@ -1189,7 +1507,7 @@ pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !Fetc .uri => |u| u, }; - var req = try request(client, options.method, uri, options.headers, .{ + var req = try open(client, options.method, uri, options.headers, .{ .header_strategy = options.header_strategy, .handle_redirects = options.payload == .none, }); @@ -1206,7 +1524,7 @@ pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !Fetc .none => {}, } - try req.start(.{ .raw_uri = options.raw_uri }); + try req.send(.{ .raw_uri = options.raw_uri }); switch (options.payload) { .string => |str| try req.writeAll(str), diff --git a/lib/std/http/Headers.zig b/lib/std/http/Headers.zig index f69f7fef5f..d2e578b5ee 100644 --- a/lib/std/http/Headers.zig +++ b/lib/std/http/Headers.zig @@ -14,15 +14,18 @@ pub const CaseInsensitiveStringContext = struct { pub fn hash(self: @This(), s: []const u8) u64 { _ = self; var buf: [64]u8 = undefined; - var i: u8 = 0; + var i: usize = 0; var h = std.hash.Wyhash.init(0); - while (i < s.len) : (i += 64) { - const left = @min(64, s.len - i); - const ret = ascii.lowerString(buf[0..], s[i..][0..left]); + while (i + 64 < s.len) : (i += 64) { + const ret = ascii.lowerString(buf[0..], s[i..][0..64]); h.update(ret); } + const left = @min(64, s.len - i); + const ret = ascii.lowerString(buf[0..], s[i..][0..left]); + h.update(ret); + return h.final(); } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 56d14fe7fc..adb1ea0812 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -14,7 +14,7 @@ allocator: Allocator, socket: net.StreamServer, -/// An interface to either a plain or TLS connection. +/// An interface to a plain connection. pub const Connection = struct { pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; pub const Protocol = enum { plain }; @@ -178,7 +178,7 @@ pub const Request = struct { }; pub fn parse(req: *Request, bytes: []const u8) ParseError!void { - var it = mem.tokenizeAny(u8, bytes[0 .. bytes.len - 4], "\r\n"); + var it = mem.tokenizeAny(u8, bytes, "\r\n"); const first_line = it.next() orelse return error.HttpHeadersInvalid; if (first_line.len < 10) @@ -228,27 +228,23 @@ pub const Request = struct { // Transfer-Encoding: deflate, chunked var iter = mem.splitBackwardsScalar(u8, header_value, ','); - if (iter.next()) |first| { - const trimmed = mem.trim(u8, first, " "); + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); - if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { - if (req.transfer_encoding != null) return error.HttpHeadersInvalid; - req.transfer_encoding = te; - } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - if (req.transfer_compression != null) return error.HttpHeadersInvalid; - req.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (req.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + req.transfer_encoding = transfer; - if (iter.next()) |second| { - if (req.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + next = iter.next(); + } - const trimmed = mem.trim(u8, second, " "); + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - req.transfer_compression = ce; + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported + req.transfer_compression = transfer; } else { return error.HttpTransferEncodingUnsupported; } @@ -256,7 +252,7 @@ pub const Request = struct { if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (req.transfer_compression != null) return error.HttpHeadersInvalid; + if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; const trimmed = mem.trim(u8, header_value, " "); @@ -277,9 +273,14 @@ pub const Request = struct { target: []const u8, version: http.Version, + /// The length of the request body, if known. content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - transfer_compression: ?http.ContentEncoding = null, + + /// The transfer encoding of the request body, or .none if not present. + transfer_encoding: http.TransferEncoding = .none, + + /// The compression of the request body, or .identity (no compression) if not present. + transfer_compression: http.ContentEncoding = .identity, headers: http.Headers, parser: proto.HeadersParser, @@ -315,6 +316,7 @@ pub const Response = struct { finished, }; + /// Free all resources associated with this response. pub fn deinit(res: *Response) void { res.connection.close(); @@ -390,10 +392,10 @@ pub const Response = struct { } } - pub const DoError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + pub const SendError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; - /// Send the response headers. - pub fn do(res: *Response) DoError!void { + /// Send the HTTP response headers to the client. + pub fn send(res: *Response) SendError!void { switch (res.state) { .waited => res.state = .responded, .first, .start, .responded, .finished => unreachable, @@ -511,8 +513,9 @@ pub const Response = struct { res.request.headers = .{ .allocator = res.allocator, .owned = true }; try res.request.parse(res.request.parser.header_bytes.items); - if (res.request.transfer_encoding) |te| { - switch (te) { + if (res.request.transfer_encoding != .none) { + switch (res.request.transfer_encoding) { + .none => unreachable, .chunked => { res.request.parser.next_chunk_length = 0; res.request.parser.state = .chunk_head_size; @@ -527,19 +530,19 @@ pub const Response = struct { } if (!res.request.parser.done) { - if (res.request.transfer_compression) |tc| switch (tc) { + switch (res.request.transfer_compression) { .identity => res.request.compression = .none, - .compress => return error.CompressionNotSupported, + .compress, .@"x-compress" => return error.CompressionNotSupported, .deflate => res.request.compression = .{ .deflate = std.compress.zlib.decompressStream(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, - .gzip => res.request.compression = .{ + .gzip, .@"x-gzip" => res.request.compression = .{ .gzip = std.compress.gzip.decompress(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => res.request.compression = .{ .zstd = std.compress.zstd.decompressStream(res.allocator, res.transferReader()), }, - }; + } } } @@ -551,6 +554,7 @@ pub const Response = struct { return .{ .context = res }; } + /// Reads data from the response body. Must be called after `wait`. pub fn read(res: *Response, buffer: []u8) ReadError!usize { switch (res.state) { .waited, .responded, .finished => {}, @@ -586,6 +590,7 @@ pub const Response = struct { return out_index; } + /// Reads data from the response body. Must be called after `wait`. pub fn readAll(res: *Response, buffer: []u8) !usize { var index: usize = 0; while (index < buffer.len) { @@ -605,6 +610,7 @@ pub const Response = struct { } /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn write(res: *Response, bytes: []const u8) WriteError!usize { switch (res.state) { .responded => {}, @@ -630,6 +636,8 @@ pub const Response = struct { } } + /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + /// Must be called after `start` and before `finish`. pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { @@ -640,6 +648,7 @@ pub const Response = struct { pub const FinishError = WriteError || error{MessageNotCompleted}; /// Finish the body of a request. This notifies the server that you have no more data to send. + /// Must be called after `start`. pub fn finish(res: *Response) FinishError!void { switch (res.state) { .responded => res.state = .finished, @@ -654,6 +663,7 @@ pub const Response = struct { } }; +/// Create a new HTTP server. pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server { return .{ .allocator = allocator, @@ -661,6 +671,7 @@ pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server { }; } +/// Free all resources associated with this server. pub fn deinit(server: *Server) void { server.socket.deinit(); } @@ -756,13 +767,13 @@ test "HTTP server handles a chunked transfer coding request" { defer _ = res.reset(); try res.wait(); - try expect(res.request.transfer_encoding.? == .chunked); + try expect(res.request.transfer_encoding == .chunked); const server_body: []const u8 = "message from server!\n"; res.transfer_encoding = .{ .content_length = server_body.len }; try res.headers.append("content-type", "text/plain"); try res.headers.append("connection", "close"); - try res.do(); + try res.send(); var buf: [128]u8 = undefined; const n = try res.readAll(&buf); diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index a369c38581..74e0207f34 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -529,7 +529,7 @@ pub const HeadersParser = struct { try conn.fill(); const nread = @min(conn.peek().len, data_avail); - conn.drop(@as(u16, @intCast(nread))); + conn.drop(@intCast(nread)); r.next_chunk_length -= nread; if (r.next_chunk_length == 0) r.done = true; @@ -553,7 +553,7 @@ pub const HeadersParser = struct { try conn.fill(); const i = r.findChunkedLen(conn.peek()); - conn.drop(@as(u16, @intCast(i))); + conn.drop(@intCast(i)); switch (r.state) { .invalid => return error.HttpChunkInvalid, @@ -582,7 +582,7 @@ pub const HeadersParser = struct { try conn.fill(); const nread = @min(conn.peek().len, data_avail); - conn.drop(@as(u16, @intCast(nread))); + conn.drop(@intCast(nread)); r.next_chunk_length -= nread; } else if (out_avail > 0) { const can_read: usize = @intCast(@min(data_avail, out_avail)); diff --git a/lib/std/std.zig b/lib/std/std.zig index 16222e52da..7342cadbef 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -283,10 +283,15 @@ pub const options = struct { else false; - pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size")) - options_override.http_connection_pool_size + /// By default, std.http.Client will support HTTPS connections. Set this option to `true` to + /// disable TLS support. + /// + /// This will likely reduce the size of the binary, but it will also make it impossible to + /// make a HTTPS connection. + pub const http_disable_tls = if (@hasDecl(options_override, "http_disable_tls")) + options_override.http_disable_tls else - http.Client.default_connection_pool_size; + false; pub const side_channels_mitigations: crypto.SideChannelsMitigations = if (@hasDecl(options_override, "side_channels_mitigations")) options_override.side_channels_mitigations |
