diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2023-04-09 10:44:52 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-09 10:44:52 -0400 |
| commit | 2ee328995a70c5c446f24c5593e0fad760e6d839 (patch) | |
| tree | 0e547171b7790ffd182fc298d384ef614571e97e /lib/std | |
| parent | c22a30ac99b9a2b92d9a8e926b9bf0c9dbc3d14e (diff) | |
| parent | 7f9a4625fda0b1a33177cdd66819f0a061c6b2da (diff) | |
| download | zig-2ee328995a70c5c446f24c5593e0fad760e6d839.tar.gz zig-2ee328995a70c5c446f24c5593e0fad760e6d839.zip | |
Merge pull request #15123 from truemedian/http-server
std.http: add http server
Diffstat (limited to 'lib/std')
| -rw-r--r-- | lib/std/http.zig | 2 | ||||
| -rw-r--r-- | lib/std/http/Client.zig | 880 | ||||
| -rw-r--r-- | lib/std/http/Client/Request.zig | 482 | ||||
| -rw-r--r-- | lib/std/http/Client/Response.zig | 509 | ||||
| -rw-r--r-- | lib/std/http/Server.zig | 600 | ||||
| -rw-r--r-- | lib/std/http/protocol.zig | 842 |
6 files changed, 2203 insertions, 1112 deletions
diff --git a/lib/std/http.zig b/lib/std/http.zig index ef89f09925..6e5f4e0cd9 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,4 +1,6 @@ pub const Client = @import("http/Client.zig"); +pub const Server = @import("http/Server.zig"); +pub const protocol = @import("http/protocol.zig"); pub const Version = enum { @"HTTP/1.0", diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 76073c0ce3..010c557f87 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1,49 +1,105 @@ -//! TODO: send connection: keep-alive and LRU cache a configurable number of -//! open connections to skip DNS and TLS handshake for subsequent requests. -//! -//! This API is *not* thread safe. +//! Connecting and opening requests are threadsafe. Individual requests are not. const std = @import("../std.zig"); -const mem = std.mem; -const assert = std.debug.assert; +const testing = std.testing; const http = std.http; +const mem = std.mem; const net = std.net; -const Client = @This(); const Uri = std.Uri; -const Allocator = std.mem.Allocator; -const testing = std.testing; +const Allocator = mem.Allocator; +const assert = std.debug.assert; -pub const Request = @import("Client/Request.zig"); -pub const Response = @import("Client/Response.zig"); +const Client = @This(); +const proto = @import("protocol.zig"); pub const default_connection_pool_size = 32; -const connection_pool_size = std.options.http_connection_pool_size; +pub const connection_pool_size = std.options.http_connection_pool_size; -/// Used for tcpConnectToHost and storing HTTP headers when an externally -/// managed buffer is not provided. allocator: Allocator, ca_bundle: std.crypto.Certificate.Bundle = .{}, +ca_bundle_mutex: std.Thread.Mutex = .{}, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. next_https_rescan_certs: bool = true, +/// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, +/// The last error that occurred on this client. This is not threadsafe, do not expect it to be completely accurate. +last_error: ?ExtraError = null, + +pub const ExtraError = union(enum) { + fn impliedErrorSet(comptime f: anytype) type { + const set = @typeInfo(@typeInfo(@TypeOf(f)).Fn.return_type.?).ErrorUnion.error_set; + if (@typeName(set)[0] != '@') @compileError(@typeName(f) ++ " doesn't have an implied error set any more."); + return set; + } + + // There's apparently a dependency loop with using Client.DeflateDecompressor. + const FakeTransferError = proto.HeadersParser.ReadError || error{ReadFailed}; + const FakeTransferReader = std.io.Reader(void, FakeTransferError, fakeRead); + fn fakeRead(ctx: void, buf: []u8) FakeTransferError!usize { + _ = .{ buf, ctx }; + return 0; + } + + const FakeDeflateDecompressor = std.compress.zlib.ZlibStream(FakeTransferReader); + const FakeGzipDecompressor = std.compress.gzip.Decompress(FakeTransferReader); + const FakeZstdDecompressor = std.compress.zstd.DecompressStream(FakeTransferReader, .{}); + + pub const TcpConnectError = std.net.TcpConnectToHostError; + pub const TlsError = std.crypto.tls.Client.InitError(net.Stream); + pub const WriteError = BufferedConnection.WriteError; + pub const ReadError = BufferedConnection.ReadError || error{HttpChunkInvalid}; + pub const CaBundleError = impliedErrorSet(std.crypto.Certificate.Bundle.rescan); + + pub const ZlibInitError = error{ BadHeader, InvalidCompression, InvalidWindowSize, Unsupported, EndOfStream, OutOfMemory } || Request.TransferReadError; + pub const GzipInitError = error{ BadHeader, InvalidCompression, OutOfMemory, WrongChecksum, EndOfStream, StreamTooLong } || Request.TransferReadError; + // pub const DecompressError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error; + pub const DecompressError = FakeDeflateDecompressor.Error || FakeGzipDecompressor.Error || FakeZstdDecompressor.Error; + + zlib_init: ZlibInitError, // error.CompressionInitializationFailed + gzip_init: GzipInitError, // error.CompressionInitializationFailed + connect: TcpConnectError, // error.ConnectionFailed + ca_bundle: CaBundleError, // error.CertificateAuthorityBundleFailed + tls: TlsError, // error.TlsInitializationFailed + write: WriteError, // error.WriteFailed + read: ReadError, // error.ReadFailed + decompress: DecompressError, // error.ReadFailed +}; + +/// A set of linked lists of connections that can be reused. pub const ConnectionPool = struct { + /// The criteria for a connection to be considered a match. pub const Criteria = struct { host: []const u8, port: u16, is_tls: bool, }; - const Queue = std.TailQueue(Connection); + pub const StoredConnection = struct { + buffered: BufferedConnection, + host: []u8, + port: u16, + + closing: bool = false, + + pub fn deinit(self: *StoredConnection, client: *Client) void { + self.buffered.close(client); + client.allocator.free(self.host); + } + }; + + const Queue = std.TailQueue(StoredConnection); pub const Node = Queue.Node; mutex: std.Thread.Mutex = .{}, + /// Open connections that are currently in use. used: Queue = .{}, + /// Open connections that are not currently in use. free: Queue = .{}, free_len: usize = 0, - free_size: usize = default_connection_pool_size, + free_size: usize = connection_pool_size, /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. /// If no connection is found, null is returned. @@ -53,9 +109,9 @@ pub const ConnectionPool = struct { var next = pool.free.last; while (next) |node| : (next = node.prev) { - if ((node.data.protocol == .tls) != criteria.is_tls) continue; + if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue; if (node.data.port != criteria.port) continue; - if (std.mem.eql(u8, node.data.host, criteria.host)) continue; + if (mem.eql(u8, node.data.host, criteria.host)) continue; pool.acquireUnsafe(node); return node; @@ -89,7 +145,7 @@ pub const ConnectionPool = struct { pool.used.remove(node); if (node.data.closing) { - node.data.close(client); + node.data.deinit(client); return client.allocator.destroy(node); } @@ -97,7 +153,7 @@ pub const ConnectionPool = struct { if (pool.free_len + 1 >= pool.free_size) { const popped = pool.free.popFirst() orelse unreachable; - popped.data.close(client); + popped.data.deinit(client); return client.allocator.destroy(popped); } @@ -122,7 +178,7 @@ pub const ConnectionPool = struct { defer client.allocator.destroy(node); next = node.next; - node.data.close(client); + node.data.deinit(client); } next = pool.used.first; @@ -130,27 +186,19 @@ pub const ConnectionPool = struct { defer client.allocator.destroy(node); next = node.next; - node.data.close(client); + node.data.deinit(client); } pool.* = undefined; } }; -pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw); -pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw); -pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.ReaderRaw, .{}); - +/// An interface to either a plain or TLS connection. pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB. + tls_client: *std.crypto.tls.Client, protocol: Protocol, - host: []u8, - port: u16, - - // This connection has been part of a non keepalive request and cannot be added to the pool. - closing: bool = false, pub const Protocol = enum { plain, tls }; @@ -215,11 +263,611 @@ pub const Connection = struct { } conn.stream.close(); + } +}; + +/// A buffered (and peekable) Connection. +pub const BufferedConnection = struct { + pub const buffer_size = 0x2000; + + conn: Connection, + buf: [buffer_size]u8 = undefined, + start: u16 = 0, + end: u16 = 0, + + pub fn fill(bconn: *BufferedConnection) ReadError!void { + if (bconn.end != bconn.start) return; + + const nread = try bconn.conn.read(bconn.buf[0..]); + if (nread == 0) return error.EndOfStream; + bconn.start = 0; + bconn.end = @truncate(u16, nread); + } + + pub fn peek(bconn: *BufferedConnection) []const u8 { + return bconn.buf[bconn.start..bconn.end]; + } + + pub fn clear(bconn: *BufferedConnection, num: u16) void { + bconn.start += num; + } + + pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { + var out_index: u16 = 0; + while (out_index < len) { + const available = bconn.end - bconn.start; + const left = buffer.len - out_index; + + if (available > 0) { + const can_read = @truncate(u16, @min(available, left)); + + std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]); + out_index += can_read; + bconn.start += can_read; + + continue; + } + + if (left > bconn.buf.len) { + // skip the buffer if the output is large enough + return bconn.conn.read(buffer[out_index..]); + } + + try bconn.fill(); + } + + return out_index; + } + + pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize { + return bconn.readAtLeast(buffer, 1); + } + + pub const ReadError = Connection.ReadError || error{EndOfStream}; + pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read); + + pub fn reader(bconn: *BufferedConnection) Reader { + return Reader{ .context = bconn }; + } + + pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { + return bconn.conn.writeAll(buffer); + } + + pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { + return bconn.conn.write(buffer); + } + + pub const WriteError = Connection.WriteError; + pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write); + + pub fn writer(bconn: *BufferedConnection) Writer { + return Writer{ .context = bconn }; + } + + pub fn close(bconn: *BufferedConnection, client: *const Client) void { + bconn.conn.close(client); + } +}; + +/// The mode of transport for requests. +pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, +}; + +/// The decompressor for response messages. +pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.TransferReader); + pub const GzipDecompressor = std.compress.gzip.Decompress(Request.TransferReader); + pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{}); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, +}; + +/// A HTTP response originating from a server. +pub const Response = struct { + pub const Headers = struct { + status: http.Status, + version: http.Version, + location: ?[]const u8 = null, + content_length: ?u64 = null, + transfer_encoding: ?http.TransferEncoding = null, + transfer_compression: ?http.ContentEncoding = null, + connection: http.Connection = .close, + upgrade: ?[]const u8 = null, + + pub const ParseError = error{ + ShortHttpStatusLine, + BadHttpVersion, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionNotSupported, + }; + + pub fn parse(bytes: []const u8) ParseError!Headers { + var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n"); + + const first_line = it.next() orelse return error.HttpHeadersInvalid; + if (first_line.len < 12) + return error.ShortHttpStatusLine; + + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.BadHttpVersion, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); + + var headers: Headers = .{ + .version = version, + .status = status, + }; + + while (it.next()) |line| { + if (line.len == 0) return error.HttpHeadersInvalid; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + + var line_it = mem.tokenize(u8, line, ": "); + const header_name = line_it.next() orelse return error.HttpHeadersInvalid; + const header_value = line_it.rest(); + if (std.ascii.eqlIgnoreCase(header_name, "location")) { + if (headers.location != null) return error.HttpHeadersInvalid; + headers.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (headers.content_length != null) return error.HttpHeadersInvalid; + headers.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwards(u8, header_value, ","); + + if (iter.next()) |first| { + const trimmed = mem.trim(u8, first, " "); + + if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { + if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; + headers.transfer_encoding = te; + } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |second| { + if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + + const trimmed = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { + headers.connection = .keep_alive; + } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { + headers.connection = .close; + } else { + return error.HttpConnectionHeaderUnsupported; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) { + headers.upgrade = header_value; + } + } + + return headers; + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(u64, array.*); + } + + fn parseInt3(nnn: @Vector(3, u8)) u10 { + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); + } + + test parseInt3 { + const expectEqual = testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000".*)); + try expectEqual(@as(u10, 418), parseInt3("418".*)); + try expectEqual(@as(u10, 999), parseInt3("999".*)); + } + }; + + headers: Headers = undefined, + parser: proto.HeadersParser, + compression: Compression = .none, + skip: bool = false, +}; + +/// A HTTP request that has been sent. +/// +/// Order of operations: request[ -> write -> finish] -> do -> read +pub const Request = struct { + pub const Headers = struct { + version: http.Version = .@"HTTP/1.1", + method: http.Method = .GET, + user_agent: []const u8 = "zig (std.http)", + connection: http.Connection = .keep_alive, + transfer_encoding: RequestTransfer = .none, + + custom: []const http.CustomHeader = &[_]http.CustomHeader{}, + }; + + uri: Uri, + client: *Client, + connection: *ConnectionPool.Node, + /// These are stored in Request so that they are available when following + /// redirects. + headers: Headers, + + redirects_left: u32, + handle_redirects: bool, + + response: Response, + + /// Used as a allocator for resolving redirects locations. + arena: std.heap.ArenaAllocator, + + /// Frees all resources associated with the request. + pub fn deinit(req: *Request) void { + switch (req.response.compression) { + .none => {}, + .deflate => |*deflate| deflate.deinit(), + .gzip => |*gzip| gzip.deinit(), + .zstd => |*zstd| zstd.deinit(), + } + + if (req.response.parser.header_bytes_owned) { + req.response.parser.header_bytes.deinit(req.client.allocator); + } + + if (!req.response.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + req.connection.data.closing = true; + req.client.connection_pool.release(req.client, req.connection); + } + + req.arena.deinit(); + req.* = undefined; + } - client.allocator.free(conn.host); + pub fn start(req: *Request, uri: Uri, headers: Headers) !void { + var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer()); + const w = buffered.writer(); + + const escaped_path = try Uri.escapePath(req.client.allocator, uri.path); + defer req.client.allocator.free(escaped_path); + + const escaped_query = if (uri.query) |q| try Uri.escapeQuery(req.client.allocator, q) else null; + defer if (escaped_query) |q| req.client.allocator.free(q); + + const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(req.client.allocator, f) else null; + defer if (escaped_fragment) |f| req.client.allocator.free(f); + + try w.writeAll(@tagName(headers.method)); + try w.writeByte(' '); + if (escaped_path.len == 0) { + try w.writeByte('/'); + } else { + try w.writeAll(escaped_path); + } + if (escaped_query) |q| { + try w.writeByte('?'); + try w.writeAll(q); + } + if (escaped_fragment) |f| { + try w.writeByte('#'); + try w.writeAll(f); + } + try w.writeByte(' '); + try w.writeAll(@tagName(headers.version)); + try w.writeAll("\r\nHost: "); + try w.writeAll(uri.host.?); + try w.writeAll("\r\nUser-Agent: "); + try w.writeAll(headers.user_agent); + if (headers.connection == .close) { + try w.writeAll("\r\nConnection: close"); + } else { + try w.writeAll("\r\nConnection: keep-alive"); + } + try w.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd"); + try w.writeAll("\r\nTE: gzip, deflate"); // TODO: add trailers when someone finds a nice way to integrate them without completely invalidating all pointers to headers. + + switch (headers.transfer_encoding) { + .chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"), + .content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}), + .none => {}, + } + + for (headers.custom) |header| { + try w.writeAll("\r\n"); + try w.writeAll(header.name); + try w.writeAll(": "); + try w.writeAll(header.value); + } + + try w.writeAll("\r\n\r\n"); + + try buffered.flush(); + } + + pub const TransferReadError = proto.HeadersParser.ReadError || error{ReadFailed}; + + pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); + + pub fn transferReader(req: *Request) TransferReader { + return .{ .context = req }; + } + + pub fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { + if (req.response.parser.done) return 0; + + var index: usize = 0; + while (index == 0) { + const amt = req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip) catch |err| { + req.client.last_error = .{ .read = err }; + return error.ReadFailed; + }; + if (amt == 0 and req.response.parser.done) break; + index += amt; + } + + return index; + } + + pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.Headers.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed }; + + /// Waits for a response from the server and parses any headers that are sent. + /// This function will block until the final response is received. + /// + /// If `handle_redirects` is true, then this function will automatically follow + /// redirects. + pub fn do(req: *Request) DoError!void { + while (true) { // handle redirects + while (true) { // read headers + req.connection.data.buffered.fill() catch |err| { + req.client.last_error = .{ .read = err }; + return error.ReadFailed; + }; + + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); + req.connection.data.buffered.clear(@intCast(u16, nchecked)); + + if (req.response.parser.state.isContent()) break; + } + + req.response.headers = try Response.Headers.parse(req.response.parser.header_bytes.items); + + if (req.response.headers.status == .switching_protocols) { + req.connection.data.closing = false; + req.response.parser.done = true; + } + + if (req.headers.connection == .keep_alive and req.response.headers.connection == .keep_alive) { + req.connection.data.closing = false; + } else { + req.connection.data.closing = true; + } + + if (req.response.headers.transfer_encoding) |te| { + switch (te) { + .chunked => { + req.response.parser.next_chunk_length = 0; + req.response.parser.state = .chunk_head_size; + }, + } + } else if (req.response.headers.content_length) |cl| { + req.response.parser.next_chunk_length = cl; + + if (cl == 0) req.response.parser.done = true; + } else { + req.response.parser.done = true; + } + + if (req.response.headers.status.class() == .redirect and req.handle_redirects) { + req.response.skip = true; + + const empty = @as([*]u8, undefined)[0..0]; + assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary + + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location); + + var new_arena = std.heap.ArenaAllocator.init(req.client.allocator); + const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator()); + errdefer new_arena.deinit(); + + req.arena.deinit(); + req.arena = new_arena; + + const new_req = try req.client.request(resolved_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.parser.header_bytes_owned) .{ + .dynamic = req.response.parser.max_header_bytes, + } else .{ + .static = req.response.parser.header_bytes.items.ptr[0..req.response.parser.max_header_bytes], + }, + }); + req.deinit(); + req.* = new_req; + } else { + req.response.skip = false; + if (!req.response.parser.done) { + if (req.response.headers.transfer_compression) |tc| switch (tc) { + .compress => return error.CompressionNotSupported, + .deflate => req.response.compression = .{ + .deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch |err| { + req.client.last_error = .{ .zlib_init = err }; + return error.CompressionInitializationFailed; + }, + }, + .gzip => req.response.compression = .{ + .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch |err| { + req.client.last_error = .{ .gzip_init = err }; + return error.CompressionInitializationFailed; + }, + }, + .zstd => req.response.compression = .{ + .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), + }, + }; + } + + break; + } + } + } + + pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError; + + pub const Reader = std.io.Reader(*Request, ReadError, read); + + pub fn reader(req: *Request) Reader { + return .{ .context = req }; + } + + /// Reads data from the response body. Must be called after `do`. + pub fn read(req: *Request, buffer: []u8) ReadError!usize { + while (true) { + const out_index = switch (req.response.compression) { + .deflate => |*deflate| deflate.read(buffer) catch |err| { + req.client.last_error = .{ .decompress = err }; + err catch {}; + return error.ReadFailed; + }, + .gzip => |*gzip| gzip.read(buffer) catch |err| { + req.client.last_error = .{ .decompress = err }; + err catch {}; + return error.ReadFailed; + }, + .zstd => |*zstd| zstd.read(buffer) catch |err| { + req.client.last_error = .{ .decompress = err }; + err catch {}; + return error.ReadFailed; + }, + else => try req.transferRead(buffer), + }; + + if (out_index == 0) { + while (!req.response.parser.state.isContent()) { // read trailing headers + req.connection.data.buffered.fill() catch |err| { + req.client.last_error = .{ .read = err }; + return error.ReadFailed; + }; + + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); + req.connection.data.buffered.clear(@intCast(u16, nchecked)); + } + } + + return out_index; + } + } + + /// Reads data from the response body. Must be called after `do`. + pub fn readAll(req: *Request, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = read(req, buffer[index..]) catch |err| { + req.client.last_error = .{ .read = err }; + return error.ReadFailed; + }; + if (amt == 0) break; + index += amt; + } + return index; + } + + pub const WriteError = error{ WriteFailed, NotWriteable, MessageTooLong }; + + pub const Writer = std.io.Writer(*Request, WriteError, write); + + pub fn writer(req: *Request) Writer { + return .{ .context = req }; + } + + /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + pub fn write(req: *Request, bytes: []const u8) WriteError!usize { + switch (req.headers.transfer_encoding) { + .chunked => { + req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len}) catch |err| { + req.client.last_error = .{ .write = err }; + return error.WriteFailed; + }; + req.connection.data.conn.writeAll(bytes) catch |err| { + req.client.last_error = .{ .write = err }; + return error.WriteFailed; + }; + req.connection.data.conn.writeAll("\r\n") catch |err| { + req.client.last_error = .{ .write = err }; + return error.WriteFailed; + }; + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = req.connection.data.conn.write(bytes) catch |err| { + req.client.last_error = .{ .write = err }; + return error.WriteFailed; + }; + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } + } + + /// Finish the body of a request. This notifies the server that you have no more data to send. + pub fn finish(req: *Request) !void { + switch (req.headers.transfer_encoding) { + .chunked => req.connection.data.conn.writeAll("0\r\n") catch |err| { + req.client.last_error = .{ .write = err }; + return error.WriteFailed; + }, + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } } }; +/// Release all associated resources with the client. +/// TODO: currently leaks all request allocated data pub fn deinit(client: *Client) void { client.connection_pool.deinit(client); @@ -227,8 +875,10 @@ pub fn deinit(client: *Client) void { client.* = undefined; } -pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream); +pub const ConnectError = Allocator.Error || error{ ConnectionFailed, TlsInitializationFailed }; +/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// This function is threadsafe. pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node { if (client.connection_pool.findConnection(.{ .host = host, @@ -241,22 +891,36 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; + const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| { + client.last_error = .{ .connect = err }; + return error.ConnectionFailed; + }; + errdefer stream.close(); + conn.data = .{ - .stream = try net.tcpConnectToHost(client.allocator, host, port), - .tls_client = undefined, - .protocol = protocol, + .buffered = .{ .conn = .{ + .stream = stream, + .tls_client = undefined, + .protocol = protocol, + } }, .host = try client.allocator.dupe(u8, host), .port = port, }; + errdefer client.allocator.free(conn.data.host); switch (protocol) { .plain => {}, .tls => { - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); - conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host); + conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client); + errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client); + + conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch |err| { + client.last_error = .{ .tls = err }; + return error.TlsInitializationFailed; + }; // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; + conn.data.buffered.conn.tls_client.allow_truncation_attacks = true; }, } @@ -265,24 +929,44 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio return conn; } -pub const RequestError = ConnectError || Connection.WriteError || error{ +pub const RequestError = ConnectError || error{ UnsupportedUrlScheme, UriMissingHost, - CertificateAuthorityBundleTooBig, - InvalidPadding, - MissingEndCertificateMarker, - Unseekable, - EndOfStream, + CertificateAuthorityBundleFailed, + WriteFailed, +}; + +pub const Options = struct { + handle_redirects: bool = true, + max_redirects: u32 = 3, + header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, + + pub const HeaderStrategy = union(enum) { + /// In this case, the client's Allocator will be used to store the + /// entire HTTP header. This value is the maximum total size of + /// HTTP headers allowed, otherwise + /// error.HttpHeadersExceededSizeLimit is returned from read(). + dynamic: usize, + /// This is used to store the entire HTTP header. If the HTTP + /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` + /// is returned from read(). When this is used, `error.OutOfMemory` + /// cannot be returned from `read()`. + static: []u8, + }; }; -pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request { - const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http")) - .plain - else if (mem.eql(u8, uri.scheme, "https")) - .tls - else - return error.UnsupportedUrlScheme; +pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{ + .{ "http", .plain }, + .{ "ws", .plain }, + .{ "https", .tls }, + .{ "wss", .tls }, +}); + +/// Form and send a http request to a server. +/// This function is threadsafe. +pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Options) RequestError!Request { + const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = uri.port orelse switch (protocol) { .plain => 80, @@ -291,91 +975,45 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req const host = uri.host orelse return error.UriMissingHost; - if (client.next_https_rescan_certs and protocol == .tls) { - client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex. - defer client.connection_pool.mutex.unlock(); + if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) { + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); if (client.next_https_rescan_certs) { - try client.ca_bundle.rescan(client.allocator); - client.next_https_rescan_certs = false; + client.ca_bundle.rescan(client.allocator) catch |err| { + client.last_error = .{ .ca_bundle = err }; + return error.CertificateAuthorityBundleFailed; + }; + @atomicStore(bool, &client.next_https_rescan_certs, false, .Release); } } var req: Request = .{ .uri = uri, .client = client, - .headers = headers, .connection = try client.connect(host, port, protocol), + .headers = headers, .redirects_left = options.max_redirects, .handle_redirects = options.handle_redirects, - .compression_init = false, - .response = switch (options.header_strategy) { - .dynamic => |max| Response.initDynamic(max), - .static => |buf| Response.initStatic(buf), + .response = .{ + .parser = switch (options.header_strategy) { + .dynamic => |max| proto.HeadersParser.initDynamic(max), + .static => |buf| proto.HeadersParser.initStatic(buf), + }, }, .arena = undefined, }; + errdefer req.deinit(); req.arena = std.heap.ArenaAllocator.init(client.allocator); - { - var buffered = std.io.bufferedWriter(req.connection.data.writer()); - const writer = buffered.writer(); - - const escaped_path = try Uri.escapePath(client.allocator, uri.path); - defer client.allocator.free(escaped_path); - - const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null; - defer if (escaped_query) |q| client.allocator.free(q); + req.start(uri, headers) catch |err| { + if (err == error.OutOfMemory) return error.OutOfMemory; + const err_casted = @errSetCast(BufferedConnection.WriteError, err); - const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null; - defer if (escaped_fragment) |f| client.allocator.free(f); - - try writer.writeAll(@tagName(headers.method)); - try writer.writeByte(' '); - if (escaped_path.len == 0) { - try writer.writeByte('/'); - } else { - try writer.writeAll(escaped_path); - } - if (escaped_query) |q| { - try writer.writeByte('?'); - try writer.writeAll(q); - } - if (escaped_fragment) |f| { - try writer.writeByte('#'); - try writer.writeAll(f); - } - try writer.writeByte(' '); - try writer.writeAll(@tagName(headers.version)); - try writer.writeAll("\r\nHost: "); - try writer.writeAll(host); - try writer.writeAll("\r\nUser-Agent: "); - try writer.writeAll(headers.user_agent); - if (headers.connection == .close) { - try writer.writeAll("\r\nConnection: close"); - } else { - try writer.writeAll("\r\nConnection: keep-alive"); - } - try writer.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd"); - - switch (headers.transfer_encoding) { - .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"), - .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}), - .none => {}, - } - - for (headers.custom) |header| { - try writer.writeAll("\r\n"); - try writer.writeAll(header.name); - try writer.writeAll(": "); - try writer.writeAll(header.value); - } - - try writer.writeAll("\r\n\r\n"); - - try buffered.flush(); - } + client.last_error = .{ .write = err_casted }; + return error.WriteFailed; + }; return req; } @@ -390,5 +1028,5 @@ test { if (builtin.os.tag == .wasi) return error.SkipZigTest; - _ = Request; + std.testing.refAllDecls(@This()); } diff --git a/lib/std/http/Client/Request.zig b/lib/std/http/Client/Request.zig deleted file mode 100644 index 9e2ebd2d6c..0000000000 --- a/lib/std/http/Client/Request.zig +++ /dev/null @@ -1,482 +0,0 @@ -const std = @import("std"); -const http = std.http; -const Uri = std.Uri; -const mem = std.mem; -const assert = std.debug.assert; - -const Client = @import("../Client.zig"); -const Connection = Client.Connection; -const ConnectionNode = Client.ConnectionPool.Node; -const Response = @import("Response.zig"); - -const Request = @This(); - -const read_buffer_size = 8192; -const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size); - -uri: Uri, -client: *Client, -connection: *ConnectionNode, -response: Response, -/// These are stored in Request so that they are available when following -/// redirects. -headers: Headers, - -redirects_left: u32, -handle_redirects: bool, -compression_init: bool, - -/// Used as a allocator for resolving redirects locations. -arena: std.heap.ArenaAllocator, - -/// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning. -read_buffer: [read_buffer_size]u8 = undefined, -read_buffer_start: ReadBufferIndex = 0, -read_buffer_len: ReadBufferIndex = 0, - -pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -pub const Headers = struct { - version: http.Version = .@"HTTP/1.1", - method: http.Method = .GET, - user_agent: []const u8 = "zig (std.http)", - connection: http.Connection = .keep_alive, - transfer_encoding: RequestTransfer = .none, - - custom: []const http.CustomHeader = &[_]http.CustomHeader{}, -}; - -pub const Options = struct { - handle_redirects: bool = true, - max_redirects: u32 = 3, - header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, - - pub const HeaderStrategy = union(enum) { - /// In this case, the client's Allocator will be used to store the - /// entire HTTP header. This value is the maximum total size of - /// HTTP headers allowed, otherwise - /// error.HttpHeadersExceededSizeLimit is returned from read(). - dynamic: usize, - /// This is used to store the entire HTTP header. If the HTTP - /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` - /// is returned from read(). When this is used, `error.OutOfMemory` - /// cannot be returned from `read()`. - static: []u8, - }; -}; - -/// Frees all resources associated with the request. -pub fn deinit(req: *Request) void { - switch (req.response.compression) { - .none => {}, - .deflate => |*deflate| deflate.deinit(), - .gzip => |*gzip| gzip.deinit(), - .zstd => |*zstd| zstd.deinit(), - } - - if (req.response.header_bytes_owned) { - req.response.header_bytes.deinit(req.client.allocator); - } - - if (!req.response.done) { - // If the response wasn't fully read, then we need to close the connection. - req.connection.data.closing = true; - req.client.connection_pool.release(req.client, req.connection); - } - - req.arena.deinit(); - req.* = undefined; -} - -pub const ReadRawError = Connection.ReadError || Uri.ParseError || Client.RequestError || error{ - UnexpectedEndOfStream, - TooManyHttpRedirects, - HttpRedirectMissingLocation, - HttpHeadersInvalid, -}; - -pub const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw); - -/// Read from the underlying stream, without decompressing or parsing the headers. Must be called -/// after waitForCompleteHead() has returned successfully. -pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize { - assert(req.response.state.isContent()); - - var index: usize = 0; - while (index == 0) { - const amt = try req.readRawAdvanced(buffer[index..]); - if (amt == 0 and req.response.done) break; - index += amt; - } - - return index; -} - -fn checkForCompleteHead(req: *Request, buffer: []u8) !usize { - switch (req.response.state) { - .invalid => unreachable, - .start, .seen_r, .seen_rn, .seen_rnr => {}, - else => return 0, // No more headers to read. - } - - const i = req.response.findHeadersEnd(buffer[0..]); - if (req.response.state == .invalid) return error.HttpHeadersInvalid; - - const headers_data = buffer[0..i]; - if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { - return error.HttpHeadersExceededSizeLimit; - } - try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); - - if (req.response.state == .finished) { - req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); - - if (req.response.headers.upgrade) |_| { - req.connection.data.closing = false; - req.response.done = true; - return i; - } - - if (req.response.headers.connection == .keep_alive) { - req.connection.data.closing = false; - } else { - req.connection.data.closing = true; - } - - if (req.response.headers.transfer_encoding) |transfer_encoding| { - switch (transfer_encoding) { - .chunked => { - req.response.next_chunk_length = 0; - req.response.state = .chunk_size; - }, - } - } else if (req.response.headers.content_length) |content_length| { - req.response.next_chunk_length = content_length; - - if (content_length == 0) req.response.done = true; - } else { - req.response.done = true; - } - - return i; - } - - return 0; -} - -pub const WaitForCompleteHeadError = ReadRawError || error{ - UnexpectedEndOfStream, - - HttpHeadersExceededSizeLimit, - ShortHttpStatusLine, - BadHttpVersion, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, -}; - -/// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent. -pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void { - if (req.response.state.isContent()) return; - - while (true) { - const nread = try req.connection.data.read(req.read_buffer[0..]); - const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]); - - if (amt != 0) { - req.read_buffer_start = @intCast(ReadBufferIndex, amt); - req.read_buffer_len = @intCast(ReadBufferIndex, nread); - return; - } else if (nread == 0) { - return error.UnexpectedEndOfStream; - } - } -} - -/// This one can return 0 without meaning EOF. -fn readRawAdvanced(req: *Request, buffer: []u8) !usize { - assert(req.response.state.isContent()); - if (req.response.done) return 0; - - // var in: []const u8 = undefined; - if (req.read_buffer_start == req.read_buffer_len) { - const nread = try req.connection.data.read(req.read_buffer[0..]); - if (nread == 0) return error.UnexpectedEndOfStream; - - req.read_buffer_start = 0; - req.read_buffer_len = @intCast(ReadBufferIndex, nread); - } - - var out_index: usize = 0; - while (true) { - switch (req.response.state) { - .invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable, - .finished => { - // TODO https://github.com/ziglang/zig/issues/14039 - const buf_avail = req.read_buffer_len - req.read_buffer_start; - const data_avail = req.response.next_chunk_length; - const out_avail = buffer.len; - - if (req.handle_redirects and req.response.headers.status.class() == .redirect) { - const can_read = @intCast(usize, @min(buf_avail, data_avail)); - req.response.next_chunk_length -= can_read; - - if (req.response.next_chunk_length == 0) { - req.client.connection_pool.release(req.client, req.connection); - req.connection = undefined; - req.response.done = true; - } - - return 0; // skip over as much data as possible - } - - const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); - req.response.next_chunk_length -= can_read; - - mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]); - req.read_buffer_start += @intCast(ReadBufferIndex, can_read); - - if (req.response.next_chunk_length == 0) { - req.client.connection_pool.release(req.client, req.connection); - req.connection = undefined; - req.response.done = true; - } - - return can_read; - }, - .chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) { - 0 => return out_index, - 1 => switch (req.read_buffer[req.read_buffer_start]) { - '\r' => { - req.response.state = .chunk_size_prefix_n; - return out_index; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) { - int16("\r\n") => { - req.read_buffer_start += 2; - req.response.state = .chunk_size; - continue; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - }, - .chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) { - 0 => return out_index, - else => switch (req.read_buffer[req.read_buffer_start]) { - '\n' => { - req.read_buffer_start += 1; - req.response.state = .chunk_size; - continue; - }, - else => { - req.response.state = .invalid; - return error.HttpHeadersInvalid; - }, - }, - }, - .chunk_size, .chunk_r => { - const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]); - switch (req.response.state) { - .invalid => return error.HttpHeadersInvalid, - .chunk_data => { - if (req.response.next_chunk_length == 0) { - req.response.done = true; - req.client.connection_pool.release(req.client, req.connection); - req.connection = undefined; - - return out_index; - } - - req.read_buffer_start += @intCast(ReadBufferIndex, i); - continue; - }, - .chunk_size => return out_index, - else => unreachable, - } - }, - .chunk_data => { - // TODO https://github.com/ziglang/zig/issues/14039 - const buf_avail = req.read_buffer_len - req.read_buffer_start; - const data_avail = req.response.next_chunk_length; - const out_avail = buffer.len - out_index; - - if (req.handle_redirects and req.response.headers.status.class() == .redirect) { - const can_read = @intCast(usize, @min(buf_avail, data_avail)); - req.response.next_chunk_length -= can_read; - - if (req.response.next_chunk_length == 0) { - req.client.connection_pool.release(req.client, req.connection); - req.connection = undefined; - req.response.done = true; - continue; - } - - return 0; // skip over as much data as possible - } - - const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail)); - req.response.next_chunk_length -= can_read; - - mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]); - req.read_buffer_start += @intCast(ReadBufferIndex, can_read); - out_index += can_read; - - if (req.response.next_chunk_length == 0) { - req.response.state = .chunk_size_prefix_r; - - continue; - } - - return out_index; - }, - } - } -} - -pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported }; - -pub const Reader = std.io.Reader(*Request, ReadError, read); - -pub fn reader(req: *Request) Reader { - return .{ .context = req }; -} - -pub fn read(req: *Request, buffer: []u8) ReadError!usize { - while (true) { - if (!req.response.state.isContent()) try req.waitForCompleteHead(); - - if (req.handle_redirects and req.response.headers.status.class() == .redirect) { - assert(try req.readRaw(buffer) == 0); - - if (req.redirects_left == 0) return error.TooManyHttpRedirects; - - const location = req.response.headers.location orelse - return error.HttpRedirectMissingLocation; - const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location); - - var new_arena = std.heap.ArenaAllocator.init(req.client.allocator); - const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator()); - errdefer new_arena.deinit(); - - req.arena.deinit(); - req.arena = new_arena; - - const new_req = try req.client.request(resolved_url, req.headers, .{ - .max_redirects = req.redirects_left - 1, - .header_strategy = if (req.response.header_bytes_owned) .{ - .dynamic = req.response.max_header_bytes, - } else .{ - .static = req.response.header_bytes.unusedCapacitySlice(), - }, - }); - req.deinit(); - req.* = new_req; - } else { - break; - } - } - - if (req.response.compression == .none) { - if (req.response.headers.transfer_compression) |compression| { - switch (compression) { - .compress => return error.CompressionNotSupported, - .deflate => req.response.compression = .{ - .deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }), - }, - .gzip => req.response.compression = .{ - .gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }), - }, - .zstd => req.response.compression = .{ - .zstd = std.compress.zstd.decompressStream(req.client.allocator, ReaderRaw{ .context = req }), - }, - } - } - } - - return switch (req.response.compression) { - .deflate => |*deflate| try deflate.read(buffer), - .gzip => |*gzip| try gzip.read(buffer), - .zstd => |*zstd| try zstd.read(buffer), - else => try req.readRaw(buffer), - }; -} - -pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; -} - -pub const WriteError = Connection.WriteError || error{MessageTooLong}; - -pub const Writer = std.io.Writer(*Request, WriteError, write); - -pub fn writer(req: *Request) Writer { - return .{ .context = req }; -} - -/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. -pub fn write(req: *Request, bytes: []const u8) !usize { - switch (req.headers.transfer_encoding) { - .chunked => { - try req.connection.data.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.data.writeAll(bytes); - try req.connection.data.writeAll("\r\n"); - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try req.connection.data.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } -} - -/// Finish the body of a request. This notifies the server that you have no more data to send. -pub fn finish(req: *Request) !void { - switch (req.headers.transfer_encoding) { - .chunked => try req.connection.data.writeAll("0\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } -} - -inline fn int16(array: *const [2]u8) u16 { - return @bitCast(u16, array.*); -} - -inline fn int32(array: *const [4]u8) u32 { - return @bitCast(u32, array.*); -} - -inline fn int64(array: *const [8]u8) u64 { - return @bitCast(u64, array.*); -} - -test { - const builtin = @import("builtin"); - - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - _ = Response; -} diff --git a/lib/std/http/Client/Response.zig b/lib/std/http/Client/Response.zig deleted file mode 100644 index f1a3b07dd8..0000000000 --- a/lib/std/http/Client/Response.zig +++ /dev/null @@ -1,509 +0,0 @@ -const std = @import("std"); -const http = std.http; -const mem = std.mem; -const testing = std.testing; -const assert = std.debug.assert; - -const Client = @import("../Client.zig"); -const Response = @This(); - -headers: Headers, -state: State, -header_bytes_owned: bool, -/// This could either be a fixed buffer provided by the API user or it -/// could be our own array list. -header_bytes: std.ArrayListUnmanaged(u8), -max_header_bytes: usize, -next_chunk_length: u64, -done: bool = false, - -compression: union(enum) { - deflate: Client.DeflateDecompressor, - gzip: Client.GzipDecompressor, - zstd: Client.ZstdDecompressor, - none: void, -} = .none, - -pub const Headers = struct { - status: http.Status, - version: http.Version, - location: ?[]const u8 = null, - content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - transfer_compression: ?http.ContentEncoding = null, - connection: http.Connection = .close, - upgrade: ?[]const u8 = null, - - number_of_headers: usize = 0, - - pub fn parse(bytes: []const u8) !Headers { - var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); - - const first_line = it.first(); - if (first_line.len < 12) - return error.ShortHttpStatusLine; - - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); - - var headers: Headers = .{ - .version = version, - .status = status, - }; - - while (it.next()) |line| { - headers.number_of_headers += 1; - - if (line.len == 0) return error.HttpHeadersInvalid; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - var line_it = mem.split(u8, line, ": "); - const header_name = line_it.first(); - const header_value = line_it.rest(); - if (std.ascii.eqlIgnoreCase(header_name, "location")) { - if (headers.location != null) return error.HttpHeadersInvalid; - headers.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (headers.content_length != null) return error.HttpHeadersInvalid; - headers.content_length = try std.fmt.parseInt(u64, header_value, 10); - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = std.mem.splitBackwards(u8, header_value, ","); - - if (iter.next()) |first| { - const trimmed = std.mem.trim(u8, first, " "); - - if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { - if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; - headers.transfer_encoding = te; - } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - if (headers.transfer_compression != null) return error.HttpHeadersInvalid; - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |second| { - if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; - - const trimmed = std.mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (headers.transfer_compression != null) return error.HttpHeadersInvalid; - - const trimmed = std.mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { - headers.connection = .keep_alive; - } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { - headers.connection = .close; - } else { - return error.HttpConnectionHeaderUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) { - headers.upgrade = header_value; - } - } - - return headers; - } - - test "parse headers" { - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - const parsed = try Headers.parse(example); - try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version); - try testing.expectEqual(http.Status.moved_permanently, parsed.status); - try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse - return error.TestFailed); - try testing.expectEqual(@as(?u64, 220), parsed.content_length); - } - - test "header continuation" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeaderContinuationsUnsupported, - Headers.parse(example), - ); - } - - test "extra content length" { - const example = - "HTTP/1.0 200 OK\r\n" ++ - "Content-Length: 220\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "content-length: 220\r\n\r\n"; - try testing.expectError( - error.HttpHeadersInvalid, - Headers.parse(example), - ); - } -}; - -inline fn int16(array: *const [2]u8) u16 { - return @bitCast(u16, array.*); -} - -inline fn int32(array: *const [4]u8) u32 { - return @bitCast(u32, array.*); -} - -inline fn int64(array: *const [8]u8) u64 { - return @bitCast(u64, array.*); -} - -pub const State = enum { - /// Begin header parsing states. - invalid, - start, - seen_r, - seen_rn, - seen_rnr, - finished, - /// Begin transfer-encoding: chunked parsing states. - chunk_size_prefix_r, - chunk_size_prefix_n, - chunk_size, - chunk_r, - chunk_data, - - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true, - }; - } -}; - -pub fn initDynamic(max: usize) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{}, - .max_header_bytes = max, - .header_bytes_owned = true, - .next_chunk_length = undefined, - }; -} - -pub fn initStatic(buf: []u8) Response { - return .{ - .state = .start, - .headers = undefined, - .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, - .max_header_bytes = buf.len, - .header_bytes_owned = false, - .next_chunk_length = undefined, - }; -} - -/// Returns how many bytes are part of HTTP headers. Always less than or -/// equal to bytes.len. If the amount returned is less than bytes.len, it -/// means the headers ended and the first byte after the double \r\n\r\n is -/// located at `bytes[result]`. -pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize { - var index: usize = 0; - - // TODO: https://github.com/ziglang/zig/issues/8220 - state: while (true) { - switch (r.state) { - .invalid => unreachable, - .finished => unreachable, - .start => while (true) { - switch (bytes.len - index) { - 0 => return index, - 1 => { - if (bytes[index] == '\r') - r.state = .seen_r; - return index + 1; - }, - 2 => { - if (int16(bytes[index..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - } else if (bytes[index + 1] == '\r') { - r.state = .seen_r; - } - return index + 2; - }, - 3 => { - if (int16(bytes[index..][0..2]) == int16("\r\n") and - bytes[index + 2] == '\r') - { - r.state = .seen_rnr; - } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - } else if (bytes[index + 2] == '\r') { - r.state = .seen_r; - } - return index + 3; - }, - 4...15 => { - if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) { - r.state = .finished; - return index + 4; - } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and - bytes[index + 3] == '\r') - { - r.state = .seen_rnr; - index += 4; - continue :state; - } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) { - r.state = .seen_rn; - index += 4; - continue :state; - } else if (bytes[index + 3] == '\r') { - r.state = .seen_r; - index += 4; - continue :state; - } - index += 4; - continue; - }, - else => { - const chunk = bytes[index..][0..16]; - const v: @Vector(16, u8) = chunk.*; - const matches_r = v == @splat(16, @as(u8, '\r')); - const iota = std.simd.iota(u8, 16); - const default = @splat(16, @as(u8, 16)); - const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default)); - switch (sub_index) { - 0...12 => { - index += sub_index + 4; - if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) { - r.state = .finished; - return index; - } - continue; - }, - 13 => { - index += 16; - if (int16(chunk[14..][0..2]) == int16("\n\r")) { - r.state = .seen_rnr; - continue :state; - } - continue; - }, - 14 => { - index += 16; - if (chunk[15] == '\n') { - r.state = .seen_rn; - continue :state; - } - continue; - }, - 15 => { - r.state = .seen_r; - index += 16; - continue :state; - }, - 16 => { - index += 16; - continue; - }, - else => unreachable, - } - }, - } - }, - - .seen_r => switch (bytes.len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\n' => r.state = .seen_rn, - '\r' => r.state = .seen_r, - else => r.state = .start, - } - return index + 1; - }, - 2 => { - if (int16(bytes[index..][0..2]) == int16("\n\r")) { - r.state = .seen_rnr; - return index + 2; - } - r.state = .start; - return index + 2; - }, - else => { - if (int16(bytes[index..][0..2]) == int16("\n\r") and - bytes[index + 2] == '\n') - { - r.state = .finished; - return index + 3; - } - index += 3; - r.state = .start; - continue :state; - }, - }, - .seen_rn => switch (bytes.len - index) { - 0 => return index, - 1 => { - switch (bytes[index]) { - '\r' => r.state = .seen_rnr, - else => r.state = .start, - } - return index + 1; - }, - else => { - if (int16(bytes[index..][0..2]) == int16("\r\n")) { - r.state = .finished; - return index + 2; - } - index += 2; - r.state = .start; - continue :state; - }, - }, - .seen_rnr => switch (bytes.len - index) { - 0 => return index, - else => { - if (bytes[index] == '\n') { - r.state = .finished; - return index + 1; - } - index += 1; - r.state = .start; - continue :state; - }, - }, - .chunk_size_prefix_r => unreachable, - .chunk_size_prefix_n => unreachable, - .chunk_size => unreachable, - .chunk_r => unreachable, - .chunk_data => unreachable, - } - - return index; - } -} - -pub fn findChunkedLen(r: *Response, bytes: []const u8) usize { - var i: usize = 0; - if (r.state == .chunk_size) { - while (i < bytes.len) : (i += 1) { - const digit = switch (bytes[i]) { - '0'...'9' => |b| b - '0', - 'A'...'Z' => |b| b - 'A' + 10, - 'a'...'z' => |b| b - 'a' + 10, - '\r' => { - r.state = .chunk_r; - i += 1; - break; - }, - else => { - r.state = .invalid; - return i; - }, - }; - const mul = @mulWithOverflow(r.next_chunk_length, 16); - if (mul[1] != 0) { - r.state = .invalid; - return i; - } - const add = @addWithOverflow(mul[0], digit); - if (add[1] != 0) { - r.state = .invalid; - return i; - } - r.next_chunk_length = add[0]; - } else { - return i; - } - } - assert(r.state == .chunk_r); - if (i == bytes.len) return i; - - if (bytes[i] == '\n') { - r.state = .chunk_data; - return i + 1; - } else { - r.state = .invalid; - return i; - } -} - -fn parseInt3(nnn: @Vector(3, u8)) u10 { - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); -} - -test parseInt3 { - const expectEqual = std.testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000".*)); - try expectEqual(@as(u10, 418), parseInt3("418".*)); - try expectEqual(@as(u10, 999), parseInt3("999".*)); -} - -test "find headers end basic" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4")); - try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18")); - try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah")); -} - -test "find headers end vectorized" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const example = - "HTTP/1.1 301 Moved Permanently\r\n" ++ - "Location: https://www.example.com/\r\n" ++ - "Content-Type: text/html; charset=UTF-8\r\n" ++ - "Content-Length: 220\r\n" ++ - "\r\ncontent"; - try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example)); -} - -test "find headers end bug" { - var buffer: [1]u8 = undefined; - var r = Response.initStatic(&buffer); - const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - const example = - "HTTP/1.1 200 OK\r\n" ++ - "Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++ - "content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++ - "Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++ - "Content-Type: application/x-gzip\r\n" ++ - "ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++ - "Strict-Transport-Security: max-age=31536000\r\n" ++ - "Vary: Authorization,Accept-Encoding,Origin\r\n" ++ - "X-Content-Type-Options: nosniff\r\n" ++ - "X-Frame-Options: deny\r\n" ++ - "X-XSS-Protection: 1; mode=block\r\n" ++ - "Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++ - "Transfer-Encoding: chunked\r\n" ++ - "X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++ - "connection: close\r\n\r\n" ++ trail; - try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example)); -} diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig new file mode 100644 index 0000000000..85fbf25265 --- /dev/null +++ b/lib/std/http/Server.zig @@ -0,0 +1,600 @@ +const std = @import("../std.zig"); +const testing = std.testing; +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const Allocator = mem.Allocator; +const assert = std.debug.assert; + +const Server = @This(); +const proto = @import("protocol.zig"); + +allocator: Allocator, + +socket: net.StreamServer, + +/// An interface to either a plain or TLS connection. +pub const Connection = struct { + stream: net.Stream, + protocol: Protocol, + + closing: bool = true, + + pub const Protocol = enum { plain }; + + pub fn read(conn: *Connection, buffer: []u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.read(buffer), + // .tls => return conn.tls_client.read(conn.stream, buffer), + } + } + + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize { + switch (conn.protocol) { + .plain => return conn.stream.readAtLeast(buffer, len), + // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), + } + } + + pub const ReadError = net.Stream.ReadError; + + pub const Reader = std.io.Reader(*Connection, ReadError, read); + + pub fn reader(conn: *Connection) Reader { + return Reader{ .context = conn }; + } + + pub fn writeAll(conn: *Connection, buffer: []const u8) !void { + switch (conn.protocol) { + .plain => return conn.stream.writeAll(buffer), + // .tls => return conn.tls_client.writeAll(conn.stream, buffer), + } + } + + pub fn write(conn: *Connection, buffer: []const u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.write(buffer), + // .tls => return conn.tls_client.write(conn.stream, buffer), + } + } + + pub const WriteError = net.Stream.WriteError || error{}; + pub const Writer = std.io.Writer(*Connection, WriteError, write); + + pub fn writer(conn: *Connection) Writer { + return Writer{ .context = conn }; + } + + pub fn close(conn: *Connection) void { + conn.stream.close(); + } +}; + +/// A buffered (and peekable) Connection. +pub const BufferedConnection = struct { + pub const buffer_size = 0x2000; + + conn: Connection, + buf: [buffer_size]u8 = undefined, + start: u16 = 0, + end: u16 = 0, + + pub fn fill(bconn: *BufferedConnection) ReadError!void { + if (bconn.end != bconn.start) return; + + const nread = try bconn.conn.read(bconn.buf[0..]); + if (nread == 0) return error.EndOfStream; + bconn.start = 0; + bconn.end = @truncate(u16, nread); + } + + pub fn peek(bconn: *BufferedConnection) []const u8 { + return bconn.buf[bconn.start..bconn.end]; + } + + pub fn clear(bconn: *BufferedConnection, num: u16) void { + bconn.start += num; + } + + pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { + var out_index: u16 = 0; + while (out_index < len) { + const available = bconn.end - bconn.start; + const left = buffer.len - out_index; + + if (available > 0) { + const can_read = @truncate(u16, @min(available, left)); + + std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]); + out_index += can_read; + bconn.start += can_read; + + continue; + } + + if (left > bconn.buf.len) { + // skip the buffer if the output is large enough + return bconn.conn.read(buffer[out_index..]); + } + + try bconn.fill(); + } + + return out_index; + } + + pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize { + return bconn.readAtLeast(buffer, 1); + } + + pub const ReadError = Connection.ReadError || error{EndOfStream}; + pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read); + + pub fn reader(bconn: *BufferedConnection) Reader { + return Reader{ .context = bconn }; + } + + pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { + return bconn.conn.writeAll(buffer); + } + + pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { + return bconn.conn.write(buffer); + } + + pub const WriteError = Connection.WriteError; + pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write); + + pub fn writer(bconn: *BufferedConnection) Writer { + return Writer{ .context = bconn }; + } + + pub fn close(bconn: *BufferedConnection) void { + bconn.conn.close(); + } +}; + +/// A HTTP request originating from a client. +pub const Request = struct { + pub const Headers = struct { + method: http.Method, + target: []const u8, + version: http.Version, + content_length: ?u64 = null, + transfer_encoding: ?http.TransferEncoding = null, + transfer_compression: ?http.ContentEncoding = null, + connection: http.Connection = .close, + host: ?[]const u8 = null, + + pub const ParseError = error{ + ShortHttpStatusLine, + BadHttpVersion, + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidCharacter, + }; + + pub fn parse(bytes: []const u8) !Headers { + var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n"); + + const first_line = it.next() orelse return error.HttpHeadersInvalid; + if (first_line.len < 10) + return error.ShortHttpStatusLine; + + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; + const method_str = first_line[0..method_end]; + const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod; + + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; + + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.BadHttpVersion, + }; + + const target = first_line[method_end + 1 .. version_start]; + + var headers: Headers = .{ + .method = method, + .target = target, + .version = version, + }; + + while (it.next()) |line| { + if (line.len == 0) return error.HttpHeadersInvalid; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + + var line_it = mem.tokenize(u8, line, ": "); + const header_name = line_it.next() orelse return error.HttpHeadersInvalid; + const header_value = line_it.rest(); + if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (headers.content_length != null) return error.HttpHeadersInvalid; + headers.content_length = try std.fmt.parseInt(u64, header_value, 10); + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwards(u8, header_value, ","); + + if (iter.next()) |first| { + const trimmed = mem.trim(u8, first, " "); + + if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { + if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; + headers.transfer_encoding = te; + } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |second| { + if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + + const trimmed = mem.trim(u8, second, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } + + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + headers.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { + headers.connection = .keep_alive; + } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { + headers.connection = .close; + } else { + return error.HttpConnectionHeaderUnsupported; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "host")) { + headers.host = header_value; + } + } + + return headers; + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(u64, array.*); + } + }; + + headers: Headers = undefined, + parser: proto.HeadersParser, + compression: Compression = .none, +}; + +/// A HTTP response waiting to be sent. +/// +/// [/ <----------------------------------- \] +/// Order of operations: accept -> wait -> do [ -> write -> finish][ -> reset /] +/// \ -> read / +pub const Response = struct { + pub const Headers = struct { + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, + + server: ?[]const u8 = "zig (std.http)", + connection: http.Connection = .keep_alive, + transfer_encoding: RequestTransfer = .none, + + custom: []const http.CustomHeader = &[_]http.CustomHeader{}, + }; + + server: *Server, + address: net.Address, + connection: BufferedConnection, + + headers: Headers = .{}, + request: Request, + + /// Reset this response to its initial state. This must be called before handling a second request on the same connection. + pub fn reset(res: *Response) void { + switch (res.request.compression) { + .none => {}, + .deflate => |*deflate| deflate.deinit(), + .gzip => |*gzip| gzip.deinit(), + .zstd => |*zstd| zstd.deinit(), + } + + if (!res.request.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + res.connection.conn.closing = true; + } + + if (res.connection.conn.closing) { + res.connection.close(); + + if (res.request.parser.header_bytes_owned) { + res.request.parser.header_bytes.deinit(res.server.allocator); + } + + res.* = undefined; + } else { + res.request.parser.reset(); + } + } + + /// Send the response headers. + pub fn do(res: *Response) !void { + var buffered = std.io.bufferedWriter(res.connection.writer()); + const w = buffered.writer(); + + try w.writeAll(@tagName(res.headers.version)); + try w.writeByte(' '); + try w.print("{d}", .{@enumToInt(res.headers.status)}); + try w.writeByte(' '); + if (res.headers.reason) |reason| { + try w.writeAll(reason); + } else if (res.headers.status.phrase()) |phrase| { + try w.writeAll(phrase); + } + + if (res.headers.server) |server| { + try w.writeAll("\r\nServer: "); + try w.writeAll(server); + } + + if (res.headers.connection == .close) { + try w.writeAll("\r\nConnection: close"); + } else { + try w.writeAll("\r\nConnection: keep-alive"); + } + + switch (res.headers.transfer_encoding) { + .chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"), + .content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}), + .none => {}, + } + + for (res.headers.custom) |header| { + try w.writeAll("\r\n"); + try w.writeAll(header.name); + try w.writeAll(": "); + try w.writeAll(header.value); + } + + try w.writeAll("\r\n\r\n"); + + try buffered.flush(); + } + + pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; + + pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); + + pub fn transferReader(res: *Response) TransferReader { + return .{ .context = res }; + } + + pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize { + if (res.request.parser.isComplete()) return 0; + + var index: usize = 0; + while (index == 0) { + const amt = try res.request.parser.read(&res.connection, buf[index..], false); + if (amt == 0 and res.request.parser.isComplete()) break; + index += amt; + } + + return index; + } + + pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported}; + + /// Wait for the client to send a complete request head. + pub fn wait(res: *Response) !void { + while (true) { + try res.connection.fill(); + + const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + res.connection.clear(@intCast(u16, nchecked)); + + if (res.request.parser.state.isContent()) break; + } + + res.request.headers = try Request.Headers.parse(res.request.parser.header_bytes.items); + + if (res.headers.connection == .keep_alive and res.request.headers.connection == .keep_alive) { + res.connection.conn.closing = false; + } else { + res.connection.conn.closing = true; + } + + if (res.request.headers.transfer_encoding) |te| { + switch (te) { + .chunked => { + res.request.parser.next_chunk_length = 0; + res.request.parser.state = .chunk_head_size; + }, + } + } else if (res.request.headers.content_length) |cl| { + res.request.parser.next_chunk_length = cl; + + if (cl == 0) res.request.parser.done = true; + } else { + res.request.parser.done = true; + } + + if (!res.request.parser.done) { + if (res.request.headers.transfer_compression) |tc| switch (tc) { + .compress => return error.CompressionNotSupported, + .deflate => res.request.compression = .{ + .deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()), + }, + .gzip => res.request.compression = .{ + .gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()), + }, + .zstd => res.request.compression = .{ + .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()), + }, + }; + } + } + + pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError; + + pub const Reader = std.io.Reader(*Response, ReadError, read); + + pub fn reader(res: *Response) Reader { + return .{ .context = res }; + } + + pub fn read(res: *Response, buffer: []u8) ReadError!usize { + return switch (res.request.compression) { + .deflate => |*deflate| try deflate.read(buffer), + .gzip => |*gzip| try gzip.read(buffer), + .zstd => |*zstd| try zstd.read(buffer), + else => try res.transferRead(buffer), + }; + } + + pub fn readAll(res: *Response, buffer: []u8) !usize { + var index: usize = 0; + while (index < buffer.len) { + const amt = try read(res, buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong }; + + pub const Writer = std.io.Writer(*Response, WriteError, write); + + pub fn writer(res: *Response) Writer { + return .{ .context = res }; + } + + /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. + pub fn write(res: *Response, bytes: []const u8) WriteError!usize { + switch (res.headers.transfer_encoding) { + .chunked => { + try res.connection.writer().print("{x}\r\n", .{bytes.len}); + try res.connection.writeAll(bytes); + try res.connection.writeAll("\r\n"); + + return bytes.len; + }, + .content_length => |*len| { + if (len.* < bytes.len) return error.MessageTooLong; + + const amt = try res.connection.write(bytes); + len.* -= amt; + return amt; + }, + .none => return error.NotWriteable, + } + } + + /// Finish the body of a request. This notifies the server that you have no more data to send. + pub fn finish(res: *Response) !void { + switch (res.headers.transfer_encoding) { + .chunked => try res.connection.writeAll("0\r\n"), + .content_length => |len| if (len != 0) return error.MessageNotCompleted, + .none => {}, + } + } +}; + +/// The mode of transport for responses. +pub const RequestTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, +}; + +/// The decompressor for request messages. +pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader); + pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader); + pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{}); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, +}; + +pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server { + return .{ + .allocator = allocator, + .socket = net.StreamServer.init(options), + }; +} + +pub fn deinit(server: *Server) void { + server.socket.deinit(); +} + +pub const ListenError = std.os.SocketError || std.os.BindError || std.os.ListenError || std.os.SetSockOptError || std.os.GetSockNameError; + +/// Start the HTTP server listening on the given address. +pub fn listen(server: *Server, address: net.Address) !void { + try server.socket.listen(address); +} + +pub const AcceptError = net.StreamServer.AcceptError || Allocator.Error; + +pub const HeaderStrategy = union(enum) { + /// In this case, the client's Allocator will be used to store the + /// entire HTTP header. This value is the maximum total size of + /// HTTP headers allowed, otherwise + /// error.HttpHeadersExceededSizeLimit is returned from read(). + dynamic: usize, + /// This is used to store the entire HTTP header. If the HTTP + /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` + /// is returned from read(). When this is used, `error.OutOfMemory` + /// cannot be returned from `read()`. + static: []u8, +}; + +/// Accept a new connection and allocate a Response for it. +pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response { + const in = try server.socket.accept(); + + const res = try server.allocator.create(Response); + res.* = .{ + .server = server, + .address = in.address, + .connection = .{ .conn = .{ + .stream = in.stream, + .protocol = .plain, + } }, + .request = .{ + .parser = switch (options) { + .dynamic => |max| proto.HeadersParser.initDynamic(max), + .static => |buf| proto.HeadersParser.initStatic(buf), + }, + }, + }; + + return res; +} diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig new file mode 100644 index 0000000000..5e63d3092b --- /dev/null +++ b/lib/std/http/protocol.zig @@ -0,0 +1,842 @@ +const std = @import("std"); +const testing = std.testing; +const mem = std.mem; + +const assert = std.debug.assert; + +pub const State = enum { + /// Begin header parsing states. + invalid, + start, + seen_n, + seen_r, + seen_rn, + seen_rnr, + finished, + /// Begin transfer-encoding: chunked parsing states. + chunk_head_size, + chunk_head_ext, + chunk_head_r, + chunk_data, + chunk_data_suffix, + chunk_data_suffix_r, + + /// Returns true if the parser is in a content state (ie. not waiting for more headers). + pub fn isContent(self: State) bool { + return switch (self) { + .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, + .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, + }; + } +}; + +pub const HeadersParser = struct { + state: State = .start, + /// Whether or not `header_bytes` is allocated or was provided as a fixed buffer. + header_bytes_owned: bool, + /// Either a fixed buffer of len `max_header_bytes` or a dynamic buffer that can grow up to `max_header_bytes`. + /// Pointers into this buffer are not stable until after a message is complete. + header_bytes: std.ArrayListUnmanaged(u8), + /// The maximum allowed size of `header_bytes`. + max_header_bytes: usize, + next_chunk_length: u64 = 0, + /// Whether this parser is done parsing a complete message. + /// A message is only done when the entire payload has been read. + done: bool = false, + + /// Initializes the parser with a dynamically growing header buffer of up to `max` bytes. + pub fn initDynamic(max: usize) HeadersParser { + return .{ + .header_bytes = .{}, + .max_header_bytes = max, + .header_bytes_owned = true, + }; + } + + /// Initializes the parser with a provided buffer `buf`. + pub fn initStatic(buf: []u8) HeadersParser { + return .{ + .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, + .max_header_bytes = buf.len, + .header_bytes_owned = false, + }; + } + + /// Completely resets the parser to it's initial state. + /// This must be called after a message is complete. + pub fn reset(r: *HeadersParser) void { + assert(r.done); // The message must be completely read before reset, otherwise the parser is in an invalid state. + + r.header_bytes.clearRetainingCapacity(); + + r.* = .{ + .header_bytes = r.header_bytes, + .max_header_bytes = r.max_header_bytes, + .header_bytes_owned = r.header_bytes_owned, + }; + } + + /// Returns the number of bytes consumed by headers. This is always less than or equal to `bytes.len`. + /// You should check `r.state.isContent()` after this to check if the headers are done. + /// + /// If the amount returned is less than `bytes.len`, you may assume that the parser is in a content state and the + /// first byte of content is located at `bytes[result]`. + pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { + const vector_len: comptime_int = comptime std.math.max(std.simd.suggestVectorSize(u8) orelse 1, 8); + const len = @intCast(u32, bytes.len); + var index: u32 = 0; + + while (true) { + switch (r.state) { + .invalid => unreachable, + .finished => return index, + .start => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + + return index + 2; + }, + 3 => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => r.state = .seen_rnr, + else => {}, + } + + return index + 3; + }, + 4...vector_len - 1 => { + const b32 = int32(bytes[index..][0..4]); + const b24 = intShift(u24, b32); + const b16 = intShift(u16, b32); + const b8 = intShift(u8, b32); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => r.state = .seen_rnr, + else => {}, + } + + switch (b32) { + int32("\r\n\r\n") => r.state = .finished, + else => {}, + } + + index += 4; + continue; + }, + else => { + const Vector = @Vector(vector_len, u8); + // const BoolVector = @Vector(vector_len, bool); + const BitVector = @Vector(vector_len, u1); + const SizeVector = @Vector(vector_len, u8); + + const chunk = bytes[index..][0..vector_len]; + const v: Vector = chunk.*; + const matches_r = @bitCast(BitVector, v == @splat(vector_len, @as(u8, '\r'))); + const matches_n = @bitCast(BitVector, v == @splat(vector_len, @as(u8, '\n'))); + const matches_or: SizeVector = matches_r | matches_n; + + const matches = @reduce(.Add, matches_or); + switch (matches) { + 0 => {}, + 1 => switch (chunk[vector_len - 1]) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + }, + 2 => { + const b16 = int16(chunk[vector_len - 2 ..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + }, + 3 => { + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => r.state = .seen_rnr, + else => {}, + } + }, + 4...vector_len => { + inline for (0..vector_len - 3) |i_usize| { + const i = @truncate(u32, i_usize); + + const b32 = int32(chunk[i..][0..4]); + const b16 = intShift(u16, b32); + + if (b32 == int32("\r\n\r\n")) { + r.state = .finished; + return index + i + 4; + } else if (b16 == int16("\n\n")) { + r.state = .finished; + return index + i + 2; + } + } + + const b24 = int24(chunk[vector_len - 3 ..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => {}, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + + switch (b24) { + int24("\r\n\r") => r.state = .seen_rnr, + else => {}, + } + }, + else => unreachable, + } + + index += vector_len; + continue; + }, + }, + .seen_n => switch (len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => r.state = .finished, + else => r.state = .start, + } + + index += 1; + continue; + }, + }, + .seen_r => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\n' => r.state = .seen_rn, + '\r' => r.state = .seen_r, + else => r.state = .start, + } + + return index + 1; + }, + 2 => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_rn, + else => r.state = .start, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\r") => r.state = .seen_rnr, + int16("\n\n") => r.state = .finished, + else => {}, + } + + return index + 2; + }, + else => { + const b24 = int24(bytes[index..][0..3]); + const b16 = intShift(u16, b24); + const b8 = intShift(u8, b24); + + switch (b8) { + '\r' => r.state = .seen_r, + '\n' => r.state = .seen_n, + else => r.state = .start, + } + + switch (b16) { + int16("\r\n") => r.state = .seen_rn, + int16("\n\n") => r.state = .finished, + else => {}, + } + + switch (b24) { + int24("\n\r\n") => r.state = .finished, + else => {}, + } + + index += 3; + continue; + }, + }, + .seen_rn => switch (len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => r.state = .seen_rnr, + '\n' => r.state = .seen_n, + else => r.state = .start, + } + + return index + 1; + }, + else => { + const b16 = int16(bytes[index..][0..2]); + const b8 = intShift(u8, b16); + + switch (b8) { + '\r' => r.state = .seen_rnr, + '\n' => r.state = .seen_n, + else => r.state = .start, + } + + switch (b16) { + int16("\r\n") => r.state = .finished, + int16("\n\n") => r.state = .finished, + else => {}, + } + + index += 2; + continue; + }, + }, + .seen_rnr => switch (len - index) { + 0 => return index, + else => { + switch (bytes[index]) { + '\n' => r.state = .finished, + else => r.state = .start, + } + + index += 1; + continue; + }, + }, + .chunk_head_size => unreachable, + .chunk_head_ext => unreachable, + .chunk_head_r => unreachable, + .chunk_data => unreachable, + .chunk_data_suffix => unreachable, + .chunk_data_suffix_r => unreachable, + } + + return index; + } + } + + /// Returns the number of bytes consumed by the chunk size. This is always less than or equal to `bytes.len`. + /// You should check `r.state == .chunk_data` after this to check if the chunk size has been fully parsed. + /// + /// If the amount returned is less than `bytes.len`, you may assume that the parser is in the `chunk_data` state + /// and that the first byte of the chunk is at `bytes[result]`. + pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { + const len = @intCast(u32, bytes.len); + + for (bytes[0..], 0..) |c, i| { + const index = @intCast(u32, i); + switch (r.state) { + .chunk_data_suffix => switch (c) { + '\r' => r.state = .chunk_data_suffix_r, + '\n' => r.state = .chunk_head_size, + else => { + r.state = .invalid; + return index; + }, + }, + .chunk_data_suffix_r => switch (c) { + '\n' => r.state = .chunk_head_size, + else => { + r.state = .invalid; + return index; + }, + }, + .chunk_head_size => { + const digit = switch (c) { + '0'...'9' => |b| b - '0', + 'A'...'Z' => |b| b - 'A' + 10, + 'a'...'z' => |b| b - 'a' + 10, + '\r' => { + r.state = .chunk_head_r; + continue; + }, + '\n' => { + r.state = .chunk_data; + return index + 1; + }, + else => { + r.state = .chunk_head_ext; + continue; + }, + }; + + const new_len = r.next_chunk_length *% 16 +% digit; + if (new_len <= r.next_chunk_length and r.next_chunk_length != 0) { + r.state = .invalid; + return index; + } + + r.next_chunk_length = new_len; + }, + .chunk_head_ext => switch (c) { + '\r' => r.state = .chunk_head_r, + '\n' => { + r.state = .chunk_data; + return index + 1; + }, + else => continue, + }, + .chunk_head_r => switch (c) { + '\n' => { + r.state = .chunk_data; + return index + 1; + }, + else => { + r.state = .invalid; + return index; + }, + }, + else => unreachable, + } + } + + return len; + } + + /// Returns whether or not the parser has finished parsing a complete message. A message is only complete after the + /// entire body has been read and any trailing headers have been parsed. + pub fn isComplete(r: *HeadersParser) bool { + return r.done and r.state == .finished; + } + + pub const CheckCompleteHeadError = mem.Allocator.Error || error{HttpHeadersExceededSizeLimit}; + + /// Pushes `in` into the parser. Returns the number of bytes consumed by the header. Any header bytes are appended + /// to the `header_bytes` buffer. + /// + /// This function only uses `allocator` if `r.header_bytes_owned` is true, and may be undefined otherwise. + pub fn checkCompleteHead(r: *HeadersParser, allocator: std.mem.Allocator, in: []const u8) CheckCompleteHeadError!u32 { + if (r.state.isContent()) return 0; + + const i = r.findHeadersEnd(in); + const data = in[0..i]; + if (r.header_bytes.items.len + data.len > r.max_header_bytes) { + return error.HttpHeadersExceededSizeLimit; + } else { + if (r.header_bytes_owned) try r.header_bytes.ensureUnusedCapacity(allocator, data.len); + + r.header_bytes.appendSliceAssumeCapacity(data); + } + + return i; + } + + pub const ReadError = error{ + HttpChunkInvalid, + }; + + /// Reads the body of the message into `buffer`. Returns the number of bytes placed in the buffer. + /// + /// If `skip` is true, the buffer will be unused and the body will be skipped. + /// + /// See `std.http.Client.BufferedConnection for an example of `bconn`. + pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize { + assert(r.state.isContent()); + if (r.done) return 0; + + var out_index: usize = 0; + while (true) { + switch (r.state) { + .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, + .finished => { + const data_avail = r.next_chunk_length; + + if (skip) { + try bconn.fill(); + + const nread = @min(bconn.peek().len, data_avail); + bconn.clear(@intCast(u16, nread)); + r.next_chunk_length -= nread; + + if (r.next_chunk_length == 0) r.done = true; + + return 0; + } else { + const out_avail = buffer.len; + + const can_read = @intCast(usize, @min(data_avail, out_avail)); + const nread = try bconn.read(buffer[0..can_read]); + r.next_chunk_length -= nread; + + if (r.next_chunk_length == 0) r.done = true; + + return nread; + } + }, + .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { + try bconn.fill(); + + const i = r.findChunkedLen(bconn.peek()); + bconn.clear(@intCast(u16, i)); + + switch (r.state) { + .invalid => return error.HttpChunkInvalid, + .chunk_data => if (r.next_chunk_length == 0) { + // The trailer section is formatted identically to the header section. + r.state = .seen_rn; + r.done = true; + + return out_index; + }, + else => return out_index, + } + + continue; + }, + .chunk_data => { + const data_avail = r.next_chunk_length; + const out_avail = buffer.len - out_index; + + if (skip) { + try bconn.fill(); + + const nread = @min(bconn.peek().len, data_avail); + bconn.clear(@intCast(u16, nread)); + r.next_chunk_length -= nread; + } else { + const can_read = @intCast(usize, @min(data_avail, out_avail)); + const nread = try bconn.read(buffer[out_index..][0..can_read]); + r.next_chunk_length -= nread; + out_index += nread; + } + + if (r.next_chunk_length == 0) { + r.state = .chunk_data_suffix; + continue; + } + + return out_index; + }, + } + } + } +}; + +inline fn int16(array: *const [2]u8) u16 { + return @bitCast(u16, array.*); +} + +inline fn int24(array: *const [3]u8) u24 { + return @bitCast(u24, array.*); +} + +inline fn int32(array: *const [4]u8) u32 { + return @bitCast(u32, array.*); +} + +inline fn intShift(comptime T: type, x: anytype) T { + switch (@import("builtin").cpu.arch.endian()) { + .Little => return @truncate(T, x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T))), + .Big => return @truncate(T, x), + } +} + +/// A buffered (and peekable) Connection. +const MockBufferedConnection = struct { + pub const buffer_size = 0x2000; + + conn: std.io.FixedBufferStream([]const u8), + buf: [buffer_size]u8 = undefined, + start: u16 = 0, + end: u16 = 0, + + pub fn fill(bconn: *MockBufferedConnection) ReadError!void { + if (bconn.end != bconn.start) return; + + const nread = try bconn.conn.read(bconn.buf[0..]); + if (nread == 0) return error.EndOfStream; + bconn.start = 0; + bconn.end = @truncate(u16, nread); + } + + pub fn peek(bconn: *MockBufferedConnection) []const u8 { + return bconn.buf[bconn.start..bconn.end]; + } + + pub fn clear(bconn: *MockBufferedConnection, num: u16) void { + bconn.start += num; + } + + pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { + var out_index: u16 = 0; + while (out_index < len) { + const available = bconn.end - bconn.start; + const left = buffer.len - out_index; + + if (available > 0) { + const can_read = @truncate(u16, @min(available, left)); + + std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]); + out_index += can_read; + bconn.start += can_read; + + continue; + } + + if (left > bconn.buf.len) { + // skip the buffer if the output is large enough + return bconn.conn.read(buffer[out_index..]); + } + + try bconn.fill(); + } + + return out_index; + } + + pub fn read(bconn: *MockBufferedConnection, buffer: []u8) ReadError!usize { + return bconn.readAtLeast(buffer, 1); + } + + pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; + pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read); + + pub fn reader(bconn: *MockBufferedConnection) Reader { + return Reader{ .context = bconn }; + } + + pub fn writeAll(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!void { + return bconn.conn.writeAll(buffer); + } + + pub fn write(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { + return bconn.conn.write(buffer); + } + + pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; + pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write); + + pub fn writer(bconn: *MockBufferedConnection) Writer { + return Writer{ .context = bconn }; + } +}; + +test "HeadersParser.findHeadersEnd" { + var r: HeadersParser = undefined; + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello"; + + for (0..36) |i| { + r = HeadersParser.initDynamic(0); + try std.testing.expectEqual(@intCast(u32, i), r.findHeadersEnd(data[0..i])); + try std.testing.expectEqual(@intCast(u32, 35 - i), r.findHeadersEnd(data[i..])); + } +} + +test "HeadersParser.findChunkedLen" { + var r: HeadersParser = undefined; + const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n"; + + r = HeadersParser.initDynamic(0); + r.state = .chunk_head_size; + r.next_chunk_length = 0; + + const first = r.findChunkedLen(data[0..]); + try testing.expectEqual(@as(u32, 4), first); + try testing.expectEqual(@as(u64, 0xff), r.next_chunk_length); + try testing.expectEqual(State.chunk_data, r.state); + r.state = .chunk_head_size; + r.next_chunk_length = 0; + + const second = r.findChunkedLen(data[first..]); + try testing.expectEqual(@as(u32, 13), second); + try testing.expectEqual(@as(u64, 0xf0f000), r.next_chunk_length); + try testing.expectEqual(State.chunk_data, r.state); + r.state = .chunk_head_size; + r.next_chunk_length = 0; + + const third = r.findChunkedLen(data[first + second ..]); + try testing.expectEqual(@as(u32, 3), third); + try testing.expectEqual(@as(u64, 0), r.next_chunk_length); + try testing.expectEqual(State.chunk_data, r.state); + r.state = .chunk_head_size; + r.next_chunk_length = 0; + + const fourth = r.findChunkedLen(data[first + second + third ..]); + try testing.expectEqual(@as(u32, 16), fourth); + try testing.expectEqual(@as(u64, 0xffffffffffffffff), r.next_chunk_length); + try testing.expectEqual(State.invalid, r.state); +} + +test "HeadersParser.read length" { + // mock BufferedConnection for read + + var r = HeadersParser.initDynamic(256); + defer r.header_bytes.deinit(std.testing.allocator); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; + var fbs = std.io.fixedBufferStream(data); + + var bconn = MockBufferedConnection{ + .conn = fbs, + }; + + while (true) { // read headers + try bconn.fill(); + + const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); + bconn.clear(@intCast(u16, nchecked)); + + if (r.state.isContent()) break; + } + + var buf: [8]u8 = undefined; + + r.next_chunk_length = 5; + const len = try r.read(&bconn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.header_bytes.items); +} + +test "HeadersParser.read chunked" { + // mock BufferedConnection for read + + var r = HeadersParser.initDynamic(256); + defer r.header_bytes.deinit(std.testing.allocator); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; + var fbs = std.io.fixedBufferStream(data); + + var bconn = MockBufferedConnection{ + .conn = fbs, + }; + + while (true) { // read headers + try bconn.fill(); + + const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); + bconn.clear(@intCast(u16, nchecked)); + + if (r.state.isContent()) break; + } + var buf: [8]u8 = undefined; + + r.state = .chunk_head_size; + const len = try r.read(&bconn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.header_bytes.items); +} + +test "HeadersParser.read chunked trailer" { + // mock BufferedConnection for read + + var r = HeadersParser.initDynamic(256); + defer r.header_bytes.deinit(std.testing.allocator); + const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; + var fbs = std.io.fixedBufferStream(data); + + var bconn = MockBufferedConnection{ + .conn = fbs, + }; + + while (true) { // read headers + try bconn.fill(); + + const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); + bconn.clear(@intCast(u16, nchecked)); + + if (r.state.isContent()) break; + } + var buf: [8]u8 = undefined; + + r.state = .chunk_head_size; + const len = try r.read(&bconn, &buf, false); + try std.testing.expectEqual(@as(usize, 5), len); + try std.testing.expectEqualStrings("Hello", buf[0..len]); + + while (true) { // read headers + try bconn.fill(); + + const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek()); + bconn.clear(@intCast(u16, nchecked)); + + if (r.state.isContent()) break; + } + + try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.header_bytes.items); +} |
