aboutsummaryrefslogtreecommitdiff
path: root/lib/std/http/Client.zig
diff options
context:
space:
mode:
Diffstat (limited to 'lib/std/http/Client.zig')
-rw-r--r--lib/std/http/Client.zig941
1 files changed, 799 insertions, 142 deletions
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
index 76073c0ce3..6e1b2cb226 100644
--- a/lib/std/http/Client.zig
+++ b/lib/std/http/Client.zig
@@ -1,35 +1,35 @@
-//! TODO: send connection: keep-alive and LRU cache a configurable number of
-//! open connections to skip DNS and TLS handshake for subsequent requests.
-//!
-//! This API is *not* thread safe.
+//! Connecting and opening requests are threadsafe. Individual requests are not.
const std = @import("../std.zig");
-const mem = std.mem;
-const assert = std.debug.assert;
+const testing = std.testing;
const http = std.http;
+const mem = std.mem;
const net = std.net;
-const Client = @This();
const Uri = std.Uri;
-const Allocator = std.mem.Allocator;
-const testing = std.testing;
+const Allocator = mem.Allocator;
+const assert = std.debug.assert;
-pub const Request = @import("Client/Request.zig");
-pub const Response = @import("Client/Response.zig");
+const Client = @This();
+const proto = @import("protocol.zig");
pub const default_connection_pool_size = 32;
-const connection_pool_size = std.options.http_connection_pool_size;
+pub const connection_pool_size = std.options.http_connection_pool_size;
-/// Used for tcpConnectToHost and storing HTTP headers when an externally
-/// managed buffer is not provided.
allocator: Allocator,
ca_bundle: std.crypto.Certificate.Bundle = .{},
+ca_bundle_mutex: std.Thread.Mutex = .{},
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
next_https_rescan_certs: bool = true,
+/// The pool of connections that can be reused (and currently in use).
connection_pool: ConnectionPool = .{},
+proxy: ?HttpProxy = null,
+
+/// A set of linked lists of connections that can be reused.
pub const ConnectionPool = struct {
+ /// The criteria for a connection to be considered a match.
pub const Criteria = struct {
host: []const u8,
port: u16,
@@ -40,10 +40,12 @@ pub const ConnectionPool = struct {
pub const Node = Queue.Node;
mutex: std.Thread.Mutex = .{},
+ /// Open connections that are currently in use.
used: Queue = .{},
+ /// Open connections that are not currently in use.
free: Queue = .{},
free_len: usize = 0,
- free_size: usize = default_connection_pool_size,
+ free_size: usize = connection_pool_size,
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
/// If no connection is found, null is returned.
@@ -55,7 +57,7 @@ pub const ConnectionPool = struct {
while (next) |node| : (next = node.prev) {
if ((node.data.protocol == .tls) != criteria.is_tls) continue;
if (node.data.port != criteria.port) continue;
- if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
+ if (!mem.eql(u8, node.data.host, criteria.host)) continue;
pool.acquireUnsafe(node);
return node;
@@ -89,7 +91,7 @@ pub const ConnectionPool = struct {
pool.used.remove(node);
if (node.data.closing) {
- node.data.close(client);
+ node.data.deinit(client);
return client.allocator.destroy(node);
}
@@ -97,12 +99,17 @@ pub const ConnectionPool = struct {
if (pool.free_len + 1 >= pool.free_size) {
const popped = pool.free.popFirst() orelse unreachable;
- popped.data.close(client);
+ popped.data.deinit(client);
return client.allocator.destroy(popped);
}
- pool.free.append(node);
+ if (node.data.proxied) {
+ pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first
+ } else {
+ pool.free.append(node);
+ }
+
pool.free_len += 1;
}
@@ -122,7 +129,7 @@ pub const ConnectionPool = struct {
defer client.allocator.destroy(node);
next = node.next;
- node.data.close(client);
+ node.data.deinit(client);
}
next = pool.used.first;
@@ -130,54 +137,114 @@ pub const ConnectionPool = struct {
defer client.allocator.destroy(node);
next = node.next;
- node.data.close(client);
+ node.data.deinit(client);
}
pool.* = undefined;
}
};
-pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
-pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
-pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.ReaderRaw, .{});
-
+/// An interface to either a plain or TLS connection.
pub const Connection = struct {
+ pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
+ pub const Protocol = enum { plain, tls };
+
stream: net.Stream,
/// undefined unless protocol is tls.
- tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
+ tls_client: *std.crypto.tls.Client,
+
protocol: Protocol,
host: []u8,
port: u16,
- // This connection has been part of a non keepalive request and cannot be added to the pool.
+ proxied: bool = false,
closing: bool = false,
- pub const Protocol = enum { plain, tls };
+ read_start: u16 = 0,
+ read_end: u16 = 0,
+ read_buf: [buffer_size]u8 = undefined,
+
+ pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+ return switch (conn.protocol) {
+ .plain => conn.stream.readAtLeast(buffer, len),
+ .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
+ } catch |err| {
+ // TODO: https://github.com/ziglang/zig/issues/2473
+ if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
+
+ switch (err) {
+ error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
+ error.ConnectionTimedOut => return error.ConnectionTimedOut,
+ error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+ else => return error.UnexpectedReadFailure,
+ }
+ };
+ }
- pub fn read(conn: *Connection, buffer: []u8) !usize {
- switch (conn.protocol) {
- .plain => return conn.stream.read(buffer),
- .tls => return conn.tls_client.read(conn.stream, buffer),
- }
+ pub fn fill(conn: *Connection) ReadError!void {
+ if (conn.read_end != conn.read_start) return;
+
+ const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
+ if (nread == 0) return error.EndOfStream;
+ conn.read_start = 0;
+ conn.read_end = @intCast(u16, nread);
}
- pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
- switch (conn.protocol) {
- .plain => return conn.stream.readAtLeast(buffer, len),
- .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
+ pub fn peek(conn: *Connection) []const u8 {
+ return conn.read_buf[conn.read_start..conn.read_end];
+ }
+
+ pub fn drop(conn: *Connection, num: u16) void {
+ conn.read_start += num;
+ }
+
+ pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+ assert(len <= buffer.len);
+
+ var out_index: u16 = 0;
+ while (out_index < len) {
+ const available_read = conn.read_end - conn.read_start;
+ const available_buffer = buffer.len - out_index;
+
+ if (available_read > available_buffer) { // partially read buffered data
+ @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
+ out_index += @intCast(u16, available_buffer);
+ conn.read_start += @intCast(u16, available_buffer);
+
+ break;
+ } else if (available_read > 0) { // fully read buffered data
+ @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
+ out_index += available_read;
+ conn.read_start += available_read;
+
+ if (out_index >= len) break;
+ }
+
+ const leftover_buffer = available_buffer - available_read;
+ const leftover_len = len - out_index;
+
+ if (leftover_buffer > conn.read_buf.len) {
+ // skip the buffer if the output is large enough
+ return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
+ }
+
+ try conn.fill();
}
+
+ return out_index;
}
- pub const ReadError = net.Stream.ReadError || error{
- TlsConnectionTruncated,
- TlsRecordOverflow,
- TlsDecodeError,
+ pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+ return conn.readAtLeast(buffer, 1);
+ }
+
+ pub const ReadError = error{
+ TlsFailure,
TlsAlert,
- TlsBadRecordMac,
- Overflow,
- TlsBadLength,
- TlsIllegalParameter,
- TlsUnexpectedMessage,
+ ConnectionTimedOut,
+ ConnectionResetByPeer,
+ UnexpectedReadFailure,
+ EndOfStream,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
@@ -187,20 +254,30 @@ pub const Connection = struct {
}
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
- switch (conn.protocol) {
- .plain => return conn.stream.writeAll(buffer),
- .tls => return conn.tls_client.writeAll(conn.stream, buffer),
- }
+ return switch (conn.protocol) {
+ .plain => conn.stream.writeAll(buffer),
+ .tls => conn.tls_client.writeAll(conn.stream, buffer),
+ } catch |err| switch (err) {
+ error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+ else => return error.UnexpectedWriteFailure,
+ };
}
pub fn write(conn: *Connection, buffer: []const u8) !usize {
- switch (conn.protocol) {
- .plain => return conn.stream.write(buffer),
- .tls => return conn.tls_client.write(conn.stream, buffer),
- }
+ return switch (conn.protocol) {
+ .plain => conn.stream.write(buffer),
+ .tls => conn.tls_client.write(conn.stream, buffer),
+ } catch |err| switch (err) {
+ error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+ else => return error.UnexpectedWriteFailure,
+ };
}
- pub const WriteError = net.Stream.WriteError || error{};
+ pub const WriteError = error{
+ ConnectionResetByPeer,
+ UnexpectedWriteFailure,
+ };
+
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
@@ -215,11 +292,569 @@ pub const Connection = struct {
}
conn.stream.close();
+ }
+ pub fn deinit(conn: *Connection, client: *const Client) void {
+ conn.close(client);
client.allocator.free(conn.host);
}
};
+/// The mode of transport for requests.
+pub const RequestTransfer = union(enum) {
+ content_length: u64,
+ chunked: void,
+ none: void,
+};
+
+/// The decompressor for response messages.
+pub const Compression = union(enum) {
+ pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.TransferReader);
+ pub const GzipDecompressor = std.compress.gzip.Decompress(Request.TransferReader);
+ pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{});
+
+ deflate: DeflateDecompressor,
+ gzip: GzipDecompressor,
+ zstd: ZstdDecompressor,
+ none: void,
+};
+
+/// A HTTP response originating from a server.
+pub const Response = struct {
+ pub const ParseError = Allocator.Error || error{
+ HttpHeadersInvalid,
+ HttpHeaderContinuationsUnsupported,
+ HttpTransferEncodingUnsupported,
+ HttpConnectionHeaderUnsupported,
+ InvalidContentLength,
+ CompressionNotSupported,
+ };
+
+ pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void {
+ var it = mem.tokenizeAny(u8, bytes[0 .. bytes.len - 4], "\r\n");
+
+ const first_line = it.next() orelse return error.HttpHeadersInvalid;
+ if (first_line.len < 12)
+ return error.HttpHeadersInvalid;
+
+ const version: http.Version = switch (int64(first_line[0..8])) {
+ int64("HTTP/1.0") => .@"HTTP/1.0",
+ int64("HTTP/1.1") => .@"HTTP/1.1",
+ else => return error.HttpHeadersInvalid,
+ };
+ if (first_line[8] != ' ') return error.HttpHeadersInvalid;
+ const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
+ const reason = mem.trimLeft(u8, first_line[12..], " ");
+
+ res.version = version;
+ res.status = status;
+ res.reason = reason;
+
+ while (it.next()) |line| {
+ if (line.len == 0) return error.HttpHeadersInvalid;
+ switch (line[0]) {
+ ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
+ else => {},
+ }
+
+ var line_it = mem.tokenizeAny(u8, line, ": ");
+ const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
+ const header_value = line_it.rest();
+
+ try res.headers.append(header_name, header_value);
+
+ if (trailing) continue;
+
+ if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
+ if (res.content_length != null) return error.HttpHeadersInvalid;
+ res.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
+ // Transfer-Encoding: second, first
+ // Transfer-Encoding: deflate, chunked
+ var iter = mem.splitBackwardsScalar(u8, header_value, ',');
+
+ if (iter.next()) |first| {
+ const trimmed = mem.trim(u8, first, " ");
+
+ if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
+ if (res.transfer_encoding != null) return error.HttpHeadersInvalid;
+ res.transfer_encoding = te;
+ } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ if (res.transfer_compression != null) return error.HttpHeadersInvalid;
+ res.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |second| {
+ if (res.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
+
+ const trimmed = mem.trim(u8, second, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ res.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
+ if (res.transfer_compression != null) return error.HttpHeadersInvalid;
+
+ const trimmed = mem.trim(u8, header_value, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ res.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+ }
+ }
+
+ inline fn int64(array: *const [8]u8) u64 {
+ return @bitCast(u64, array.*);
+ }
+
+ fn parseInt3(nnn: @Vector(3, u8)) u10 {
+ const zero: @Vector(3, u8) = .{ '0', '0', '0' };
+ const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
+ return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
+ }
+
+ test parseInt3 {
+ const expectEqual = testing.expectEqual;
+ try expectEqual(@as(u10, 0), parseInt3("000".*));
+ try expectEqual(@as(u10, 418), parseInt3("418".*));
+ try expectEqual(@as(u10, 999), parseInt3("999".*));
+ }
+
+ version: http.Version,
+ status: http.Status,
+ reason: []const u8,
+
+ content_length: ?u64 = null,
+ transfer_encoding: ?http.TransferEncoding = null,
+ transfer_compression: ?http.ContentEncoding = null,
+
+ headers: http.Headers,
+ parser: proto.HeadersParser,
+ compression: Compression = .none,
+ skip: bool = false,
+};
+
+/// A HTTP request that has been sent.
+///
+/// Order of operations: request -> start[ -> write -> finish] -> wait -> read
+pub const Request = struct {
+ uri: Uri,
+ client: *Client,
+ connection: *ConnectionPool.Node,
+
+ method: http.Method,
+ version: http.Version = .@"HTTP/1.1",
+ headers: http.Headers,
+ transfer_encoding: RequestTransfer = .none,
+
+ redirects_left: u32,
+ handle_redirects: bool,
+
+ response: Response,
+
+ /// Used as a allocator for resolving redirects locations.
+ arena: std.heap.ArenaAllocator,
+
+ /// Frees all resources associated with the request.
+ pub fn deinit(req: *Request) void {
+ switch (req.response.compression) {
+ .none => {},
+ .deflate => |*deflate| deflate.deinit(),
+ .gzip => |*gzip| gzip.deinit(),
+ .zstd => |*zstd| zstd.deinit(),
+ }
+
+ req.response.headers.deinit();
+
+ if (req.response.parser.header_bytes_owned) {
+ req.response.parser.header_bytes.deinit(req.client.allocator);
+ }
+
+ if (!req.response.parser.done) {
+ // If the response wasn't fully read, then we need to close the connection.
+ req.connection.data.closing = true;
+ }
+
+ req.client.connection_pool.release(req.client, req.connection);
+
+ req.arena.deinit();
+ req.* = undefined;
+ }
+
+ // This function must deallocate all resources associated with the request, or keep those which will be used
+ // This needs to be kept in sync with deinit and request
+ fn redirect(req: *Request, uri: Uri) !void {
+ assert(req.response.parser.done);
+
+ switch (req.response.compression) {
+ .none => {},
+ .deflate => |*deflate| deflate.deinit(),
+ .gzip => |*gzip| gzip.deinit(),
+ .zstd => |*zstd| zstd.deinit(),
+ }
+
+ req.client.connection_pool.release(req.client, req.connection);
+
+ const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
+
+ const port: u16 = uri.port orelse switch (protocol) {
+ .plain => 80,
+ .tls => 443,
+ };
+
+ const host = uri.host orelse return error.UriMissingHost;
+
+ req.uri = uri;
+ req.connection = try req.client.connect(host, port, protocol);
+ req.redirects_left -= 1;
+ req.response.headers.clearRetainingCapacity();
+ req.response.parser.reset();
+
+ req.response = .{
+ .status = undefined,
+ .reason = undefined,
+ .version = undefined,
+ .headers = req.response.headers,
+ .parser = req.response.parser,
+ };
+ }
+
+ pub const StartError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
+
+ /// Send the request to the server.
+ pub fn start(req: *Request) StartError!void {
+ var buffered = std.io.bufferedWriter(req.connection.data.writer());
+ const w = buffered.writer();
+
+ try w.writeAll(@tagName(req.method));
+ try w.writeByte(' ');
+
+ if (req.method == .CONNECT) {
+ try w.writeAll(req.uri.host.?);
+ try w.writeByte(':');
+ try w.print("{}", .{req.uri.port.?});
+ } else if (req.connection.data.proxied) {
+ // proxied connections require the full uri
+ try w.print("{+/}", .{req.uri});
+ } else {
+ try w.print("{/}", .{req.uri});
+ }
+
+ try w.writeByte(' ');
+ try w.writeAll(@tagName(req.version));
+ try w.writeAll("\r\n");
+
+ if (!req.headers.contains("host")) {
+ try w.writeAll("Host: ");
+ try w.writeAll(req.uri.host.?);
+ try w.writeAll("\r\n");
+ }
+
+ if (!req.headers.contains("user-agent")) {
+ try w.writeAll("User-Agent: zig/");
+ try w.writeAll(@import("builtin").zig_version_string);
+ try w.writeAll(" (std.http)\r\n");
+ }
+
+ if (!req.headers.contains("connection")) {
+ try w.writeAll("Connection: keep-alive\r\n");
+ }
+
+ if (!req.headers.contains("accept-encoding")) {
+ try w.writeAll("Accept-Encoding: gzip, deflate, zstd\r\n");
+ }
+
+ if (!req.headers.contains("te")) {
+ try w.writeAll("TE: gzip, deflate, trailers\r\n");
+ }
+
+ const has_transfer_encoding = req.headers.contains("transfer-encoding");
+ const has_content_length = req.headers.contains("content-length");
+
+ if (!has_transfer_encoding and !has_content_length) {
+ switch (req.transfer_encoding) {
+ .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"),
+ .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}),
+ .none => {},
+ }
+ } else {
+ if (has_content_length) {
+ const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
+
+ req.transfer_encoding = .{ .content_length = content_length };
+ } else if (has_transfer_encoding) {
+ const transfer_encoding = req.headers.getFirstValue("content-length").?;
+ if (std.mem.eql(u8, transfer_encoding, "chunked")) {
+ req.transfer_encoding = .chunked;
+ } else {
+ return error.UnsupportedTransferEncoding;
+ }
+ } else {
+ req.transfer_encoding = .none;
+ }
+ }
+
+ try w.print("{}", .{req.headers});
+
+ try w.writeAll("\r\n");
+
+ try buffered.flush();
+ }
+
+ pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
+
+ pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
+
+ pub fn transferReader(req: *Request) TransferReader {
+ return .{ .context = req };
+ }
+
+ pub fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
+ if (req.response.parser.done) return 0;
+
+ var index: usize = 0;
+ while (index == 0) {
+ const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
+ if (amt == 0 and req.response.parser.done) break;
+ index += amt;
+ }
+
+ return index;
+ }
+
+ pub const WaitError = RequestError || StartError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, CannotRedirect, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
+
+ /// Waits for a response from the server and parses any headers that are sent.
+ /// This function will block until the final response is received.
+ ///
+ /// If `handle_redirects` is true and the request has no payload, then this function will automatically follow
+ /// redirects. If a request payload is present, then this function will error with error.CannotRedirect.
+ pub fn wait(req: *Request) WaitError!void {
+ while (true) { // handle redirects
+ while (true) { // read headers
+ try req.connection.data.fill();
+
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
+ req.connection.data.drop(@intCast(u16, nchecked));
+
+ if (req.response.parser.state.isContent()) break;
+ }
+
+ try req.response.parse(req.response.parser.header_bytes.items, false);
+
+ if (req.response.status == .switching_protocols) {
+ req.connection.data.closing = false;
+ req.response.parser.done = true;
+ }
+
+ if (req.method == .CONNECT and req.response.status == .ok) {
+ req.connection.data.closing = false;
+ req.response.parser.done = true;
+ }
+
+ // we default to using keep-alive if not provided
+ const req_connection = req.headers.getFirstValue("connection");
+ const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
+
+ const res_connection = req.response.headers.getFirstValue("connection");
+ const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
+ if (res_keepalive and (req_keepalive or req_connection == null)) {
+ req.connection.data.closing = false;
+ } else {
+ req.connection.data.closing = true;
+ }
+
+ if (req.response.transfer_encoding) |te| {
+ switch (te) {
+ .chunked => {
+ req.response.parser.next_chunk_length = 0;
+ req.response.parser.state = .chunk_head_size;
+ },
+ }
+ } else if (req.response.content_length) |cl| {
+ req.response.parser.next_chunk_length = cl;
+
+ if (cl == 0) req.response.parser.done = true;
+ } else {
+ req.response.parser.done = true;
+ }
+
+ // HEAD requests have no body
+ if (req.method == .HEAD) {
+ req.response.parser.done = true;
+ }
+
+ if (req.transfer_encoding == .none and req.response.status.class() == .redirect and req.handle_redirects) {
+ req.response.skip = true;
+
+ const empty = @as([*]u8, undefined)[0..0];
+ assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary
+
+ if (req.redirects_left == 0) return error.TooManyHttpRedirects;
+
+ const location = req.response.headers.getFirstValue("location") orelse
+ return error.HttpRedirectMissingLocation;
+
+ const arena = req.arena.allocator();
+
+ const location_duped = try arena.dupe(u8, location);
+
+ const new_url = Uri.parse(location_duped) catch try Uri.parseWithoutScheme(location_duped);
+ const resolved_url = try req.uri.resolve(new_url, false, arena);
+
+ try req.redirect(resolved_url);
+
+ try req.start();
+ } else {
+ req.response.skip = false;
+ if (!req.response.parser.done) {
+ if (req.response.transfer_compression) |tc| switch (tc) {
+ .compress => return error.CompressionNotSupported,
+ .deflate => req.response.compression = .{
+ .deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
+ },
+ .gzip => req.response.compression = .{
+ .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
+ },
+ .zstd => req.response.compression = .{
+ .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()),
+ },
+ };
+ }
+
+ if (req.response.status.class() == .redirect and req.handle_redirects and req.transfer_encoding != .none)
+ return error.CannotRedirect; // The request body has already been sent. The request is still in a valid state, but the redirect must be handled manually.
+
+ break;
+ }
+ }
+ }
+
+ pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers };
+
+ pub const Reader = std.io.Reader(*Request, ReadError, read);
+
+ pub fn reader(req: *Request) Reader {
+ return .{ .context = req };
+ }
+
+ /// Reads data from the response body. Must be called after `do`.
+ pub fn read(req: *Request, buffer: []u8) ReadError!usize {
+ const out_index = switch (req.response.compression) {
+ .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
+ .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
+ .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
+ else => try req.transferRead(buffer),
+ };
+
+ if (out_index == 0) {
+ const has_trail = !req.response.parser.state.isContent();
+
+ while (!req.response.parser.state.isContent()) { // read trailing headers
+ try req.connection.data.fill();
+
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
+ req.connection.data.drop(@intCast(u16, nchecked));
+ }
+
+ if (has_trail) {
+ req.response.headers.clearRetainingCapacity();
+
+ // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error.
+ // This will *only* fail for a malformed trailer.
+ req.response.parse(req.response.parser.header_bytes.items, true) catch return error.InvalidTrailers;
+ }
+ }
+
+ return out_index;
+ }
+
+ /// Reads data from the response body. Must be called after `do`.
+ pub fn readAll(req: *Request, buffer: []u8) !usize {
+ var index: usize = 0;
+ while (index < buffer.len) {
+ const amt = try read(req, buffer[index..]);
+ if (amt == 0) break;
+ index += amt;
+ }
+ return index;
+ }
+
+ pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
+
+ pub const Writer = std.io.Writer(*Request, WriteError, write);
+
+ pub fn writer(req: *Request) Writer {
+ return .{ .context = req };
+ }
+
+ /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
+ pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
+ switch (req.transfer_encoding) {
+ .chunked => {
+ try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
+ try req.connection.data.writeAll(bytes);
+ try req.connection.data.writeAll("\r\n");
+
+ return bytes.len;
+ },
+ .content_length => |*len| {
+ if (len.* < bytes.len) return error.MessageTooLong;
+
+ const amt = try req.connection.data.write(bytes);
+ len.* -= amt;
+ return amt;
+ },
+ .none => return error.NotWriteable,
+ }
+ }
+
+ pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try write(req, bytes[index..]);
+ }
+ }
+
+ pub const FinishError = WriteError || error{MessageNotCompleted};
+
+ /// Finish the body of a request. This notifies the server that you have no more data to send.
+ pub fn finish(req: *Request) FinishError!void {
+ switch (req.transfer_encoding) {
+ .chunked => try req.connection.data.writeAll("0\r\n\r\n"),
+ .content_length => |len| if (len != 0) return error.MessageNotCompleted,
+ .none => {},
+ }
+ }
+};
+
+pub const HttpProxy = struct {
+ pub const ProxyAuthentication = union(enum) {
+ basic: []const u8,
+ custom: []const u8,
+ };
+
+ protocol: Connection.Protocol,
+ host: []const u8,
+ port: ?u16 = null,
+
+ /// The value for the Proxy-Authorization header.
+ auth: ?ProxyAuthentication = null,
+};
+
+/// Release all associated resources with the client.
+/// TODO: currently leaks all request allocated data
pub fn deinit(client: *Client) void {
client.connection_pool.deinit(client);
@@ -227,9 +862,11 @@ pub fn deinit(client: *Client) void {
client.* = undefined;
}
-pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
+pub const ConnectUnproxiedError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
-pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
+/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
+/// This function is threadsafe.
+pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectUnproxiedError!*ConnectionPool.Node {
if (client.connection_pool.findConnection(.{
.host = host,
.port = port,
@@ -241,19 +878,36 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
+ const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
+ error.ConnectionRefused => return error.ConnectionRefused,
+ error.NetworkUnreachable => return error.NetworkUnreachable,
+ error.ConnectionTimedOut => return error.ConnectionTimedOut,
+ error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+ error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure,
+ error.NameServerFailure => return error.NameServerFailure,
+ error.UnknownHostName => return error.UnknownHostName,
+ error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses,
+ else => return error.UnexpectedConnectFailure,
+ };
+ errdefer stream.close();
+
conn.data = .{
- .stream = try net.tcpConnectToHost(client.allocator, host, port),
+ .stream = stream,
.tls_client = undefined,
.protocol = protocol,
+
.host = try client.allocator.dupe(u8, host),
.port = port,
};
+ errdefer client.allocator.free(conn.data.host);
switch (protocol) {
.plain => {},
.tls => {
conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
- conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
+ errdefer client.allocator.destroy(conn.data.tls_client);
+
+ conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
conn.data.tls_client.allow_truncation_attacks = true;
@@ -265,24 +919,76 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
return conn;
}
-pub const RequestError = ConnectError || Connection.WriteError || error{
+// Prevents a dependency loop in request()
+const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused };
+pub const ConnectError = ConnectErrorPartial || RequestError;
+
+pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
+ if (client.connection_pool.findConnection(.{
+ .host = host,
+ .port = port,
+ .is_tls = protocol == .tls,
+ })) |node|
+ return node;
+
+ if (client.proxy) |proxy| {
+ const proxy_port: u16 = proxy.port orelse switch (proxy.protocol) {
+ .plain => 80,
+ .tls => 443,
+ };
+
+ const conn = try client.connectUnproxied(proxy.host, proxy_port, proxy.protocol);
+ conn.data.proxied = true;
+
+ return conn;
+ } else {
+ return client.connectUnproxied(host, port, protocol);
+ }
+}
+
+pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{
UnsupportedUrlScheme,
UriMissingHost,
- CertificateAuthorityBundleTooBig,
- InvalidPadding,
- MissingEndCertificateMarker,
- Unseekable,
- EndOfStream,
+ CertificateBundleLoadFailure,
+ UnsupportedTransferEncoding,
};
-pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
- const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http"))
- .plain
- else if (mem.eql(u8, uri.scheme, "https"))
- .tls
- else
- return error.UnsupportedUrlScheme;
+pub const Options = struct {
+ version: http.Version = .@"HTTP/1.1",
+
+ handle_redirects: bool = true,
+ max_redirects: u32 = 3,
+ header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
+
+ /// Must be an already acquired connection.
+ connection: ?*ConnectionPool.Node = null,
+
+ pub const HeaderStrategy = union(enum) {
+ /// In this case, the client's Allocator will be used to store the
+ /// entire HTTP header. This value is the maximum total size of
+ /// HTTP headers allowed, otherwise
+ /// error.HttpHeadersExceededSizeLimit is returned from read().
+ dynamic: usize,
+ /// This is used to store the entire HTTP header. If the HTTP
+ /// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
+ /// is returned from read(). When this is used, `error.OutOfMemory`
+ /// cannot be returned from `read()`.
+ static: []u8,
+ };
+};
+
+pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{
+ .{ "http", .plain },
+ .{ "ws", .plain },
+ .{ "https", .tls },
+ .{ "wss", .tls },
+});
+
+/// Form and send a http request to a server.
+/// This function is threadsafe.
+pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: Options) RequestError!Request {
+ const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
const port: u16 = uri.port orelse switch (protocol) {
.plain => 80,
@@ -291,92 +997,43 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
const host = uri.host orelse return error.UriMissingHost;
- if (client.next_https_rescan_certs and protocol == .tls) {
- client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
- defer client.connection_pool.mutex.unlock();
+ if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) {
+ client.ca_bundle_mutex.lock();
+ defer client.ca_bundle_mutex.unlock();
if (client.next_https_rescan_certs) {
- try client.ca_bundle.rescan(client.allocator);
- client.next_https_rescan_certs = false;
+ client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure;
+ @atomicStore(bool, &client.next_https_rescan_certs, false, .Release);
}
}
+ const conn = options.connection orelse try client.connect(host, port, protocol);
+
var req: Request = .{
.uri = uri,
.client = client,
+ .connection = conn,
.headers = headers,
- .connection = try client.connect(host, port, protocol),
+ .method = method,
+ .version = options.version,
.redirects_left = options.max_redirects,
.handle_redirects = options.handle_redirects,
- .compression_init = false,
- .response = switch (options.header_strategy) {
- .dynamic => |max| Response.initDynamic(max),
- .static => |buf| Response.initStatic(buf),
+ .response = .{
+ .status = undefined,
+ .reason = undefined,
+ .version = undefined,
+ .headers = http.Headers{ .allocator = client.allocator, .owned = false },
+ .parser = switch (options.header_strategy) {
+ .dynamic => |max| proto.HeadersParser.initDynamic(max),
+ .static => |buf| proto.HeadersParser.initStatic(buf),
+ },
},
.arena = undefined,
};
+ errdefer req.deinit();
req.arena = std.heap.ArenaAllocator.init(client.allocator);
- {
- var buffered = std.io.bufferedWriter(req.connection.data.writer());
- const writer = buffered.writer();
-
- const escaped_path = try Uri.escapePath(client.allocator, uri.path);
- defer client.allocator.free(escaped_path);
-
- const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
- defer if (escaped_query) |q| client.allocator.free(q);
-
- const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
- defer if (escaped_fragment) |f| client.allocator.free(f);
-
- try writer.writeAll(@tagName(headers.method));
- try writer.writeByte(' ');
- if (escaped_path.len == 0) {
- try writer.writeByte('/');
- } else {
- try writer.writeAll(escaped_path);
- }
- if (escaped_query) |q| {
- try writer.writeByte('?');
- try writer.writeAll(q);
- }
- if (escaped_fragment) |f| {
- try writer.writeByte('#');
- try writer.writeAll(f);
- }
- try writer.writeByte(' ');
- try writer.writeAll(@tagName(headers.version));
- try writer.writeAll("\r\nHost: ");
- try writer.writeAll(host);
- try writer.writeAll("\r\nUser-Agent: ");
- try writer.writeAll(headers.user_agent);
- if (headers.connection == .close) {
- try writer.writeAll("\r\nConnection: close");
- } else {
- try writer.writeAll("\r\nConnection: keep-alive");
- }
- try writer.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd");
-
- switch (headers.transfer_encoding) {
- .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"),
- .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}),
- .none => {},
- }
-
- for (headers.custom) |header| {
- try writer.writeAll("\r\n");
- try writer.writeAll(header.name);
- try writer.writeAll(": ");
- try writer.writeAll(header.value);
- }
-
- try writer.writeAll("\r\n\r\n");
-
- try buffered.flush();
- }
-
return req;
}
@@ -390,5 +1047,5 @@ test {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
- _ = Request;
+ std.testing.refAllDecls(@This());
}