diff options
Diffstat (limited to 'lib/std/http/Client.zig')
| -rw-r--r-- | lib/std/http/Client.zig | 1033 |
1 files changed, 247 insertions, 786 deletions
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 7f62b2d597..baf0239388 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -13,6 +13,12 @@ const Uri = std.Uri; const Allocator = std.mem.Allocator; const testing = std.testing; +pub const Request = @import("Client/Request.zig"); +pub const Response = @import("Client/Response.zig"); + +pub const default_connection_pool_size = 32; +const connection_pool_size = std.options.http_connection_pool_size; + /// Used for tcpConnectToHost and storing HTTP headers when an externally /// managed buffer is not provided. allocator: Allocator, @@ -21,854 +27,256 @@ ca_bundle: std.crypto.Certificate.Bundle = .{}, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, -pub const Connection = struct { - stream: net.Stream, - /// undefined unless protocol is tls. - tls_client: std.crypto.tls.Client, - protocol: Protocol, +connection_pool: ConnectionPool = .{}, - pub const Protocol = enum { plain, tls }; +pub const ConnectionPool = struct { + pub const Criteria = struct { + host: []const u8, + port: u16, + is_tls: bool, + }; - pub fn read(conn: *Connection, buffer: []u8) !usize { - switch (conn.protocol) { - .plain => return conn.stream.read(buffer), - .tls => return conn.tls_client.read(conn.stream, buffer), + const Queue = std.TailQueue(Connection); + pub const Node = Queue.Node; + + mutex: std.Thread.Mutex = .{}, + used: Queue = .{}, + free: Queue = .{}, + free_len: usize = 0, + free_size: usize = default_connection_pool_size, + + /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// If no connection is found, null is returned. + pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + var next = pool.free.last; + while (next) |node| : (next = node.prev) { + if ((node.data.protocol == .tls) != criteria.is_tls) continue; + if (node.data.port != criteria.port) continue; + if (std.mem.eql(u8, node.data.host, criteria.host)) continue; + + pool.acquireUnsafe(node); + return node; } - } - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize { - switch (conn.protocol) { - .plain => return conn.stream.readAtLeast(buffer, len), - .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), - } + return null; } - pub fn writeAll(conn: *Connection, buffer: []const u8) !void { - switch (conn.protocol) { - .plain => return conn.stream.writeAll(buffer), - .tls => return conn.tls_client.writeAll(conn.stream, buffer), - } + /// Acquires an existing connection from the connection pool. This function is not threadsafe. + pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void { + pool.free.remove(node); + pool.free_len -= 1; + + pool.used.append(node); } - pub fn write(conn: *Connection, buffer: []const u8) !usize { - switch (conn.protocol) { - .plain => return conn.stream.write(buffer), - .tls => return conn.tls_client.write(conn.stream, buffer), - } + /// Acquires an existing connection from the connection pool. This function is threadsafe. + pub fn acquire(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); + + return pool.acquireUnsafe(node); } -}; -/// TODO: emit error.UnexpectedEndOfStream or something like that when the read -/// data does not match the content length. This is necessary since HTTPS disables -/// close_notify protection on underlying TLS streams. -pub const Request = struct { - client: *Client, - connection: Connection, - redirects_left: u32, - response: Response, - /// These are stored in Request so that they are available when following - /// redirects. - headers: Headers, - - pub const Response = struct { - headers: Response.Headers, - state: State, - header_bytes_owned: bool, - /// This could either be a fixed buffer provided by the API user or it - /// could be our own array list. - header_bytes: std.ArrayListUnmanaged(u8), - max_header_bytes: usize, - next_chunk_length: u64, - - pub const Headers = struct { - status: http.Status, - version: http.Version, - location: ?[]const u8 = null, - content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - - pub fn parse(bytes: []const u8) !Response.Headers { - var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); - - const first_line = it.first(); - if (first_line.len < 12) - return error.ShortHttpStatusLine; - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); - - var headers: Response.Headers = .{ - .version = version, - .status = status, - }; - - while (it.next()) |line| { - if (line.len == 0) return error.HttpHeadersInvalid; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - var line_it = mem.split(u8, line, ": "); - const header_name = line_it.first(); - const header_value = line_it.rest(); - if (std.ascii.eqlIgnoreCase(header_name, "location")) { - if (headers.location != null) return error.HttpHeadersInvalid; - headers.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (headers.content_length != null) return error.HttpHeadersInvalid; - headers.content_length = try std.fmt.parseInt(u64, header_value, 10); - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; - headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse - return error.HttpTransferEncodingUnsupported; - } - } - - return headers; - } - - test "parse headers" { - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - const parsed = try Response.Headers.parse(example); - try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version); - try testing.expectEqual(http.Status.moved_permanently, parsed.status); - try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse - return error.TestFailed); - try testing.expectEqual(@as(?u64, 220), parsed.content_length); - } - - test "header continuation" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeaderContinuationsUnsupported, - Response.Headers.parse(example), - ); - } - - test "extra content length" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Length: 220\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "content-length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeadersInvalid, - Response.Headers.parse(example), - ); - } - }; - - pub const State = enum { - /// Begin header parsing states. - invalid, - start, - seen_r, - seen_rn, - seen_rnr, - finished, - /// Begin transfer-encoding: chunked parsing states. - chunk_size_prefix_r, - chunk_size_prefix_n, - chunk_size, - chunk_r, - chunk_data, - - pub fn zeroMeansEnd(state: State) bool { - return switch (state) { - .finished, .chunk_data => true, - else => false, - }; - } - }; - - pub fn initDynamic(max: usize) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{}, - .max_header_bytes = max, - .header_bytes_owned = true, - .next_chunk_length = undefined, - }; - } + /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// If the connection is marked as closing, it will be closed instead. + pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); - pub fn initStatic(buf: []u8) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, - .max_header_bytes = buf.len, - .header_bytes_owned = false, - .next_chunk_length = undefined, - }; - } + pool.used.remove(node); - /// Returns how many bytes are part of HTTP headers. Always less than or - /// equal to bytes.len. If the amount returned is less than bytes.len, it - /// means the headers ended and the first byte after the double \r\n\r\n is - /// located at `bytes[result]`. - pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize { - var index: usize = 0; - - // TODO: https://github.com/ziglang/zig/issues/8220 - state: while (true) { - switch (r.state) { - .invalid => unreachable, - .finished => unreachable, - .start => while (true) { - switch (bytes.len - index) { - 0 => return index, - 1 => { - if (bytes[index] == '\r') - r.state = .seen_r; - return index + 1; - }, - 2 => { - if (int16(bytes[index..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - } else if (bytes[index + 1] == '\r') { - r.state = .seen_r; - } - return index + 2; - }, - 3 => { - if (int16(bytes[index..][0..2]) == int16("\r\n") and - bytes[index + 2] == '\r') - { - r.state = .seen_rnr; - } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - } else if (bytes[index + 2] == '\r') { - r.state = .seen_r; - } - return index + 3; - }, - 4...15 => { - if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) { - r.state = .finished; - return index + 4; - } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and - bytes[index + 3] == '\r') - { - r.state = .seen_rnr; - index += 4; - continue :state; - } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - index += 4; - continue :state; - } else if (bytes[index + 3] == '\r') { - r.state = .seen_r; - index += 4; - continue :state; - } - index += 4; - continue; - }, - else => { - const chunk = bytes[index..][0..16]; - const v: @Vector(16, u8) = chunk.*; - const matches_r = v == @splat(16, @as(u8, '\r')); - const iota = std.simd.iota(u8, 16); - const default = @splat(16, @as(u8, 16)); - const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default)); - switch (sub_index) { - 0...12 => { - index += sub_index + 4; - if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) { - r.state = .finished; - return index; - } - continue; - }, - 13 => { - index += 16; - if (int16(chunk[14..][0..2]) == int16("\n\r")) { - r.state = .seen_rnr; - continue :state; - } - continue; - }, - 14 => { - index += 16; - if (chunk[15] == '\n') { - r.state = .seen_rn; - continue :state; - } - continue; - }, - 15 => { - r.state = .seen_r; - index += 16; - continue :state; - }, - 16 => { - index += 16; - continue; - }, - else => unreachable, - } - }, - } - }, - - .seen_r => switch (bytes.len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\n' => r.state = .seen_rn, - '\r' => r.state = .seen_r, - else => r.state = .start, - } - return index + 1; - }, - 2 => { - if (int16(bytes[index..][0..2]) == int16("\n\r")) { - r.state = .seen_rnr; - return index + 2; - } - r.state = .start; - return index + 2; - }, - else => { - if (int16(bytes[index..][0..2]) == int16("\n\r") and - bytes[index + 2] == '\n') - { - r.state = .finished; - return index + 3; - } - index += 3; - r.state = .start; - continue :state; - }, - }, - .seen_rn => switch (bytes.len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_rnr, - else => r.state = .start, - } - return index + 1; - }, - else => { - if (int16(bytes[index..][0..2]) == int16("\r\n")) { - r.state = .finished; - return index + 2; - } - index += 2; - r.state = .start; - continue :state; - }, - }, - .seen_rnr => switch (bytes.len - index) { - 0 => return index, - else => { - if (bytes[index] == '\n') { - r.state = .finished; - return index + 1; - } - index += 1; - r.state = .start; - continue :state; - }, - }, - .chunk_size_prefix_r => unreachable, - .chunk_size_prefix_n => unreachable, - .chunk_size => unreachable, - .chunk_r => unreachable, - .chunk_data => unreachable, - } - - return index; - } - } + if (node.data.closing) { + node.data.close(client); - pub fn findChunkedLen(r: *Response, bytes: []const u8) usize { - var i: usize = 0; - if (r.state == .chunk_size) { - while (i < bytes.len) : (i += 1) { - const digit = switch (bytes[i]) { - '0'...'9' => |b| b - '0', - 'A'...'Z' => |b| b - 'A' + 10, - 'a'...'z' => |b| b - 'a' + 10, - '\r' => { - r.state = .chunk_r; - i += 1; - break; - }, - else => { - r.state = .invalid; - return i; - }, - }; - const mul = @mulWithOverflow(r.next_chunk_length, 16); - if (mul[1] != 0) { - r.state = .invalid; - return i; - } - const add = @addWithOverflow(mul[0], digit); - if (add[1] != 0) { - r.state = .invalid; - return i; - } - r.next_chunk_length = add[0]; - } else { - return i; - } - } - assert(r.state == .chunk_r); - if (i == bytes.len) return i; - - if (bytes[i] == '\n') { - r.state = .chunk_data; - return i + 1; - } else { - r.state = .invalid; - return i; - } + return client.allocator.destroy(node); } - fn parseInt3(nnn: @Vector(3, u8)) u10 { - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); - } + if (pool.free_len + 1 >= pool.free_size) { + const popped = pool.free.popFirst() orelse unreachable; - test parseInt3 { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000".*)); - try expectEqual(@as(u10, 418), parseInt3("418".*)); - try expectEqual(@as(u10, 999), parseInt3("999".*)); - } + popped.data.close(client); - test "find headers end basic" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4")); - try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18")); - try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah")); + return client.allocator.destroy(popped); } - test "find headers end vectorized" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n" ++ - "\r\ncontent"; - try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example)); - } + pool.free.append(node); + pool.free_len += 1; + } - test "find headers end bug" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - const example = - "HTTP/1.1 200 OK\r\n" ++ - "Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++ - "content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++ - "Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++ - "Content-Type: application/x-gzip\r\n" ++ - "ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++ - "Strict-Transport-Security: max-age=31536000\r\n" ++ - "Vary: Authorization,Accept-Encoding,Origin\r\n" ++ - "X-Content-Type-Options: nosniff\r\n" ++ - "X-Frame-Options: deny\r\n" ++ - "X-XSS-Protection: 1; mode=block\r\n" ++ - "Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++ - "Transfer-Encoding: chunked\r\n" ++ - "X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++ - "connection: close\r\n\r\n" ++ trail; - try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example)); - } - }; + /// Adds a newly created node to the pool of used connections. This function is threadsafe. + pub fn addUsed(pool: *ConnectionPool, node: *Node) void { + pool.mutex.lock(); + defer pool.mutex.unlock(); - pub const Headers = struct { - version: http.Version = .@"HTTP/1.1", - method: http.Method = .GET, - }; + pool.used.append(node); + } - pub const Options = struct { - max_redirects: u32 = 3, - header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, - - pub const HeaderStrategy = union(enum) { - /// In this case, the client's Allocator will be used to store the - /// entire HTTP header. This value is the maximum total size of - /// HTTP headers allowed, otherwise - /// error.HttpHeadersExceededSizeLimit is returned from read(). - dynamic: usize, - /// This is used to store the entire HTTP header. If the HTTP - /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` - /// is returned from read(). When this is used, `error.OutOfMemory` - /// cannot be returned from `read()`. - static: []u8, - }; - }; + pub fn deinit(pool: *ConnectionPool, client: *Client) void { + pool.mutex.lock(); - /// May be skipped if header strategy is buffer. - pub fn deinit(req: *Request) void { - if (req.response.header_bytes_owned) { - req.response.header_bytes.deinit(req.client.allocator); + var next = pool.free.first; + while (next) |node| { + defer client.allocator.destroy(node); + next = node.next; + + node.data.close(client); } - req.* = undefined; + + next = pool.used.first; + while (next) |node| { + defer client.allocator.destroy(node); + next = node.next; + + node.data.close(client); + } + + pool.* = undefined; } +}; + +pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); +pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); +pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.ReaderRaw, .{}); - pub const Reader = std.io.Reader(*Request, ReadError, read); +pub const Connection = struct { + stream: net.Stream, + /// undefined unless protocol is tls. + tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. + protocol: Protocol, + host: []u8, + port: u16, - pub fn reader(req: *Request) Reader { - return .{ .context = req }; + // This connection has been part of a non keepalive request and cannot be added to the pool. + closing: bool = false, + + pub const Protocol = enum { plain, tls }; + + pub fn read(conn: *Connection, buffer: []u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.read(buffer), + .tls => return conn.tls_client.read(conn.stream, buffer), + } } - pub fn readAll(req: *Request, buffer: []u8) !usize { - return readAtLeast(req, buffer, buffer.len); + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize { + switch (conn.protocol) { + .plain => return conn.stream.readAtLeast(buffer, len), + .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), + } } pub const ReadError = net.Stream.ReadError || error{ - // From HTTP protocol - HttpHeadersInvalid, - HttpHeadersExceededSizeLimit, - HttpRedirectMissingLocation, - HttpTransferEncodingUnsupported, - HttpContentLengthUnknown, - TooManyHttpRedirects, - ShortHttpStatusLine, - BadHttpVersion, - HttpHeaderContinuationsUnsupported, - UnsupportedUrlScheme, - UriMissingHost, - UnknownHostName, - - // Network problems - NetworkUnreachable, - HostLacksNetworkAddresses, - TemporaryNameServerFailure, - NameServerFailure, - ProtocolFamilyNotAvailable, - ProtocolNotSupported, - - // System resource problems - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - OutOfMemory, - - // TLS problems - InsufficientEntropy, TlsConnectionTruncated, TlsRecordOverflow, TlsDecodeError, TlsAlert, TlsBadRecordMac, + Overflow, TlsBadLength, TlsIllegalParameter, TlsUnexpectedMessage, - TlsDecryptFailure, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - TlsDecryptError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - CertificateAuthorityBundleTooBig, - - // TODO: convert to higher level errors - InvalidFormat, - InvalidPort, - UnexpectedCharacter, - Overflow, - InvalidCharacter, - AddressFamilyNotSupported, - AddressInUse, - AddressNotAvailable, - ConnectionPending, - ConnectionRefused, - FileNotFound, - PermissionDenied, - ServiceUnavailable, - SocketTypeNotSupported, - FileTooBig, - LockViolation, - NoSpaceLeft, - NotOpenForWriting, - InvalidEncoding, - IdentityElement, - NonCanonical, - SignatureVerificationFailed, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - DiskQuota, - InvalidEnd, - Incomplete, - InvalidIpv4Mapping, - InvalidIPAddressFormat, - BadPathName, - DeviceBusy, - FileBusy, - FileLocksNotSupported, - InvalidHandle, - InvalidUtf8, - NameTooLong, - NoDevice, - PathAlreadyExists, - PipeBusy, - SharingViolation, - SymLinkLoop, - FileSystem, - InterfaceNotFound, - AlreadyBound, - FileDescriptorNotASocket, - NetworkSubsystemFailed, - NotDir, - ReadOnlyFileSystem, - Unseekable, - MissingEndCertificateMarker, - InvalidPadding, - EndOfStream, - InvalidArgument, }; - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - return readAtLeast(req, buffer, 1); + pub const Reader = std.io.Reader(*Connection, ReadError, read); + + pub fn reader(conn: *Connection) Reader { + return Reader{ .context = conn }; } - pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const zero_means_end = req.response.state.zeroMeansEnd(); - const amt = try readAdvanced(req, buffer[index..]); - if (amt == 0 and zero_means_end) break; - index += amt; + pub fn writeAll(conn: *Connection, buffer: []const u8) !void { + switch (conn.protocol) { + .plain => return conn.stream.writeAll(buffer), + .tls => return conn.tls_client.writeAll(conn.stream, buffer), } - return index; } - /// This one can return 0 without meaning EOF. - /// TODO change to readvAdvanced - pub fn readAdvanced(req: *Request, buffer: []u8) !usize { - var in = buffer[0..try req.connection.read(buffer)]; - var out_index: usize = 0; - while (true) { - switch (req.response.state) { - .invalid => unreachable, - .start, .seen_r, .seen_rn, .seen_rnr => { - const i = req.response.findHeadersEnd(in); - if (req.response.state == .invalid) return error.HttpHeadersInvalid; - - const headers_data = in[0..i]; - if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { - return error.HttpHeadersExceededSizeLimit; - } - try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); - - if (req.response.state == .finished) { - req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - - if (req.response.headers.status.class() == .redirect) { - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - const location = req.response.headers.location orelse - return error.HttpRedirectMissingLocation; - const new_url = try std.Uri.parse(location); - const new_req = try req.client.request(new_url, req.headers, .{ - .max_redirects = req.redirects_left - 1, - .header_strategy = if (req.response.header_bytes_owned) .{ - .dynamic = req.response.max_header_bytes, - } else .{ - .static = req.response.header_bytes.unusedCapacitySlice(), - }, - }); - req.deinit(); - req.* = new_req; - assert(out_index == 0); - in = buffer[0..try req.connection.read(buffer)]; - continue; - } - - if (req.response.headers.transfer_encoding) |transfer_encoding| { - switch (transfer_encoding) { - .chunked => { - req.response.next_chunk_length = 0; - req.response.state = .chunk_size; - }, - .compress => return error.HttpTransferEncodingUnsupported, - .deflate => return error.HttpTransferEncodingUnsupported, - .gzip => return error.HttpTransferEncodingUnsupported, - } - } else if (req.response.headers.content_length) |content_length| { - req.response.next_chunk_length = content_length; - } else { - return error.HttpContentLengthUnknown; - } - - in = in[i..]; - continue; - } - - assert(out_index == 0); - return 0; - }, - .finished => { - if (in.ptr == buffer.ptr) { - return in.len; - } else { - mem.copy(u8, buffer[out_index..], in); - return out_index + in.len; - } - }, - .chunk_size_prefix_r => switch (in.len) { - 0 => return out_index, - 1 => switch (in[0]) { - '\r' => { - req.response.state = .chunk_size_prefix_n; - return out_index; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - else => switch (int16(in[0..2])) { - int16("\r\n") => { - in = in[2..]; - req.response.state = .chunk_size; - continue; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - }, - .chunk_size_prefix_n => switch (in.len) { - 0 => return out_index, - else => switch (in[0]) { - '\n' => { - in = in[1..]; - req.response.state = .chunk_size; - continue; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - }, - .chunk_size, .chunk_r => { - const i = req.response.findChunkedLen(in); - switch (req.response.state) { - .invalid => return error.HttpHeadersInvalid, - .chunk_data => { - if (req.response.next_chunk_length == 0) { - req.response.state = .start; - return out_index; - } - in = in[i..]; - continue; - }, - .chunk_size => return out_index, - else => unreachable, - } - }, - .chunk_data => { - // TODO https://github.com/ziglang/zig/issues/14039 - const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len)); - req.response.next_chunk_length -= sub_amt; - if (req.response.next_chunk_length > 0) { - if (in.ptr == buffer.ptr) { - return sub_amt; - } else { - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - return out_index; - } - } - mem.copy(u8, buffer[out_index..], in[0..sub_amt]); - out_index += sub_amt; - req.response.state = .chunk_size_prefix_r; - in = in[sub_amt..]; - continue; - }, - } + pub fn write(conn: *Connection, buffer: []const u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.write(buffer), + .tls => return conn.tls_client.write(conn.stream, buffer), } } - inline fn int16(array: *const [2]u8) u16 { - return @bitCast(u16, array.*); - } + pub const WriteError = net.Stream.WriteError || error{}; + pub const Writer = std.io.Writer(*Connection, WriteError, write); - inline fn int32(array: *const [4]u8) u32 { - return @bitCast(u32, array.*); + pub fn writer(conn: *Connection) Writer { + return Writer{ .context = conn }; } - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(u64, array.*); - } + pub fn close(conn: *Connection, client: *const Client) void { + if (conn.protocol == .tls) { + // try to cleanly close the TLS connection, for any server that cares. + _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + client.allocator.destroy(conn.tls_client); + } - test { - _ = Response; + conn.stream.close(); + + client.allocator.free(conn.host); } }; pub fn deinit(client: *Client) void { + client.connection_pool.deinit(client); + client.ca_bundle.deinit(client.allocator); client.* = undefined; } -pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection { - var conn: Connection = .{ +pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream); + +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { + if (client.connection_pool.findConnection(.{ + .host = host, + .port = port, + .is_tls = protocol == .tls, + })) |node| + return node; + + const conn = try client.allocator.create(ConnectionPool.Node); + errdefer client.allocator.destroy(conn); + conn.* = .{ .data = undefined }; + + conn.data = .{ .stream = try net.tcpConnectToHost(client.allocator, host, port), .tls_client = undefined, .protocol = protocol, + .host = try client.allocator.dupe(u8, host), + .port = port, }; switch (protocol) { .plain => {}, .tls => { - conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host); + conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. - conn.tls_client.allow_truncation_attacks = true; + conn.data.tls_client.allow_truncation_attacks = true; }, } + client.connection_pool.addUsed(conn); + return conn; } -pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) !Request { +pub const RequestError = ConnectError || Connection.WriteError || error{ + UnsupportedUrlScheme, + UriMissingHost, + + CertificateAuthorityBundleTooBig, + InvalidPadding, + MissingEndCertificateMarker, + Unseekable, + EndOfStream, +}; + +pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request { const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http")) .plain else if (mem.eql(u8, uri.scheme, "https")) @@ -884,34 +292,85 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req const host = uri.host orelse return error.UriMissingHost; if (client.next_https_rescan_certs and protocol == .tls) { - try client.ca_bundle.rescan(client.allocator); - client.next_https_rescan_certs = false; + client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. + defer client.connection_pool.mutex.unlock(); + + if (client.next_https_rescan_certs) { + try client.ca_bundle.rescan(client.allocator); + client.next_https_rescan_certs = false; + } } var req: Request = .{ + .uri = uri, .client = client, .headers = headers, .connection = try client.connect(host, port, protocol), .redirects_left = options.max_redirects, + .handle_redirects = options.handle_redirects, + .compression_init = false, .response = switch (options.header_strategy) { - .dynamic => |max| Request.Response.initDynamic(max), - .static => |buf| Request.Response.initStatic(buf), + .dynamic => |max| Response.initDynamic(max), + .static => |buf| Response.initStatic(buf), }, + .arena = undefined, }; + req.arena = std.heap.ArenaAllocator.init(client.allocator); + { - var h = try std.BoundedArray(u8, 1000).init(0); - try h.appendSlice(@tagName(headers.method)); - try h.appendSlice(" "); - try h.appendSlice(uri.path); - try h.appendSlice(" "); - try h.appendSlice(@tagName(headers.version)); - try h.appendSlice("\r\nHost: "); - try h.appendSlice(host); - try h.appendSlice("\r\nConnection: close\r\n\r\n"); - - const header_bytes = h.slice(); - try req.connection.writeAll(header_bytes); + var buffered = std.io.bufferedWriter(req.connection.data.writer()); + const writer = buffered.writer(); + + const escaped_path = try Uri.escapePath(client.allocator, uri.path); + defer client.allocator.free(escaped_path); + + const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null; + defer if (escaped_query) |q| client.allocator.free(q); + + const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null; + defer if (escaped_fragment) |f| client.allocator.free(f); + + try writer.writeAll(@tagName(headers.method)); + try writer.writeByte(' '); + try writer.writeAll(escaped_path); + if (escaped_query) |q| { + try writer.writeByte('?'); + try writer.writeAll(q); + } + if (escaped_fragment) |f| { + try writer.writeByte('#'); + try writer.writeAll(f); + } + try writer.writeByte(' '); + try writer.writeAll(@tagName(headers.version)); + try writer.writeAll("\r\nHost: "); + try writer.writeAll(host); + try writer.writeAll("\r\nUser-Agent: "); + try writer.writeAll(headers.user_agent); + if (headers.connection == .close) { + try writer.writeAll("\r\nConnection: close"); + } else { + try writer.writeAll("\r\nConnection: keep-alive"); + } + try writer.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd"); + + switch (headers.transfer_encoding) { + .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"), + .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}), + .none => {}, + } + + for (headers.custom) |header| { + try writer.writeAll("\r\n"); + try writer.writeAll(header.name); + try writer.writeAll(": "); + try writer.writeAll(header.value); + } + + try writer.writeAll("\r\n\r\n"); + + try buffered.flush(); } return req; @@ -925,5 +384,7 @@ test { return error.SkipZigTest; } + if (builtin.os.tag == .wasi) return error.SkipZigTest; + _ = Request; } |
