aboutsummaryrefslogtreecommitdiff
path: root/lib/std/http/Client.zig
diff options
context:
space:
mode:
Diffstat (limited to 'lib/std/http/Client.zig')
-rw-r--r--lib/std/http/Client.zig1033
1 files changed, 247 insertions, 786 deletions
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
index 7f62b2d597..baf0239388 100644
--- a/lib/std/http/Client.zig
+++ b/lib/std/http/Client.zig
@@ -13,6 +13,12 @@ const Uri = std.Uri;
const Allocator = std.mem.Allocator;
const testing = std.testing;
+pub const Request = @import("Client/Request.zig");
+pub const Response = @import("Client/Response.zig");
+
+pub const default_connection_pool_size = 32;
+const connection_pool_size = std.options.http_connection_pool_size;
+
/// Used for tcpConnectToHost and storing HTTP headers when an externally
/// managed buffer is not provided.
allocator: Allocator,
@@ -21,854 +27,256 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
/// it will first rescan the system for root certificates.
next_https_rescan_certs: bool = true,
-pub const Connection = struct {
- stream: net.Stream,
- /// undefined unless protocol is tls.
- tls_client: std.crypto.tls.Client,
- protocol: Protocol,
+connection_pool: ConnectionPool = .{},
- pub const Protocol = enum { plain, tls };
+pub const ConnectionPool = struct {
+ pub const Criteria = struct {
+ host: []const u8,
+ port: u16,
+ is_tls: bool,
+ };
- pub fn read(conn: *Connection, buffer: []u8) !usize {
- switch (conn.protocol) {
- .plain => return conn.stream.read(buffer),
- .tls => return conn.tls_client.read(conn.stream, buffer),
+ const Queue = std.TailQueue(Connection);
+ pub const Node = Queue.Node;
+
+ mutex: std.Thread.Mutex = .{},
+ used: Queue = .{},
+ free: Queue = .{},
+ free_len: usize = 0,
+ free_size: usize = default_connection_pool_size,
+
+ /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
+ /// If no connection is found, null is returned.
+ pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+
+ var next = pool.free.last;
+ while (next) |node| : (next = node.prev) {
+ if ((node.data.protocol == .tls) != criteria.is_tls) continue;
+ if (node.data.port != criteria.port) continue;
+ if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
+
+ pool.acquireUnsafe(node);
+ return node;
}
- }
- pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
- switch (conn.protocol) {
- .plain => return conn.stream.readAtLeast(buffer, len),
- .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
- }
+ return null;
}
- pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
- switch (conn.protocol) {
- .plain => return conn.stream.writeAll(buffer),
- .tls => return conn.tls_client.writeAll(conn.stream, buffer),
- }
+ /// Acquires an existing connection from the connection pool. This function is not threadsafe.
+ pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void {
+ pool.free.remove(node);
+ pool.free_len -= 1;
+
+ pool.used.append(node);
}
- pub fn write(conn: *Connection, buffer: []const u8) !usize {
- switch (conn.protocol) {
- .plain => return conn.stream.write(buffer),
- .tls => return conn.tls_client.write(conn.stream, buffer),
- }
+ /// Acquires an existing connection from the connection pool. This function is threadsafe.
+ pub fn acquire(pool: *ConnectionPool, node: *Node) void {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+
+ return pool.acquireUnsafe(node);
}
-};
-/// TODO: emit error.UnexpectedEndOfStream or something like that when the read
-/// data does not match the content length. This is necessary since HTTPS disables
-/// close_notify protection on underlying TLS streams.
-pub const Request = struct {
- client: *Client,
- connection: Connection,
- redirects_left: u32,
- response: Response,
- /// These are stored in Request so that they are available when following
- /// redirects.
- headers: Headers,
-
- pub const Response = struct {
- headers: Response.Headers,
- state: State,
- header_bytes_owned: bool,
- /// This could either be a fixed buffer provided by the API user or it
- /// could be our own array list.
- header_bytes: std.ArrayListUnmanaged(u8),
- max_header_bytes: usize,
- next_chunk_length: u64,
-
- pub const Headers = struct {
- status: http.Status,
- version: http.Version,
- location: ?[]const u8 = null,
- content_length: ?u64 = null,
- transfer_encoding: ?http.TransferEncoding = null,
-
- pub fn parse(bytes: []const u8) !Response.Headers {
- var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
-
- const first_line = it.first();
- if (first_line.len < 12)
- return error.ShortHttpStatusLine;
-
- const version: http.Version = switch (int64(first_line[0..8])) {
- int64("HTTP/1.0") => .@"HTTP/1.0",
- int64("HTTP/1.1") => .@"HTTP/1.1",
- else => return error.BadHttpVersion,
- };
- if (first_line[8] != ' ') return error.HttpHeadersInvalid;
- const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
-
- var headers: Response.Headers = .{
- .version = version,
- .status = status,
- };
-
- while (it.next()) |line| {
- if (line.len == 0) return error.HttpHeadersInvalid;
- switch (line[0]) {
- ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
- else => {},
- }
- var line_it = mem.split(u8, line, ": ");
- const header_name = line_it.first();
- const header_value = line_it.rest();
- if (std.ascii.eqlIgnoreCase(header_name, "location")) {
- if (headers.location != null) return error.HttpHeadersInvalid;
- headers.location = header_value;
- } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
- if (headers.content_length != null) return error.HttpHeadersInvalid;
- headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
- } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
- if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
- headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse
- return error.HttpTransferEncodingUnsupported;
- }
- }
-
- return headers;
- }
-
- test "parse headers" {
- const example =
- "HTTP/1.1 301 Moved Permanently\r\n" ++
- "Location: https://www.example.com/\r\n" ++
- "Content-Type: text/html; charset=UTF-8\r\n" ++
- "Content-Length: 220\r\n\r\n";
- const parsed = try Response.Headers.parse(example);
- try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version);
- try testing.expectEqual(http.Status.moved_permanently, parsed.status);
- try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse
- return error.TestFailed);
- try testing.expectEqual(@as(?u64, 220), parsed.content_length);
- }
-
- test "header continuation" {
- const example =
- "HTTP/1.0 200 OK\r\n" ++
- "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++
- "Content-Length: 220\r\n\r\n";
- try testing.expectError(
- error.HttpHeaderContinuationsUnsupported,
- Response.Headers.parse(example),
- );
- }
-
- test "extra content length" {
- const example =
- "HTTP/1.0 200 OK\r\n" ++
- "Content-Length: 220\r\n" ++
- "Content-Type: text/html; charset=UTF-8\r\n" ++
- "content-length: 220\r\n\r\n";
- try testing.expectError(
- error.HttpHeadersInvalid,
- Response.Headers.parse(example),
- );
- }
- };
-
- pub const State = enum {
- /// Begin header parsing states.
- invalid,
- start,
- seen_r,
- seen_rn,
- seen_rnr,
- finished,
- /// Begin transfer-encoding: chunked parsing states.
- chunk_size_prefix_r,
- chunk_size_prefix_n,
- chunk_size,
- chunk_r,
- chunk_data,
-
- pub fn zeroMeansEnd(state: State) bool {
- return switch (state) {
- .finished, .chunk_data => true,
- else => false,
- };
- }
- };
-
- pub fn initDynamic(max: usize) Response {
- return .{
- .state = .start,
- .headers = undefined,
- .header_bytes = .{},
- .max_header_bytes = max,
- .header_bytes_owned = true,
- .next_chunk_length = undefined,
- };
- }
+ /// Tries to release a connection back to the connection pool. This function is threadsafe.
+ /// If the connection is marked as closing, it will be closed instead.
+ pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
- pub fn initStatic(buf: []u8) Response {
- return .{
- .state = .start,
- .headers = undefined,
- .header_bytes = .{ .items = buf[0..0], .capacity = buf.len },
- .max_header_bytes = buf.len,
- .header_bytes_owned = false,
- .next_chunk_length = undefined,
- };
- }
+ pool.used.remove(node);
- /// Returns how many bytes are part of HTTP headers. Always less than or
- /// equal to bytes.len. If the amount returned is less than bytes.len, it
- /// means the headers ended and the first byte after the double \r\n\r\n is
- /// located at `bytes[result]`.
- pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize {
- var index: usize = 0;
-
- // TODO: https://github.com/ziglang/zig/issues/8220
- state: while (true) {
- switch (r.state) {
- .invalid => unreachable,
- .finished => unreachable,
- .start => while (true) {
- switch (bytes.len - index) {
- 0 => return index,
- 1 => {
- if (bytes[index] == '\r')
- r.state = .seen_r;
- return index + 1;
- },
- 2 => {
- if (int16(bytes[index..][0..2]) == int16("\r\n")) {
- r.state = .seen_rn;
- } else if (bytes[index + 1] == '\r') {
- r.state = .seen_r;
- }
- return index + 2;
- },
- 3 => {
- if (int16(bytes[index..][0..2]) == int16("\r\n") and
- bytes[index + 2] == '\r')
- {
- r.state = .seen_rnr;
- } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) {
- r.state = .seen_rn;
- } else if (bytes[index + 2] == '\r') {
- r.state = .seen_r;
- }
- return index + 3;
- },
- 4...15 => {
- if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) {
- r.state = .finished;
- return index + 4;
- } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and
- bytes[index + 3] == '\r')
- {
- r.state = .seen_rnr;
- index += 4;
- continue :state;
- } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) {
- r.state = .seen_rn;
- index += 4;
- continue :state;
- } else if (bytes[index + 3] == '\r') {
- r.state = .seen_r;
- index += 4;
- continue :state;
- }
- index += 4;
- continue;
- },
- else => {
- const chunk = bytes[index..][0..16];
- const v: @Vector(16, u8) = chunk.*;
- const matches_r = v == @splat(16, @as(u8, '\r'));
- const iota = std.simd.iota(u8, 16);
- const default = @splat(16, @as(u8, 16));
- const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default));
- switch (sub_index) {
- 0...12 => {
- index += sub_index + 4;
- if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) {
- r.state = .finished;
- return index;
- }
- continue;
- },
- 13 => {
- index += 16;
- if (int16(chunk[14..][0..2]) == int16("\n\r")) {
- r.state = .seen_rnr;
- continue :state;
- }
- continue;
- },
- 14 => {
- index += 16;
- if (chunk[15] == '\n') {
- r.state = .seen_rn;
- continue :state;
- }
- continue;
- },
- 15 => {
- r.state = .seen_r;
- index += 16;
- continue :state;
- },
- 16 => {
- index += 16;
- continue;
- },
- else => unreachable,
- }
- },
- }
- },
-
- .seen_r => switch (bytes.len - index) {
- 0 => return index,
- 1 => {
- switch (bytes[index]) {
- '\n' => r.state = .seen_rn,
- '\r' => r.state = .seen_r,
- else => r.state = .start,
- }
- return index + 1;
- },
- 2 => {
- if (int16(bytes[index..][0..2]) == int16("\n\r")) {
- r.state = .seen_rnr;
- return index + 2;
- }
- r.state = .start;
- return index + 2;
- },
- else => {
- if (int16(bytes[index..][0..2]) == int16("\n\r") and
- bytes[index + 2] == '\n')
- {
- r.state = .finished;
- return index + 3;
- }
- index += 3;
- r.state = .start;
- continue :state;
- },
- },
- .seen_rn => switch (bytes.len - index) {
- 0 => return index,
- 1 => {
- switch (bytes[index]) {
- '\r' => r.state = .seen_rnr,
- else => r.state = .start,
- }
- return index + 1;
- },
- else => {
- if (int16(bytes[index..][0..2]) == int16("\r\n")) {
- r.state = .finished;
- return index + 2;
- }
- index += 2;
- r.state = .start;
- continue :state;
- },
- },
- .seen_rnr => switch (bytes.len - index) {
- 0 => return index,
- else => {
- if (bytes[index] == '\n') {
- r.state = .finished;
- return index + 1;
- }
- index += 1;
- r.state = .start;
- continue :state;
- },
- },
- .chunk_size_prefix_r => unreachable,
- .chunk_size_prefix_n => unreachable,
- .chunk_size => unreachable,
- .chunk_r => unreachable,
- .chunk_data => unreachable,
- }
-
- return index;
- }
- }
+ if (node.data.closing) {
+ node.data.close(client);
- pub fn findChunkedLen(r: *Response, bytes: []const u8) usize {
- var i: usize = 0;
- if (r.state == .chunk_size) {
- while (i < bytes.len) : (i += 1) {
- const digit = switch (bytes[i]) {
- '0'...'9' => |b| b - '0',
- 'A'...'Z' => |b| b - 'A' + 10,
- 'a'...'z' => |b| b - 'a' + 10,
- '\r' => {
- r.state = .chunk_r;
- i += 1;
- break;
- },
- else => {
- r.state = .invalid;
- return i;
- },
- };
- const mul = @mulWithOverflow(r.next_chunk_length, 16);
- if (mul[1] != 0) {
- r.state = .invalid;
- return i;
- }
- const add = @addWithOverflow(mul[0], digit);
- if (add[1] != 0) {
- r.state = .invalid;
- return i;
- }
- r.next_chunk_length = add[0];
- } else {
- return i;
- }
- }
- assert(r.state == .chunk_r);
- if (i == bytes.len) return i;
-
- if (bytes[i] == '\n') {
- r.state = .chunk_data;
- return i + 1;
- } else {
- r.state = .invalid;
- return i;
- }
+ return client.allocator.destroy(node);
}
- fn parseInt3(nnn: @Vector(3, u8)) u10 {
- const zero: @Vector(3, u8) = .{ '0', '0', '0' };
- const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
- return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
- }
+ if (pool.free_len + 1 >= pool.free_size) {
+ const popped = pool.free.popFirst() orelse unreachable;
- test parseInt3 {
- const expectEqual = std.testing.expectEqual;
- try expectEqual(@as(u10, 0), parseInt3("000".*));
- try expectEqual(@as(u10, 418), parseInt3("418".*));
- try expectEqual(@as(u10, 999), parseInt3("999".*));
- }
+ popped.data.close(client);
- test "find headers end basic" {
- var buffer: [1]u8 = undefined;
- var r = Response.initStatic(&buffer);
- try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4"));
- try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18"));
- try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah"));
+ return client.allocator.destroy(popped);
}
- test "find headers end vectorized" {
- var buffer: [1]u8 = undefined;
- var r = Response.initStatic(&buffer);
- const example =
- "HTTP/1.1 301 Moved Permanently\r\n" ++
- "Location: https://www.example.com/\r\n" ++
- "Content-Type: text/html; charset=UTF-8\r\n" ++
- "Content-Length: 220\r\n" ++
- "\r\ncontent";
- try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example));
- }
+ pool.free.append(node);
+ pool.free_len += 1;
+ }
- test "find headers end bug" {
- var buffer: [1]u8 = undefined;
- var r = Response.initStatic(&buffer);
- const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
- const example =
- "HTTP/1.1 200 OK\r\n" ++
- "Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++
- "content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++
- "Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++
- "Content-Type: application/x-gzip\r\n" ++
- "ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++
- "Strict-Transport-Security: max-age=31536000\r\n" ++
- "Vary: Authorization,Accept-Encoding,Origin\r\n" ++
- "X-Content-Type-Options: nosniff\r\n" ++
- "X-Frame-Options: deny\r\n" ++
- "X-XSS-Protection: 1; mode=block\r\n" ++
- "Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++
- "Transfer-Encoding: chunked\r\n" ++
- "X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++
- "connection: close\r\n\r\n" ++ trail;
- try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example));
- }
- };
+ /// Adds a newly created node to the pool of used connections. This function is threadsafe.
+ pub fn addUsed(pool: *ConnectionPool, node: *Node) void {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
- pub const Headers = struct {
- version: http.Version = .@"HTTP/1.1",
- method: http.Method = .GET,
- };
+ pool.used.append(node);
+ }
- pub const Options = struct {
- max_redirects: u32 = 3,
- header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
-
- pub const HeaderStrategy = union(enum) {
- /// In this case, the client's Allocator will be used to store the
- /// entire HTTP header. This value is the maximum total size of
- /// HTTP headers allowed, otherwise
- /// error.HttpHeadersExceededSizeLimit is returned from read().
- dynamic: usize,
- /// This is used to store the entire HTTP header. If the HTTP
- /// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
- /// is returned from read(). When this is used, `error.OutOfMemory`
- /// cannot be returned from `read()`.
- static: []u8,
- };
- };
+ pub fn deinit(pool: *ConnectionPool, client: *Client) void {
+ pool.mutex.lock();
- /// May be skipped if header strategy is buffer.
- pub fn deinit(req: *Request) void {
- if (req.response.header_bytes_owned) {
- req.response.header_bytes.deinit(req.client.allocator);
+ var next = pool.free.first;
+ while (next) |node| {
+ defer client.allocator.destroy(node);
+ next = node.next;
+
+ node.data.close(client);
}
- req.* = undefined;
+
+ next = pool.used.first;
+ while (next) |node| {
+ defer client.allocator.destroy(node);
+ next = node.next;
+
+ node.data.close(client);
+ }
+
+ pool.* = undefined;
}
+};
+
+pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
+pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
+pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.ReaderRaw, .{});
- pub const Reader = std.io.Reader(*Request, ReadError, read);
+pub const Connection = struct {
+ stream: net.Stream,
+ /// undefined unless protocol is tls.
+ tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
+ protocol: Protocol,
+ host: []u8,
+ port: u16,
- pub fn reader(req: *Request) Reader {
- return .{ .context = req };
+ // This connection has been part of a non keepalive request and cannot be added to the pool.
+ closing: bool = false,
+
+ pub const Protocol = enum { plain, tls };
+
+ pub fn read(conn: *Connection, buffer: []u8) !usize {
+ switch (conn.protocol) {
+ .plain => return conn.stream.read(buffer),
+ .tls => return conn.tls_client.read(conn.stream, buffer),
+ }
}
- pub fn readAll(req: *Request, buffer: []u8) !usize {
- return readAtLeast(req, buffer, buffer.len);
+ pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
+ switch (conn.protocol) {
+ .plain => return conn.stream.readAtLeast(buffer, len),
+ .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
+ }
}
pub const ReadError = net.Stream.ReadError || error{
- // From HTTP protocol
- HttpHeadersInvalid,
- HttpHeadersExceededSizeLimit,
- HttpRedirectMissingLocation,
- HttpTransferEncodingUnsupported,
- HttpContentLengthUnknown,
- TooManyHttpRedirects,
- ShortHttpStatusLine,
- BadHttpVersion,
- HttpHeaderContinuationsUnsupported,
- UnsupportedUrlScheme,
- UriMissingHost,
- UnknownHostName,
-
- // Network problems
- NetworkUnreachable,
- HostLacksNetworkAddresses,
- TemporaryNameServerFailure,
- NameServerFailure,
- ProtocolFamilyNotAvailable,
- ProtocolNotSupported,
-
- // System resource problems
- ProcessFdQuotaExceeded,
- SystemFdQuotaExceeded,
- OutOfMemory,
-
- // TLS problems
- InsufficientEntropy,
TlsConnectionTruncated,
TlsRecordOverflow,
TlsDecodeError,
TlsAlert,
TlsBadRecordMac,
+ Overflow,
TlsBadLength,
TlsIllegalParameter,
TlsUnexpectedMessage,
- TlsDecryptFailure,
- CertificateFieldHasInvalidLength,
- CertificateHostMismatch,
- CertificatePublicKeyInvalid,
- CertificateExpired,
- CertificateFieldHasWrongDataType,
- CertificateIssuerMismatch,
- CertificateNotYetValid,
- CertificateSignatureAlgorithmMismatch,
- CertificateSignatureAlgorithmUnsupported,
- CertificateSignatureInvalid,
- CertificateSignatureInvalidLength,
- CertificateSignatureNamedCurveUnsupported,
- CertificateSignatureUnsupportedBitCount,
- TlsCertificateNotVerified,
- TlsBadSignatureScheme,
- TlsBadRsaSignatureBitCount,
- TlsDecryptError,
- UnsupportedCertificateVersion,
- CertificateTimeInvalid,
- CertificateHasUnrecognizedObjectId,
- CertificateHasInvalidBitString,
- CertificateAuthorityBundleTooBig,
-
- // TODO: convert to higher level errors
- InvalidFormat,
- InvalidPort,
- UnexpectedCharacter,
- Overflow,
- InvalidCharacter,
- AddressFamilyNotSupported,
- AddressInUse,
- AddressNotAvailable,
- ConnectionPending,
- ConnectionRefused,
- FileNotFound,
- PermissionDenied,
- ServiceUnavailable,
- SocketTypeNotSupported,
- FileTooBig,
- LockViolation,
- NoSpaceLeft,
- NotOpenForWriting,
- InvalidEncoding,
- IdentityElement,
- NonCanonical,
- SignatureVerificationFailed,
- MessageTooLong,
- NegativeIntoUnsigned,
- TargetTooSmall,
- BufferTooSmall,
- InvalidSignature,
- NotSquare,
- DiskQuota,
- InvalidEnd,
- Incomplete,
- InvalidIpv4Mapping,
- InvalidIPAddressFormat,
- BadPathName,
- DeviceBusy,
- FileBusy,
- FileLocksNotSupported,
- InvalidHandle,
- InvalidUtf8,
- NameTooLong,
- NoDevice,
- PathAlreadyExists,
- PipeBusy,
- SharingViolation,
- SymLinkLoop,
- FileSystem,
- InterfaceNotFound,
- AlreadyBound,
- FileDescriptorNotASocket,
- NetworkSubsystemFailed,
- NotDir,
- ReadOnlyFileSystem,
- Unseekable,
- MissingEndCertificateMarker,
- InvalidPadding,
- EndOfStream,
- InvalidArgument,
};
- pub fn read(req: *Request, buffer: []u8) ReadError!usize {
- return readAtLeast(req, buffer, 1);
+ pub const Reader = std.io.Reader(*Connection, ReadError, read);
+
+ pub fn reader(conn: *Connection) Reader {
+ return Reader{ .context = conn };
}
- pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
- assert(len <= buffer.len);
- var index: usize = 0;
- while (index < len) {
- const zero_means_end = req.response.state.zeroMeansEnd();
- const amt = try readAdvanced(req, buffer[index..]);
- if (amt == 0 and zero_means_end) break;
- index += amt;
+ pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
+ switch (conn.protocol) {
+ .plain => return conn.stream.writeAll(buffer),
+ .tls => return conn.tls_client.writeAll(conn.stream, buffer),
}
- return index;
}
- /// This one can return 0 without meaning EOF.
- /// TODO change to readvAdvanced
- pub fn readAdvanced(req: *Request, buffer: []u8) !usize {
- var in = buffer[0..try req.connection.read(buffer)];
- var out_index: usize = 0;
- while (true) {
- switch (req.response.state) {
- .invalid => unreachable,
- .start, .seen_r, .seen_rn, .seen_rnr => {
- const i = req.response.findHeadersEnd(in);
- if (req.response.state == .invalid) return error.HttpHeadersInvalid;
-
- const headers_data = in[0..i];
- if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {
- return error.HttpHeadersExceededSizeLimit;
- }
- try req.response.header_bytes.appendSlice(req.client.allocator, headers_data);
-
- if (req.response.state == .finished) {
- req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
-
- if (req.response.headers.status.class() == .redirect) {
- if (req.redirects_left == 0) return error.TooManyHttpRedirects;
- const location = req.response.headers.location orelse
- return error.HttpRedirectMissingLocation;
- const new_url = try std.Uri.parse(location);
- const new_req = try req.client.request(new_url, req.headers, .{
- .max_redirects = req.redirects_left - 1,
- .header_strategy = if (req.response.header_bytes_owned) .{
- .dynamic = req.response.max_header_bytes,
- } else .{
- .static = req.response.header_bytes.unusedCapacitySlice(),
- },
- });
- req.deinit();
- req.* = new_req;
- assert(out_index == 0);
- in = buffer[0..try req.connection.read(buffer)];
- continue;
- }
-
- if (req.response.headers.transfer_encoding) |transfer_encoding| {
- switch (transfer_encoding) {
- .chunked => {
- req.response.next_chunk_length = 0;
- req.response.state = .chunk_size;
- },
- .compress => return error.HttpTransferEncodingUnsupported,
- .deflate => return error.HttpTransferEncodingUnsupported,
- .gzip => return error.HttpTransferEncodingUnsupported,
- }
- } else if (req.response.headers.content_length) |content_length| {
- req.response.next_chunk_length = content_length;
- } else {
- return error.HttpContentLengthUnknown;
- }
-
- in = in[i..];
- continue;
- }
-
- assert(out_index == 0);
- return 0;
- },
- .finished => {
- if (in.ptr == buffer.ptr) {
- return in.len;
- } else {
- mem.copy(u8, buffer[out_index..], in);
- return out_index + in.len;
- }
- },
- .chunk_size_prefix_r => switch (in.len) {
- 0 => return out_index,
- 1 => switch (in[0]) {
- '\r' => {
- req.response.state = .chunk_size_prefix_n;
- return out_index;
- },
- else => {
- req.response.state = .invalid;
- return error.HttpHeadersInvalid;
- },
- },
- else => switch (int16(in[0..2])) {
- int16("\r\n") => {
- in = in[2..];
- req.response.state = .chunk_size;
- continue;
- },
- else => {
- req.response.state = .invalid;
- return error.HttpHeadersInvalid;
- },
- },
- },
- .chunk_size_prefix_n => switch (in.len) {
- 0 => return out_index,
- else => switch (in[0]) {
- '\n' => {
- in = in[1..];
- req.response.state = .chunk_size;
- continue;
- },
- else => {
- req.response.state = .invalid;
- return error.HttpHeadersInvalid;
- },
- },
- },
- .chunk_size, .chunk_r => {
- const i = req.response.findChunkedLen(in);
- switch (req.response.state) {
- .invalid => return error.HttpHeadersInvalid,
- .chunk_data => {
- if (req.response.next_chunk_length == 0) {
- req.response.state = .start;
- return out_index;
- }
- in = in[i..];
- continue;
- },
- .chunk_size => return out_index,
- else => unreachable,
- }
- },
- .chunk_data => {
- // TODO https://github.com/ziglang/zig/issues/14039
- const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len));
- req.response.next_chunk_length -= sub_amt;
- if (req.response.next_chunk_length > 0) {
- if (in.ptr == buffer.ptr) {
- return sub_amt;
- } else {
- mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
- out_index += sub_amt;
- return out_index;
- }
- }
- mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
- out_index += sub_amt;
- req.response.state = .chunk_size_prefix_r;
- in = in[sub_amt..];
- continue;
- },
- }
+ pub fn write(conn: *Connection, buffer: []const u8) !usize {
+ switch (conn.protocol) {
+ .plain => return conn.stream.write(buffer),
+ .tls => return conn.tls_client.write(conn.stream, buffer),
}
}
- inline fn int16(array: *const [2]u8) u16 {
- return @bitCast(u16, array.*);
- }
+ pub const WriteError = net.Stream.WriteError || error{};
+ pub const Writer = std.io.Writer(*Connection, WriteError, write);
- inline fn int32(array: *const [4]u8) u32 {
- return @bitCast(u32, array.*);
+ pub fn writer(conn: *Connection) Writer {
+ return Writer{ .context = conn };
}
- inline fn int64(array: *const [8]u8) u64 {
- return @bitCast(u64, array.*);
- }
+ pub fn close(conn: *Connection, client: *const Client) void {
+ if (conn.protocol == .tls) {
+ // try to cleanly close the TLS connection, for any server that cares.
+ _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
+ client.allocator.destroy(conn.tls_client);
+ }
- test {
- _ = Response;
+ conn.stream.close();
+
+ client.allocator.free(conn.host);
}
};
pub fn deinit(client: *Client) void {
+ client.connection_pool.deinit(client);
+
client.ca_bundle.deinit(client.allocator);
client.* = undefined;
}
-pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection {
- var conn: Connection = .{
+pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
+
+pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
+ if (client.connection_pool.findConnection(.{
+ .host = host,
+ .port = port,
+ .is_tls = protocol == .tls,
+ })) |node|
+ return node;
+
+ const conn = try client.allocator.create(ConnectionPool.Node);
+ errdefer client.allocator.destroy(conn);
+ conn.* = .{ .data = undefined };
+
+ conn.data = .{
.stream = try net.tcpConnectToHost(client.allocator, host, port),
.tls_client = undefined,
.protocol = protocol,
+ .host = try client.allocator.dupe(u8, host),
+ .port = port,
};
switch (protocol) {
.plain => {},
.tls => {
- conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host);
+ conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
+ conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
- conn.tls_client.allow_truncation_attacks = true;
+ conn.data.tls_client.allow_truncation_attacks = true;
},
}
+ client.connection_pool.addUsed(conn);
+
return conn;
}
-pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) !Request {
+pub const RequestError = ConnectError || Connection.WriteError || error{
+ UnsupportedUrlScheme,
+ UriMissingHost,
+
+ CertificateAuthorityBundleTooBig,
+ InvalidPadding,
+ MissingEndCertificateMarker,
+ Unseekable,
+ EndOfStream,
+};
+
+pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http"))
.plain
else if (mem.eql(u8, uri.scheme, "https"))
@@ -884,34 +292,85 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
const host = uri.host orelse return error.UriMissingHost;
if (client.next_https_rescan_certs and protocol == .tls) {
- try client.ca_bundle.rescan(client.allocator);
- client.next_https_rescan_certs = false;
+ client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
+ defer client.connection_pool.mutex.unlock();
+
+ if (client.next_https_rescan_certs) {
+ try client.ca_bundle.rescan(client.allocator);
+ client.next_https_rescan_certs = false;
+ }
}
var req: Request = .{
+ .uri = uri,
.client = client,
.headers = headers,
.connection = try client.connect(host, port, protocol),
.redirects_left = options.max_redirects,
+ .handle_redirects = options.handle_redirects,
+ .compression_init = false,
.response = switch (options.header_strategy) {
- .dynamic => |max| Request.Response.initDynamic(max),
- .static => |buf| Request.Response.initStatic(buf),
+ .dynamic => |max| Response.initDynamic(max),
+ .static => |buf| Response.initStatic(buf),
},
+ .arena = undefined,
};
+ req.arena = std.heap.ArenaAllocator.init(client.allocator);
+
{
- var h = try std.BoundedArray(u8, 1000).init(0);
- try h.appendSlice(@tagName(headers.method));
- try h.appendSlice(" ");
- try h.appendSlice(uri.path);
- try h.appendSlice(" ");
- try h.appendSlice(@tagName(headers.version));
- try h.appendSlice("\r\nHost: ");
- try h.appendSlice(host);
- try h.appendSlice("\r\nConnection: close\r\n\r\n");
-
- const header_bytes = h.slice();
- try req.connection.writeAll(header_bytes);
+ var buffered = std.io.bufferedWriter(req.connection.data.writer());
+ const writer = buffered.writer();
+
+ const escaped_path = try Uri.escapePath(client.allocator, uri.path);
+ defer client.allocator.free(escaped_path);
+
+ const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
+ defer if (escaped_query) |q| client.allocator.free(q);
+
+ const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
+ defer if (escaped_fragment) |f| client.allocator.free(f);
+
+ try writer.writeAll(@tagName(headers.method));
+ try writer.writeByte(' ');
+ try writer.writeAll(escaped_path);
+ if (escaped_query) |q| {
+ try writer.writeByte('?');
+ try writer.writeAll(q);
+ }
+ if (escaped_fragment) |f| {
+ try writer.writeByte('#');
+ try writer.writeAll(f);
+ }
+ try writer.writeByte(' ');
+ try writer.writeAll(@tagName(headers.version));
+ try writer.writeAll("\r\nHost: ");
+ try writer.writeAll(host);
+ try writer.writeAll("\r\nUser-Agent: ");
+ try writer.writeAll(headers.user_agent);
+ if (headers.connection == .close) {
+ try writer.writeAll("\r\nConnection: close");
+ } else {
+ try writer.writeAll("\r\nConnection: keep-alive");
+ }
+ try writer.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd");
+
+ switch (headers.transfer_encoding) {
+ .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"),
+ .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}),
+ .none => {},
+ }
+
+ for (headers.custom) |header| {
+ try writer.writeAll("\r\n");
+ try writer.writeAll(header.name);
+ try writer.writeAll(": ");
+ try writer.writeAll(header.value);
+ }
+
+ try writer.writeAll("\r\n\r\n");
+
+ try buffered.flush();
}
return req;
@@ -925,5 +384,7 @@ test {
return error.SkipZigTest;
}
+ if (builtin.os.tag == .wasi) return error.SkipZigTest;
+
_ = Request;
}