aboutsummaryrefslogtreecommitdiff
path: root/lib/std/http/Client.zig
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2023-04-09 10:44:52 -0400
committerGitHub <noreply@github.com>2023-04-09 10:44:52 -0400
commit2ee328995a70c5c446f24c5593e0fad760e6d839 (patch)
tree0e547171b7790ffd182fc298d384ef614571e97e /lib/std/http/Client.zig
parentc22a30ac99b9a2b92d9a8e926b9bf0c9dbc3d14e (diff)
parent7f9a4625fda0b1a33177cdd66819f0a061c6b2da (diff)
downloadzig-2ee328995a70c5c446f24c5593e0fad760e6d839.tar.gz
zig-2ee328995a70c5c446f24c5593e0fad760e6d839.zip
Merge pull request #15123 from truemedian/http-server
std.http: add http server
Diffstat (limited to 'lib/std/http/Client.zig')
-rw-r--r--lib/std/http/Client.zig880
1 files changed, 759 insertions, 121 deletions
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
index 76073c0ce3..010c557f87 100644
--- a/lib/std/http/Client.zig
+++ b/lib/std/http/Client.zig
@@ -1,49 +1,105 @@
-//! TODO: send connection: keep-alive and LRU cache a configurable number of
-//! open connections to skip DNS and TLS handshake for subsequent requests.
-//!
-//! This API is *not* thread safe.
+//! Connecting and opening requests are threadsafe. Individual requests are not.
const std = @import("../std.zig");
-const mem = std.mem;
-const assert = std.debug.assert;
+const testing = std.testing;
const http = std.http;
+const mem = std.mem;
const net = std.net;
-const Client = @This();
const Uri = std.Uri;
-const Allocator = std.mem.Allocator;
-const testing = std.testing;
+const Allocator = mem.Allocator;
+const assert = std.debug.assert;
-pub const Request = @import("Client/Request.zig");
-pub const Response = @import("Client/Response.zig");
+const Client = @This();
+const proto = @import("protocol.zig");
pub const default_connection_pool_size = 32;
-const connection_pool_size = std.options.http_connection_pool_size;
+pub const connection_pool_size = std.options.http_connection_pool_size;
-/// Used for tcpConnectToHost and storing HTTP headers when an externally
-/// managed buffer is not provided.
allocator: Allocator,
ca_bundle: std.crypto.Certificate.Bundle = .{},
+ca_bundle_mutex: std.Thread.Mutex = .{},
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
next_https_rescan_certs: bool = true,
+/// The pool of connections that can be reused (and currently in use).
connection_pool: ConnectionPool = .{},
+/// The last error that occurred on this client. This is not threadsafe, do not expect it to be completely accurate.
+last_error: ?ExtraError = null,
+
+pub const ExtraError = union(enum) {
+ fn impliedErrorSet(comptime f: anytype) type {
+ const set = @typeInfo(@typeInfo(@TypeOf(f)).Fn.return_type.?).ErrorUnion.error_set;
+ if (@typeName(set)[0] != '@') @compileError(@typeName(f) ++ " doesn't have an implied error set any more.");
+ return set;
+ }
+
+ // There's apparently a dependency loop with using Client.DeflateDecompressor.
+ const FakeTransferError = proto.HeadersParser.ReadError || error{ReadFailed};
+ const FakeTransferReader = std.io.Reader(void, FakeTransferError, fakeRead);
+ fn fakeRead(ctx: void, buf: []u8) FakeTransferError!usize {
+ _ = .{ buf, ctx };
+ return 0;
+ }
+
+ const FakeDeflateDecompressor = std.compress.zlib.ZlibStream(FakeTransferReader);
+ const FakeGzipDecompressor = std.compress.gzip.Decompress(FakeTransferReader);
+ const FakeZstdDecompressor = std.compress.zstd.DecompressStream(FakeTransferReader, .{});
+
+ pub const TcpConnectError = std.net.TcpConnectToHostError;
+ pub const TlsError = std.crypto.tls.Client.InitError(net.Stream);
+ pub const WriteError = BufferedConnection.WriteError;
+ pub const ReadError = BufferedConnection.ReadError || error{HttpChunkInvalid};
+ pub const CaBundleError = impliedErrorSet(std.crypto.Certificate.Bundle.rescan);
+
+ pub const ZlibInitError = error{ BadHeader, InvalidCompression, InvalidWindowSize, Unsupported, EndOfStream, OutOfMemory } || Request.TransferReadError;
+ pub const GzipInitError = error{ BadHeader, InvalidCompression, OutOfMemory, WrongChecksum, EndOfStream, StreamTooLong } || Request.TransferReadError;
+ // pub const DecompressError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error;
+ pub const DecompressError = FakeDeflateDecompressor.Error || FakeGzipDecompressor.Error || FakeZstdDecompressor.Error;
+
+ zlib_init: ZlibInitError, // error.CompressionInitializationFailed
+ gzip_init: GzipInitError, // error.CompressionInitializationFailed
+ connect: TcpConnectError, // error.ConnectionFailed
+ ca_bundle: CaBundleError, // error.CertificateAuthorityBundleFailed
+ tls: TlsError, // error.TlsInitializationFailed
+ write: WriteError, // error.WriteFailed
+ read: ReadError, // error.ReadFailed
+ decompress: DecompressError, // error.ReadFailed
+};
+
+/// A set of linked lists of connections that can be reused.
pub const ConnectionPool = struct {
+ /// The criteria for a connection to be considered a match.
pub const Criteria = struct {
host: []const u8,
port: u16,
is_tls: bool,
};
- const Queue = std.TailQueue(Connection);
+ pub const StoredConnection = struct {
+ buffered: BufferedConnection,
+ host: []u8,
+ port: u16,
+
+ closing: bool = false,
+
+ pub fn deinit(self: *StoredConnection, client: *Client) void {
+ self.buffered.close(client);
+ client.allocator.free(self.host);
+ }
+ };
+
+ const Queue = std.TailQueue(StoredConnection);
pub const Node = Queue.Node;
mutex: std.Thread.Mutex = .{},
+ /// Open connections that are currently in use.
used: Queue = .{},
+ /// Open connections that are not currently in use.
free: Queue = .{},
free_len: usize = 0,
- free_size: usize = default_connection_pool_size,
+ free_size: usize = connection_pool_size,
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
/// If no connection is found, null is returned.
@@ -53,9 +109,9 @@ pub const ConnectionPool = struct {
var next = pool.free.last;
while (next) |node| : (next = node.prev) {
- if ((node.data.protocol == .tls) != criteria.is_tls) continue;
+ if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue;
if (node.data.port != criteria.port) continue;
- if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
+ if (mem.eql(u8, node.data.host, criteria.host)) continue;
pool.acquireUnsafe(node);
return node;
@@ -89,7 +145,7 @@ pub const ConnectionPool = struct {
pool.used.remove(node);
if (node.data.closing) {
- node.data.close(client);
+ node.data.deinit(client);
return client.allocator.destroy(node);
}
@@ -97,7 +153,7 @@ pub const ConnectionPool = struct {
if (pool.free_len + 1 >= pool.free_size) {
const popped = pool.free.popFirst() orelse unreachable;
- popped.data.close(client);
+ popped.data.deinit(client);
return client.allocator.destroy(popped);
}
@@ -122,7 +178,7 @@ pub const ConnectionPool = struct {
defer client.allocator.destroy(node);
next = node.next;
- node.data.close(client);
+ node.data.deinit(client);
}
next = pool.used.first;
@@ -130,27 +186,19 @@ pub const ConnectionPool = struct {
defer client.allocator.destroy(node);
next = node.next;
- node.data.close(client);
+ node.data.deinit(client);
}
pool.* = undefined;
}
};
-pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
-pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
-pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.ReaderRaw, .{});
-
+/// An interface to either a plain or TLS connection.
pub const Connection = struct {
stream: net.Stream,
/// undefined unless protocol is tls.
- tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
+ tls_client: *std.crypto.tls.Client,
protocol: Protocol,
- host: []u8,
- port: u16,
-
- // This connection has been part of a non keepalive request and cannot be added to the pool.
- closing: bool = false,
pub const Protocol = enum { plain, tls };
@@ -215,11 +263,611 @@ pub const Connection = struct {
}
conn.stream.close();
+ }
+};
+
+/// A buffered (and peekable) Connection.
+pub const BufferedConnection = struct {
+ pub const buffer_size = 0x2000;
+
+ conn: Connection,
+ buf: [buffer_size]u8 = undefined,
+ start: u16 = 0,
+ end: u16 = 0,
+
+ pub fn fill(bconn: *BufferedConnection) ReadError!void {
+ if (bconn.end != bconn.start) return;
+
+ const nread = try bconn.conn.read(bconn.buf[0..]);
+ if (nread == 0) return error.EndOfStream;
+ bconn.start = 0;
+ bconn.end = @truncate(u16, nread);
+ }
+
+ pub fn peek(bconn: *BufferedConnection) []const u8 {
+ return bconn.buf[bconn.start..bconn.end];
+ }
+
+ pub fn clear(bconn: *BufferedConnection, num: u16) void {
+ bconn.start += num;
+ }
+
+ pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
+ var out_index: u16 = 0;
+ while (out_index < len) {
+ const available = bconn.end - bconn.start;
+ const left = buffer.len - out_index;
+
+ if (available > 0) {
+ const can_read = @truncate(u16, @min(available, left));
+
+ std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
+ out_index += can_read;
+ bconn.start += can_read;
+
+ continue;
+ }
+
+ if (left > bconn.buf.len) {
+ // skip the buffer if the output is large enough
+ return bconn.conn.read(buffer[out_index..]);
+ }
+
+ try bconn.fill();
+ }
+
+ return out_index;
+ }
+
+ pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
+ return bconn.readAtLeast(buffer, 1);
+ }
+
+ pub const ReadError = Connection.ReadError || error{EndOfStream};
+ pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
+
+ pub fn reader(bconn: *BufferedConnection) Reader {
+ return Reader{ .context = bconn };
+ }
+
+ pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
+ return bconn.conn.writeAll(buffer);
+ }
+
+ pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
+ return bconn.conn.write(buffer);
+ }
+
+ pub const WriteError = Connection.WriteError;
+ pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
+
+ pub fn writer(bconn: *BufferedConnection) Writer {
+ return Writer{ .context = bconn };
+ }
+
+ pub fn close(bconn: *BufferedConnection, client: *const Client) void {
+ bconn.conn.close(client);
+ }
+};
+
+/// The mode of transport for requests.
+pub const RequestTransfer = union(enum) {
+ content_length: u64,
+ chunked: void,
+ none: void,
+};
+
+/// The decompressor for response messages.
+pub const Compression = union(enum) {
+ pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.TransferReader);
+ pub const GzipDecompressor = std.compress.gzip.Decompress(Request.TransferReader);
+ pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{});
+
+ deflate: DeflateDecompressor,
+ gzip: GzipDecompressor,
+ zstd: ZstdDecompressor,
+ none: void,
+};
+
+/// A HTTP response originating from a server.
+pub const Response = struct {
+ pub const Headers = struct {
+ status: http.Status,
+ version: http.Version,
+ location: ?[]const u8 = null,
+ content_length: ?u64 = null,
+ transfer_encoding: ?http.TransferEncoding = null,
+ transfer_compression: ?http.ContentEncoding = null,
+ connection: http.Connection = .close,
+ upgrade: ?[]const u8 = null,
+
+ pub const ParseError = error{
+ ShortHttpStatusLine,
+ BadHttpVersion,
+ HttpHeadersInvalid,
+ HttpHeaderContinuationsUnsupported,
+ HttpTransferEncodingUnsupported,
+ HttpConnectionHeaderUnsupported,
+ InvalidContentLength,
+ CompressionNotSupported,
+ };
+
+ pub fn parse(bytes: []const u8) ParseError!Headers {
+ var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
+
+ const first_line = it.next() orelse return error.HttpHeadersInvalid;
+ if (first_line.len < 12)
+ return error.ShortHttpStatusLine;
+
+ const version: http.Version = switch (int64(first_line[0..8])) {
+ int64("HTTP/1.0") => .@"HTTP/1.0",
+ int64("HTTP/1.1") => .@"HTTP/1.1",
+ else => return error.BadHttpVersion,
+ };
+ if (first_line[8] != ' ') return error.HttpHeadersInvalid;
+ const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
+
+ var headers: Headers = .{
+ .version = version,
+ .status = status,
+ };
+
+ while (it.next()) |line| {
+ if (line.len == 0) return error.HttpHeadersInvalid;
+ switch (line[0]) {
+ ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
+ else => {},
+ }
+
+ var line_it = mem.tokenize(u8, line, ": ");
+ const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
+ const header_value = line_it.rest();
+ if (std.ascii.eqlIgnoreCase(header_name, "location")) {
+ if (headers.location != null) return error.HttpHeadersInvalid;
+ headers.location = header_value;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
+ if (headers.content_length != null) return error.HttpHeadersInvalid;
+ headers.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
+ // Transfer-Encoding: second, first
+ // Transfer-Encoding: deflate, chunked
+ var iter = mem.splitBackwards(u8, header_value, ",");
+
+ if (iter.next()) |first| {
+ const trimmed = mem.trim(u8, first, " ");
+
+ if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
+ if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
+ headers.transfer_encoding = te;
+ } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |second| {
+ if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
+
+ const trimmed = mem.trim(u8, second, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
+ if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
+
+ const trimmed = mem.trim(u8, header_value, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
+ if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
+ headers.connection = .keep_alive;
+ } else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
+ headers.connection = .close;
+ } else {
+ return error.HttpConnectionHeaderUnsupported;
+ }
+ } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
+ headers.upgrade = header_value;
+ }
+ }
+
+ return headers;
+ }
+
+ inline fn int64(array: *const [8]u8) u64 {
+ return @bitCast(u64, array.*);
+ }
+
+ fn parseInt3(nnn: @Vector(3, u8)) u10 {
+ const zero: @Vector(3, u8) = .{ '0', '0', '0' };
+ const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
+ return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
+ }
+
+ test parseInt3 {
+ const expectEqual = testing.expectEqual;
+ try expectEqual(@as(u10, 0), parseInt3("000".*));
+ try expectEqual(@as(u10, 418), parseInt3("418".*));
+ try expectEqual(@as(u10, 999), parseInt3("999".*));
+ }
+ };
+
+ headers: Headers = undefined,
+ parser: proto.HeadersParser,
+ compression: Compression = .none,
+ skip: bool = false,
+};
+
+/// A HTTP request that has been sent.
+///
+/// Order of operations: request[ -> write -> finish] -> do -> read
+pub const Request = struct {
+ pub const Headers = struct {
+ version: http.Version = .@"HTTP/1.1",
+ method: http.Method = .GET,
+ user_agent: []const u8 = "zig (std.http)",
+ connection: http.Connection = .keep_alive,
+ transfer_encoding: RequestTransfer = .none,
+
+ custom: []const http.CustomHeader = &[_]http.CustomHeader{},
+ };
+
+ uri: Uri,
+ client: *Client,
+ connection: *ConnectionPool.Node,
+ /// These are stored in Request so that they are available when following
+ /// redirects.
+ headers: Headers,
+
+ redirects_left: u32,
+ handle_redirects: bool,
+
+ response: Response,
+
+ /// Used as a allocator for resolving redirects locations.
+ arena: std.heap.ArenaAllocator,
+
+ /// Frees all resources associated with the request.
+ pub fn deinit(req: *Request) void {
+ switch (req.response.compression) {
+ .none => {},
+ .deflate => |*deflate| deflate.deinit(),
+ .gzip => |*gzip| gzip.deinit(),
+ .zstd => |*zstd| zstd.deinit(),
+ }
+
+ if (req.response.parser.header_bytes_owned) {
+ req.response.parser.header_bytes.deinit(req.client.allocator);
+ }
+
+ if (!req.response.parser.done) {
+ // If the response wasn't fully read, then we need to close the connection.
+ req.connection.data.closing = true;
+ req.client.connection_pool.release(req.client, req.connection);
+ }
+
+ req.arena.deinit();
+ req.* = undefined;
+ }
- client.allocator.free(conn.host);
+ pub fn start(req: *Request, uri: Uri, headers: Headers) !void {
+ var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
+ const w = buffered.writer();
+
+ const escaped_path = try Uri.escapePath(req.client.allocator, uri.path);
+ defer req.client.allocator.free(escaped_path);
+
+ const escaped_query = if (uri.query) |q| try Uri.escapeQuery(req.client.allocator, q) else null;
+ defer if (escaped_query) |q| req.client.allocator.free(q);
+
+ const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(req.client.allocator, f) else null;
+ defer if (escaped_fragment) |f| req.client.allocator.free(f);
+
+ try w.writeAll(@tagName(headers.method));
+ try w.writeByte(' ');
+ if (escaped_path.len == 0) {
+ try w.writeByte('/');
+ } else {
+ try w.writeAll(escaped_path);
+ }
+ if (escaped_query) |q| {
+ try w.writeByte('?');
+ try w.writeAll(q);
+ }
+ if (escaped_fragment) |f| {
+ try w.writeByte('#');
+ try w.writeAll(f);
+ }
+ try w.writeByte(' ');
+ try w.writeAll(@tagName(headers.version));
+ try w.writeAll("\r\nHost: ");
+ try w.writeAll(uri.host.?);
+ try w.writeAll("\r\nUser-Agent: ");
+ try w.writeAll(headers.user_agent);
+ if (headers.connection == .close) {
+ try w.writeAll("\r\nConnection: close");
+ } else {
+ try w.writeAll("\r\nConnection: keep-alive");
+ }
+ try w.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd");
+ try w.writeAll("\r\nTE: gzip, deflate"); // TODO: add trailers when someone finds a nice way to integrate them without completely invalidating all pointers to headers.
+
+ switch (headers.transfer_encoding) {
+ .chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"),
+ .content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}),
+ .none => {},
+ }
+
+ for (headers.custom) |header| {
+ try w.writeAll("\r\n");
+ try w.writeAll(header.name);
+ try w.writeAll(": ");
+ try w.writeAll(header.value);
+ }
+
+ try w.writeAll("\r\n\r\n");
+
+ try buffered.flush();
+ }
+
+ pub const TransferReadError = proto.HeadersParser.ReadError || error{ReadFailed};
+
+ pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
+
+ pub fn transferReader(req: *Request) TransferReader {
+ return .{ .context = req };
+ }
+
+ pub fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
+ if (req.response.parser.done) return 0;
+
+ var index: usize = 0;
+ while (index == 0) {
+ const amt = req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip) catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+ if (amt == 0 and req.response.parser.done) break;
+ index += amt;
+ }
+
+ return index;
+ }
+
+ pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.Headers.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed };
+
+ /// Waits for a response from the server and parses any headers that are sent.
+ /// This function will block until the final response is received.
+ ///
+ /// If `handle_redirects` is true, then this function will automatically follow
+ /// redirects.
+ pub fn do(req: *Request) DoError!void {
+ while (true) { // handle redirects
+ while (true) { // read headers
+ req.connection.data.buffered.fill() catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
+ req.connection.data.buffered.clear(@intCast(u16, nchecked));
+
+ if (req.response.parser.state.isContent()) break;
+ }
+
+ req.response.headers = try Response.Headers.parse(req.response.parser.header_bytes.items);
+
+ if (req.response.headers.status == .switching_protocols) {
+ req.connection.data.closing = false;
+ req.response.parser.done = true;
+ }
+
+ if (req.headers.connection == .keep_alive and req.response.headers.connection == .keep_alive) {
+ req.connection.data.closing = false;
+ } else {
+ req.connection.data.closing = true;
+ }
+
+ if (req.response.headers.transfer_encoding) |te| {
+ switch (te) {
+ .chunked => {
+ req.response.parser.next_chunk_length = 0;
+ req.response.parser.state = .chunk_head_size;
+ },
+ }
+ } else if (req.response.headers.content_length) |cl| {
+ req.response.parser.next_chunk_length = cl;
+
+ if (cl == 0) req.response.parser.done = true;
+ } else {
+ req.response.parser.done = true;
+ }
+
+ if (req.response.headers.status.class() == .redirect and req.handle_redirects) {
+ req.response.skip = true;
+
+ const empty = @as([*]u8, undefined)[0..0];
+ assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary
+
+ if (req.redirects_left == 0) return error.TooManyHttpRedirects;
+
+ const location = req.response.headers.location orelse
+ return error.HttpRedirectMissingLocation;
+ const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
+
+ var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
+ const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
+ errdefer new_arena.deinit();
+
+ req.arena.deinit();
+ req.arena = new_arena;
+
+ const new_req = try req.client.request(resolved_url, req.headers, .{
+ .max_redirects = req.redirects_left - 1,
+ .header_strategy = if (req.response.parser.header_bytes_owned) .{
+ .dynamic = req.response.parser.max_header_bytes,
+ } else .{
+ .static = req.response.parser.header_bytes.items.ptr[0..req.response.parser.max_header_bytes],
+ },
+ });
+ req.deinit();
+ req.* = new_req;
+ } else {
+ req.response.skip = false;
+ if (!req.response.parser.done) {
+ if (req.response.headers.transfer_compression) |tc| switch (tc) {
+ .compress => return error.CompressionNotSupported,
+ .deflate => req.response.compression = .{
+ .deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch |err| {
+ req.client.last_error = .{ .zlib_init = err };
+ return error.CompressionInitializationFailed;
+ },
+ },
+ .gzip => req.response.compression = .{
+ .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch |err| {
+ req.client.last_error = .{ .gzip_init = err };
+ return error.CompressionInitializationFailed;
+ },
+ },
+ .zstd => req.response.compression = .{
+ .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()),
+ },
+ };
+ }
+
+ break;
+ }
+ }
+ }
+
+ pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError;
+
+ pub const Reader = std.io.Reader(*Request, ReadError, read);
+
+ pub fn reader(req: *Request) Reader {
+ return .{ .context = req };
+ }
+
+ /// Reads data from the response body. Must be called after `do`.
+ pub fn read(req: *Request, buffer: []u8) ReadError!usize {
+ while (true) {
+ const out_index = switch (req.response.compression) {
+ .deflate => |*deflate| deflate.read(buffer) catch |err| {
+ req.client.last_error = .{ .decompress = err };
+ err catch {};
+ return error.ReadFailed;
+ },
+ .gzip => |*gzip| gzip.read(buffer) catch |err| {
+ req.client.last_error = .{ .decompress = err };
+ err catch {};
+ return error.ReadFailed;
+ },
+ .zstd => |*zstd| zstd.read(buffer) catch |err| {
+ req.client.last_error = .{ .decompress = err };
+ err catch {};
+ return error.ReadFailed;
+ },
+ else => try req.transferRead(buffer),
+ };
+
+ if (out_index == 0) {
+ while (!req.response.parser.state.isContent()) { // read trailing headers
+ req.connection.data.buffered.fill() catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
+ req.connection.data.buffered.clear(@intCast(u16, nchecked));
+ }
+ }
+
+ return out_index;
+ }
+ }
+
+ /// Reads data from the response body. Must be called after `do`.
+ pub fn readAll(req: *Request, buffer: []u8) !usize {
+ var index: usize = 0;
+ while (index < buffer.len) {
+ const amt = read(req, buffer[index..]) catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+ if (amt == 0) break;
+ index += amt;
+ }
+ return index;
+ }
+
+ pub const WriteError = error{ WriteFailed, NotWriteable, MessageTooLong };
+
+ pub const Writer = std.io.Writer(*Request, WriteError, write);
+
+ pub fn writer(req: *Request) Writer {
+ return .{ .context = req };
+ }
+
+ /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
+ pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
+ switch (req.headers.transfer_encoding) {
+ .chunked => {
+ req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len}) catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+ req.connection.data.conn.writeAll(bytes) catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+ req.connection.data.conn.writeAll("\r\n") catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+
+ return bytes.len;
+ },
+ .content_length => |*len| {
+ if (len.* < bytes.len) return error.MessageTooLong;
+
+ const amt = req.connection.data.conn.write(bytes) catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+ len.* -= amt;
+ return amt;
+ },
+ .none => return error.NotWriteable,
+ }
+ }
+
+ /// Finish the body of a request. This notifies the server that you have no more data to send.
+ pub fn finish(req: *Request) !void {
+ switch (req.headers.transfer_encoding) {
+ .chunked => req.connection.data.conn.writeAll("0\r\n") catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ },
+ .content_length => |len| if (len != 0) return error.MessageNotCompleted,
+ .none => {},
+ }
}
};
+/// Release all associated resources with the client.
+/// TODO: currently leaks all request allocated data
pub fn deinit(client: *Client) void {
client.connection_pool.deinit(client);
@@ -227,8 +875,10 @@ pub fn deinit(client: *Client) void {
client.* = undefined;
}
-pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
+pub const ConnectError = Allocator.Error || error{ ConnectionFailed, TlsInitializationFailed };
+/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
+/// This function is threadsafe.
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
if (client.connection_pool.findConnection(.{
.host = host,
@@ -241,22 +891,36 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
+ const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| {
+ client.last_error = .{ .connect = err };
+ return error.ConnectionFailed;
+ };
+ errdefer stream.close();
+
conn.data = .{
- .stream = try net.tcpConnectToHost(client.allocator, host, port),
- .tls_client = undefined,
- .protocol = protocol,
+ .buffered = .{ .conn = .{
+ .stream = stream,
+ .tls_client = undefined,
+ .protocol = protocol,
+ } },
.host = try client.allocator.dupe(u8, host),
.port = port,
};
+ errdefer client.allocator.free(conn.data.host);
switch (protocol) {
.plain => {},
.tls => {
- conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
- conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
+ conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
+ errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client);
+
+ conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch |err| {
+ client.last_error = .{ .tls = err };
+ return error.TlsInitializationFailed;
+ };
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
- conn.data.tls_client.allow_truncation_attacks = true;
+ conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
},
}
@@ -265,24 +929,44 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
return conn;
}
-pub const RequestError = ConnectError || Connection.WriteError || error{
+pub const RequestError = ConnectError || error{
UnsupportedUrlScheme,
UriMissingHost,
- CertificateAuthorityBundleTooBig,
- InvalidPadding,
- MissingEndCertificateMarker,
- Unseekable,
- EndOfStream,
+ CertificateAuthorityBundleFailed,
+ WriteFailed,
+};
+
+pub const Options = struct {
+ handle_redirects: bool = true,
+ max_redirects: u32 = 3,
+ header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
+
+ pub const HeaderStrategy = union(enum) {
+ /// In this case, the client's Allocator will be used to store the
+ /// entire HTTP header. This value is the maximum total size of
+ /// HTTP headers allowed, otherwise
+ /// error.HttpHeadersExceededSizeLimit is returned from read().
+ dynamic: usize,
+ /// This is used to store the entire HTTP header. If the HTTP
+ /// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
+ /// is returned from read(). When this is used, `error.OutOfMemory`
+ /// cannot be returned from `read()`.
+ static: []u8,
+ };
};
-pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
- const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http"))
- .plain
- else if (mem.eql(u8, uri.scheme, "https"))
- .tls
- else
- return error.UnsupportedUrlScheme;
+pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{
+ .{ "http", .plain },
+ .{ "ws", .plain },
+ .{ "https", .tls },
+ .{ "wss", .tls },
+});
+
+/// Form and send a http request to a server.
+/// This function is threadsafe.
+pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Options) RequestError!Request {
+ const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
const port: u16 = uri.port orelse switch (protocol) {
.plain => 80,
@@ -291,91 +975,45 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
const host = uri.host orelse return error.UriMissingHost;
- if (client.next_https_rescan_certs and protocol == .tls) {
- client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
- defer client.connection_pool.mutex.unlock();
+ if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) {
+ client.ca_bundle_mutex.lock();
+ defer client.ca_bundle_mutex.unlock();
if (client.next_https_rescan_certs) {
- try client.ca_bundle.rescan(client.allocator);
- client.next_https_rescan_certs = false;
+ client.ca_bundle.rescan(client.allocator) catch |err| {
+ client.last_error = .{ .ca_bundle = err };
+ return error.CertificateAuthorityBundleFailed;
+ };
+ @atomicStore(bool, &client.next_https_rescan_certs, false, .Release);
}
}
var req: Request = .{
.uri = uri,
.client = client,
- .headers = headers,
.connection = try client.connect(host, port, protocol),
+ .headers = headers,
.redirects_left = options.max_redirects,
.handle_redirects = options.handle_redirects,
- .compression_init = false,
- .response = switch (options.header_strategy) {
- .dynamic => |max| Response.initDynamic(max),
- .static => |buf| Response.initStatic(buf),
+ .response = .{
+ .parser = switch (options.header_strategy) {
+ .dynamic => |max| proto.HeadersParser.initDynamic(max),
+ .static => |buf| proto.HeadersParser.initStatic(buf),
+ },
},
.arena = undefined,
};
+ errdefer req.deinit();
req.arena = std.heap.ArenaAllocator.init(client.allocator);
- {
- var buffered = std.io.bufferedWriter(req.connection.data.writer());
- const writer = buffered.writer();
-
- const escaped_path = try Uri.escapePath(client.allocator, uri.path);
- defer client.allocator.free(escaped_path);
-
- const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
- defer if (escaped_query) |q| client.allocator.free(q);
+ req.start(uri, headers) catch |err| {
+ if (err == error.OutOfMemory) return error.OutOfMemory;
+ const err_casted = @errSetCast(BufferedConnection.WriteError, err);
- const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
- defer if (escaped_fragment) |f| client.allocator.free(f);
-
- try writer.writeAll(@tagName(headers.method));
- try writer.writeByte(' ');
- if (escaped_path.len == 0) {
- try writer.writeByte('/');
- } else {
- try writer.writeAll(escaped_path);
- }
- if (escaped_query) |q| {
- try writer.writeByte('?');
- try writer.writeAll(q);
- }
- if (escaped_fragment) |f| {
- try writer.writeByte('#');
- try writer.writeAll(f);
- }
- try writer.writeByte(' ');
- try writer.writeAll(@tagName(headers.version));
- try writer.writeAll("\r\nHost: ");
- try writer.writeAll(host);
- try writer.writeAll("\r\nUser-Agent: ");
- try writer.writeAll(headers.user_agent);
- if (headers.connection == .close) {
- try writer.writeAll("\r\nConnection: close");
- } else {
- try writer.writeAll("\r\nConnection: keep-alive");
- }
- try writer.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd");
-
- switch (headers.transfer_encoding) {
- .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"),
- .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}),
- .none => {},
- }
-
- for (headers.custom) |header| {
- try writer.writeAll("\r\n");
- try writer.writeAll(header.name);
- try writer.writeAll(": ");
- try writer.writeAll(header.value);
- }
-
- try writer.writeAll("\r\n\r\n");
-
- try buffered.flush();
- }
+ client.last_error = .{ .write = err_casted };
+ return error.WriteFailed;
+ };
return req;
}
@@ -390,5 +1028,5 @@ test {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
- _ = Request;
+ std.testing.refAllDecls(@This());
}