aboutsummaryrefslogtreecommitdiff
path: root/lib/std
diff options
context:
space:
mode:
authorAndrew Kelley <andrew@ziglang.org>2023-04-09 10:44:52 -0400
committerGitHub <noreply@github.com>2023-04-09 10:44:52 -0400
commit2ee328995a70c5c446f24c5593e0fad760e6d839 (patch)
tree0e547171b7790ffd182fc298d384ef614571e97e /lib/std
parentc22a30ac99b9a2b92d9a8e926b9bf0c9dbc3d14e (diff)
parent7f9a4625fda0b1a33177cdd66819f0a061c6b2da (diff)
downloadzig-2ee328995a70c5c446f24c5593e0fad760e6d839.tar.gz
zig-2ee328995a70c5c446f24c5593e0fad760e6d839.zip
Merge pull request #15123 from truemedian/http-server
std.http: add http server
Diffstat (limited to 'lib/std')
-rw-r--r--lib/std/http.zig2
-rw-r--r--lib/std/http/Client.zig880
-rw-r--r--lib/std/http/Client/Request.zig482
-rw-r--r--lib/std/http/Client/Response.zig509
-rw-r--r--lib/std/http/Server.zig600
-rw-r--r--lib/std/http/protocol.zig842
6 files changed, 2203 insertions, 1112 deletions
diff --git a/lib/std/http.zig b/lib/std/http.zig
index ef89f09925..6e5f4e0cd9 100644
--- a/lib/std/http.zig
+++ b/lib/std/http.zig
@@ -1,4 +1,6 @@
pub const Client = @import("http/Client.zig");
+pub const Server = @import("http/Server.zig");
+pub const protocol = @import("http/protocol.zig");
pub const Version = enum {
@"HTTP/1.0",
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig
index 76073c0ce3..010c557f87 100644
--- a/lib/std/http/Client.zig
+++ b/lib/std/http/Client.zig
@@ -1,49 +1,105 @@
-//! TODO: send connection: keep-alive and LRU cache a configurable number of
-//! open connections to skip DNS and TLS handshake for subsequent requests.
-//!
-//! This API is *not* thread safe.
+//! Connecting and opening requests are threadsafe. Individual requests are not.
const std = @import("../std.zig");
-const mem = std.mem;
-const assert = std.debug.assert;
+const testing = std.testing;
const http = std.http;
+const mem = std.mem;
const net = std.net;
-const Client = @This();
const Uri = std.Uri;
-const Allocator = std.mem.Allocator;
-const testing = std.testing;
+const Allocator = mem.Allocator;
+const assert = std.debug.assert;
-pub const Request = @import("Client/Request.zig");
-pub const Response = @import("Client/Response.zig");
+const Client = @This();
+const proto = @import("protocol.zig");
pub const default_connection_pool_size = 32;
-const connection_pool_size = std.options.http_connection_pool_size;
+pub const connection_pool_size = std.options.http_connection_pool_size;
-/// Used for tcpConnectToHost and storing HTTP headers when an externally
-/// managed buffer is not provided.
allocator: Allocator,
ca_bundle: std.crypto.Certificate.Bundle = .{},
+ca_bundle_mutex: std.Thread.Mutex = .{},
/// When this is `true`, the next time this client performs an HTTPS request,
/// it will first rescan the system for root certificates.
next_https_rescan_certs: bool = true,
+/// The pool of connections that can be reused (and currently in use).
connection_pool: ConnectionPool = .{},
+/// The last error that occurred on this client. This is not threadsafe, do not expect it to be completely accurate.
+last_error: ?ExtraError = null,
+
+pub const ExtraError = union(enum) {
+ fn impliedErrorSet(comptime f: anytype) type {
+ const set = @typeInfo(@typeInfo(@TypeOf(f)).Fn.return_type.?).ErrorUnion.error_set;
+ if (@typeName(set)[0] != '@') @compileError(@typeName(f) ++ " doesn't have an implied error set any more.");
+ return set;
+ }
+
+ // There's apparently a dependency loop with using Client.DeflateDecompressor.
+ const FakeTransferError = proto.HeadersParser.ReadError || error{ReadFailed};
+ const FakeTransferReader = std.io.Reader(void, FakeTransferError, fakeRead);
+ fn fakeRead(ctx: void, buf: []u8) FakeTransferError!usize {
+ _ = .{ buf, ctx };
+ return 0;
+ }
+
+ const FakeDeflateDecompressor = std.compress.zlib.ZlibStream(FakeTransferReader);
+ const FakeGzipDecompressor = std.compress.gzip.Decompress(FakeTransferReader);
+ const FakeZstdDecompressor = std.compress.zstd.DecompressStream(FakeTransferReader, .{});
+
+ pub const TcpConnectError = std.net.TcpConnectToHostError;
+ pub const TlsError = std.crypto.tls.Client.InitError(net.Stream);
+ pub const WriteError = BufferedConnection.WriteError;
+ pub const ReadError = BufferedConnection.ReadError || error{HttpChunkInvalid};
+ pub const CaBundleError = impliedErrorSet(std.crypto.Certificate.Bundle.rescan);
+
+ pub const ZlibInitError = error{ BadHeader, InvalidCompression, InvalidWindowSize, Unsupported, EndOfStream, OutOfMemory } || Request.TransferReadError;
+ pub const GzipInitError = error{ BadHeader, InvalidCompression, OutOfMemory, WrongChecksum, EndOfStream, StreamTooLong } || Request.TransferReadError;
+ // pub const DecompressError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error;
+ pub const DecompressError = FakeDeflateDecompressor.Error || FakeGzipDecompressor.Error || FakeZstdDecompressor.Error;
+
+ zlib_init: ZlibInitError, // error.CompressionInitializationFailed
+ gzip_init: GzipInitError, // error.CompressionInitializationFailed
+ connect: TcpConnectError, // error.ConnectionFailed
+ ca_bundle: CaBundleError, // error.CertificateAuthorityBundleFailed
+ tls: TlsError, // error.TlsInitializationFailed
+ write: WriteError, // error.WriteFailed
+ read: ReadError, // error.ReadFailed
+ decompress: DecompressError, // error.ReadFailed
+};
+
+/// A set of linked lists of connections that can be reused.
pub const ConnectionPool = struct {
+ /// The criteria for a connection to be considered a match.
pub const Criteria = struct {
host: []const u8,
port: u16,
is_tls: bool,
};
- const Queue = std.TailQueue(Connection);
+ pub const StoredConnection = struct {
+ buffered: BufferedConnection,
+ host: []u8,
+ port: u16,
+
+ closing: bool = false,
+
+ pub fn deinit(self: *StoredConnection, client: *Client) void {
+ self.buffered.close(client);
+ client.allocator.free(self.host);
+ }
+ };
+
+ const Queue = std.TailQueue(StoredConnection);
pub const Node = Queue.Node;
mutex: std.Thread.Mutex = .{},
+ /// Open connections that are currently in use.
used: Queue = .{},
+ /// Open connections that are not currently in use.
free: Queue = .{},
free_len: usize = 0,
- free_size: usize = default_connection_pool_size,
+ free_size: usize = connection_pool_size,
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
/// If no connection is found, null is returned.
@@ -53,9 +109,9 @@ pub const ConnectionPool = struct {
var next = pool.free.last;
while (next) |node| : (next = node.prev) {
- if ((node.data.protocol == .tls) != criteria.is_tls) continue;
+ if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue;
if (node.data.port != criteria.port) continue;
- if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
+ if (mem.eql(u8, node.data.host, criteria.host)) continue;
pool.acquireUnsafe(node);
return node;
@@ -89,7 +145,7 @@ pub const ConnectionPool = struct {
pool.used.remove(node);
if (node.data.closing) {
- node.data.close(client);
+ node.data.deinit(client);
return client.allocator.destroy(node);
}
@@ -97,7 +153,7 @@ pub const ConnectionPool = struct {
if (pool.free_len + 1 >= pool.free_size) {
const popped = pool.free.popFirst() orelse unreachable;
- popped.data.close(client);
+ popped.data.deinit(client);
return client.allocator.destroy(popped);
}
@@ -122,7 +178,7 @@ pub const ConnectionPool = struct {
defer client.allocator.destroy(node);
next = node.next;
- node.data.close(client);
+ node.data.deinit(client);
}
next = pool.used.first;
@@ -130,27 +186,19 @@ pub const ConnectionPool = struct {
defer client.allocator.destroy(node);
next = node.next;
- node.data.close(client);
+ node.data.deinit(client);
}
pool.* = undefined;
}
};
-pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
-pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
-pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.ReaderRaw, .{});
-
+/// An interface to either a plain or TLS connection.
pub const Connection = struct {
stream: net.Stream,
/// undefined unless protocol is tls.
- tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
+ tls_client: *std.crypto.tls.Client,
protocol: Protocol,
- host: []u8,
- port: u16,
-
- // This connection has been part of a non keepalive request and cannot be added to the pool.
- closing: bool = false,
pub const Protocol = enum { plain, tls };
@@ -215,11 +263,611 @@ pub const Connection = struct {
}
conn.stream.close();
+ }
+};
+
+/// A buffered (and peekable) Connection.
+pub const BufferedConnection = struct {
+ pub const buffer_size = 0x2000;
+
+ conn: Connection,
+ buf: [buffer_size]u8 = undefined,
+ start: u16 = 0,
+ end: u16 = 0,
+
+ pub fn fill(bconn: *BufferedConnection) ReadError!void {
+ if (bconn.end != bconn.start) return;
+
+ const nread = try bconn.conn.read(bconn.buf[0..]);
+ if (nread == 0) return error.EndOfStream;
+ bconn.start = 0;
+ bconn.end = @truncate(u16, nread);
+ }
+
+ pub fn peek(bconn: *BufferedConnection) []const u8 {
+ return bconn.buf[bconn.start..bconn.end];
+ }
+
+ pub fn clear(bconn: *BufferedConnection, num: u16) void {
+ bconn.start += num;
+ }
+
+ pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
+ var out_index: u16 = 0;
+ while (out_index < len) {
+ const available = bconn.end - bconn.start;
+ const left = buffer.len - out_index;
+
+ if (available > 0) {
+ const can_read = @truncate(u16, @min(available, left));
+
+ std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
+ out_index += can_read;
+ bconn.start += can_read;
+
+ continue;
+ }
+
+ if (left > bconn.buf.len) {
+ // skip the buffer if the output is large enough
+ return bconn.conn.read(buffer[out_index..]);
+ }
+
+ try bconn.fill();
+ }
+
+ return out_index;
+ }
+
+ pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
+ return bconn.readAtLeast(buffer, 1);
+ }
+
+ pub const ReadError = Connection.ReadError || error{EndOfStream};
+ pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
+
+ pub fn reader(bconn: *BufferedConnection) Reader {
+ return Reader{ .context = bconn };
+ }
+
+ pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
+ return bconn.conn.writeAll(buffer);
+ }
+
+ pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
+ return bconn.conn.write(buffer);
+ }
+
+ pub const WriteError = Connection.WriteError;
+ pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
+
+ pub fn writer(bconn: *BufferedConnection) Writer {
+ return Writer{ .context = bconn };
+ }
+
+ pub fn close(bconn: *BufferedConnection, client: *const Client) void {
+ bconn.conn.close(client);
+ }
+};
+
+/// The mode of transport for requests.
+pub const RequestTransfer = union(enum) {
+ content_length: u64,
+ chunked: void,
+ none: void,
+};
+
+/// The decompressor for response messages.
+pub const Compression = union(enum) {
+ pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.TransferReader);
+ pub const GzipDecompressor = std.compress.gzip.Decompress(Request.TransferReader);
+ pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Request.TransferReader, .{});
+
+ deflate: DeflateDecompressor,
+ gzip: GzipDecompressor,
+ zstd: ZstdDecompressor,
+ none: void,
+};
+
+/// A HTTP response originating from a server.
+pub const Response = struct {
+ pub const Headers = struct {
+ status: http.Status,
+ version: http.Version,
+ location: ?[]const u8 = null,
+ content_length: ?u64 = null,
+ transfer_encoding: ?http.TransferEncoding = null,
+ transfer_compression: ?http.ContentEncoding = null,
+ connection: http.Connection = .close,
+ upgrade: ?[]const u8 = null,
+
+ pub const ParseError = error{
+ ShortHttpStatusLine,
+ BadHttpVersion,
+ HttpHeadersInvalid,
+ HttpHeaderContinuationsUnsupported,
+ HttpTransferEncodingUnsupported,
+ HttpConnectionHeaderUnsupported,
+ InvalidContentLength,
+ CompressionNotSupported,
+ };
+
+ pub fn parse(bytes: []const u8) ParseError!Headers {
+ var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
+
+ const first_line = it.next() orelse return error.HttpHeadersInvalid;
+ if (first_line.len < 12)
+ return error.ShortHttpStatusLine;
+
+ const version: http.Version = switch (int64(first_line[0..8])) {
+ int64("HTTP/1.0") => .@"HTTP/1.0",
+ int64("HTTP/1.1") => .@"HTTP/1.1",
+ else => return error.BadHttpVersion,
+ };
+ if (first_line[8] != ' ') return error.HttpHeadersInvalid;
+ const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
+
+ var headers: Headers = .{
+ .version = version,
+ .status = status,
+ };
+
+ while (it.next()) |line| {
+ if (line.len == 0) return error.HttpHeadersInvalid;
+ switch (line[0]) {
+ ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
+ else => {},
+ }
+
+ var line_it = mem.tokenize(u8, line, ": ");
+ const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
+ const header_value = line_it.rest();
+ if (std.ascii.eqlIgnoreCase(header_name, "location")) {
+ if (headers.location != null) return error.HttpHeadersInvalid;
+ headers.location = header_value;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
+ if (headers.content_length != null) return error.HttpHeadersInvalid;
+ headers.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
+ // Transfer-Encoding: second, first
+ // Transfer-Encoding: deflate, chunked
+ var iter = mem.splitBackwards(u8, header_value, ",");
+
+ if (iter.next()) |first| {
+ const trimmed = mem.trim(u8, first, " ");
+
+ if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
+ if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
+ headers.transfer_encoding = te;
+ } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |second| {
+ if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
+
+ const trimmed = mem.trim(u8, second, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
+ if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
+
+ const trimmed = mem.trim(u8, header_value, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
+ if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
+ headers.connection = .keep_alive;
+ } else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
+ headers.connection = .close;
+ } else {
+ return error.HttpConnectionHeaderUnsupported;
+ }
+ } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
+ headers.upgrade = header_value;
+ }
+ }
+
+ return headers;
+ }
+
+ inline fn int64(array: *const [8]u8) u64 {
+ return @bitCast(u64, array.*);
+ }
+
+ fn parseInt3(nnn: @Vector(3, u8)) u10 {
+ const zero: @Vector(3, u8) = .{ '0', '0', '0' };
+ const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
+ return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
+ }
+
+ test parseInt3 {
+ const expectEqual = testing.expectEqual;
+ try expectEqual(@as(u10, 0), parseInt3("000".*));
+ try expectEqual(@as(u10, 418), parseInt3("418".*));
+ try expectEqual(@as(u10, 999), parseInt3("999".*));
+ }
+ };
+
+ headers: Headers = undefined,
+ parser: proto.HeadersParser,
+ compression: Compression = .none,
+ skip: bool = false,
+};
+
+/// A HTTP request that has been sent.
+///
+/// Order of operations: request[ -> write -> finish] -> do -> read
+pub const Request = struct {
+ pub const Headers = struct {
+ version: http.Version = .@"HTTP/1.1",
+ method: http.Method = .GET,
+ user_agent: []const u8 = "zig (std.http)",
+ connection: http.Connection = .keep_alive,
+ transfer_encoding: RequestTransfer = .none,
+
+ custom: []const http.CustomHeader = &[_]http.CustomHeader{},
+ };
+
+ uri: Uri,
+ client: *Client,
+ connection: *ConnectionPool.Node,
+ /// These are stored in Request so that they are available when following
+ /// redirects.
+ headers: Headers,
+
+ redirects_left: u32,
+ handle_redirects: bool,
+
+ response: Response,
+
+ /// Used as a allocator for resolving redirects locations.
+ arena: std.heap.ArenaAllocator,
+
+ /// Frees all resources associated with the request.
+ pub fn deinit(req: *Request) void {
+ switch (req.response.compression) {
+ .none => {},
+ .deflate => |*deflate| deflate.deinit(),
+ .gzip => |*gzip| gzip.deinit(),
+ .zstd => |*zstd| zstd.deinit(),
+ }
+
+ if (req.response.parser.header_bytes_owned) {
+ req.response.parser.header_bytes.deinit(req.client.allocator);
+ }
+
+ if (!req.response.parser.done) {
+ // If the response wasn't fully read, then we need to close the connection.
+ req.connection.data.closing = true;
+ req.client.connection_pool.release(req.client, req.connection);
+ }
+
+ req.arena.deinit();
+ req.* = undefined;
+ }
- client.allocator.free(conn.host);
+ pub fn start(req: *Request, uri: Uri, headers: Headers) !void {
+ var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
+ const w = buffered.writer();
+
+ const escaped_path = try Uri.escapePath(req.client.allocator, uri.path);
+ defer req.client.allocator.free(escaped_path);
+
+ const escaped_query = if (uri.query) |q| try Uri.escapeQuery(req.client.allocator, q) else null;
+ defer if (escaped_query) |q| req.client.allocator.free(q);
+
+ const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(req.client.allocator, f) else null;
+ defer if (escaped_fragment) |f| req.client.allocator.free(f);
+
+ try w.writeAll(@tagName(headers.method));
+ try w.writeByte(' ');
+ if (escaped_path.len == 0) {
+ try w.writeByte('/');
+ } else {
+ try w.writeAll(escaped_path);
+ }
+ if (escaped_query) |q| {
+ try w.writeByte('?');
+ try w.writeAll(q);
+ }
+ if (escaped_fragment) |f| {
+ try w.writeByte('#');
+ try w.writeAll(f);
+ }
+ try w.writeByte(' ');
+ try w.writeAll(@tagName(headers.version));
+ try w.writeAll("\r\nHost: ");
+ try w.writeAll(uri.host.?);
+ try w.writeAll("\r\nUser-Agent: ");
+ try w.writeAll(headers.user_agent);
+ if (headers.connection == .close) {
+ try w.writeAll("\r\nConnection: close");
+ } else {
+ try w.writeAll("\r\nConnection: keep-alive");
+ }
+ try w.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd");
+ try w.writeAll("\r\nTE: gzip, deflate"); // TODO: add trailers when someone finds a nice way to integrate them without completely invalidating all pointers to headers.
+
+ switch (headers.transfer_encoding) {
+ .chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"),
+ .content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}),
+ .none => {},
+ }
+
+ for (headers.custom) |header| {
+ try w.writeAll("\r\n");
+ try w.writeAll(header.name);
+ try w.writeAll(": ");
+ try w.writeAll(header.value);
+ }
+
+ try w.writeAll("\r\n\r\n");
+
+ try buffered.flush();
+ }
+
+ pub const TransferReadError = proto.HeadersParser.ReadError || error{ReadFailed};
+
+ pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
+
+ pub fn transferReader(req: *Request) TransferReader {
+ return .{ .context = req };
+ }
+
+ pub fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
+ if (req.response.parser.done) return 0;
+
+ var index: usize = 0;
+ while (index == 0) {
+ const amt = req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip) catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+ if (amt == 0 and req.response.parser.done) break;
+ index += amt;
+ }
+
+ return index;
+ }
+
+ pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.Headers.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed };
+
+ /// Waits for a response from the server and parses any headers that are sent.
+ /// This function will block until the final response is received.
+ ///
+ /// If `handle_redirects` is true, then this function will automatically follow
+ /// redirects.
+ pub fn do(req: *Request) DoError!void {
+ while (true) { // handle redirects
+ while (true) { // read headers
+ req.connection.data.buffered.fill() catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
+ req.connection.data.buffered.clear(@intCast(u16, nchecked));
+
+ if (req.response.parser.state.isContent()) break;
+ }
+
+ req.response.headers = try Response.Headers.parse(req.response.parser.header_bytes.items);
+
+ if (req.response.headers.status == .switching_protocols) {
+ req.connection.data.closing = false;
+ req.response.parser.done = true;
+ }
+
+ if (req.headers.connection == .keep_alive and req.response.headers.connection == .keep_alive) {
+ req.connection.data.closing = false;
+ } else {
+ req.connection.data.closing = true;
+ }
+
+ if (req.response.headers.transfer_encoding) |te| {
+ switch (te) {
+ .chunked => {
+ req.response.parser.next_chunk_length = 0;
+ req.response.parser.state = .chunk_head_size;
+ },
+ }
+ } else if (req.response.headers.content_length) |cl| {
+ req.response.parser.next_chunk_length = cl;
+
+ if (cl == 0) req.response.parser.done = true;
+ } else {
+ req.response.parser.done = true;
+ }
+
+ if (req.response.headers.status.class() == .redirect and req.handle_redirects) {
+ req.response.skip = true;
+
+ const empty = @as([*]u8, undefined)[0..0];
+ assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary
+
+ if (req.redirects_left == 0) return error.TooManyHttpRedirects;
+
+ const location = req.response.headers.location orelse
+ return error.HttpRedirectMissingLocation;
+ const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
+
+ var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
+ const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
+ errdefer new_arena.deinit();
+
+ req.arena.deinit();
+ req.arena = new_arena;
+
+ const new_req = try req.client.request(resolved_url, req.headers, .{
+ .max_redirects = req.redirects_left - 1,
+ .header_strategy = if (req.response.parser.header_bytes_owned) .{
+ .dynamic = req.response.parser.max_header_bytes,
+ } else .{
+ .static = req.response.parser.header_bytes.items.ptr[0..req.response.parser.max_header_bytes],
+ },
+ });
+ req.deinit();
+ req.* = new_req;
+ } else {
+ req.response.skip = false;
+ if (!req.response.parser.done) {
+ if (req.response.headers.transfer_compression) |tc| switch (tc) {
+ .compress => return error.CompressionNotSupported,
+ .deflate => req.response.compression = .{
+ .deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch |err| {
+ req.client.last_error = .{ .zlib_init = err };
+ return error.CompressionInitializationFailed;
+ },
+ },
+ .gzip => req.response.compression = .{
+ .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch |err| {
+ req.client.last_error = .{ .gzip_init = err };
+ return error.CompressionInitializationFailed;
+ },
+ },
+ .zstd => req.response.compression = .{
+ .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()),
+ },
+ };
+ }
+
+ break;
+ }
+ }
+ }
+
+ pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError;
+
+ pub const Reader = std.io.Reader(*Request, ReadError, read);
+
+ pub fn reader(req: *Request) Reader {
+ return .{ .context = req };
+ }
+
+ /// Reads data from the response body. Must be called after `do`.
+ pub fn read(req: *Request, buffer: []u8) ReadError!usize {
+ while (true) {
+ const out_index = switch (req.response.compression) {
+ .deflate => |*deflate| deflate.read(buffer) catch |err| {
+ req.client.last_error = .{ .decompress = err };
+ err catch {};
+ return error.ReadFailed;
+ },
+ .gzip => |*gzip| gzip.read(buffer) catch |err| {
+ req.client.last_error = .{ .decompress = err };
+ err catch {};
+ return error.ReadFailed;
+ },
+ .zstd => |*zstd| zstd.read(buffer) catch |err| {
+ req.client.last_error = .{ .decompress = err };
+ err catch {};
+ return error.ReadFailed;
+ },
+ else => try req.transferRead(buffer),
+ };
+
+ if (out_index == 0) {
+ while (!req.response.parser.state.isContent()) { // read trailing headers
+ req.connection.data.buffered.fill() catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
+ req.connection.data.buffered.clear(@intCast(u16, nchecked));
+ }
+ }
+
+ return out_index;
+ }
+ }
+
+ /// Reads data from the response body. Must be called after `do`.
+ pub fn readAll(req: *Request, buffer: []u8) !usize {
+ var index: usize = 0;
+ while (index < buffer.len) {
+ const amt = read(req, buffer[index..]) catch |err| {
+ req.client.last_error = .{ .read = err };
+ return error.ReadFailed;
+ };
+ if (amt == 0) break;
+ index += amt;
+ }
+ return index;
+ }
+
+ pub const WriteError = error{ WriteFailed, NotWriteable, MessageTooLong };
+
+ pub const Writer = std.io.Writer(*Request, WriteError, write);
+
+ pub fn writer(req: *Request) Writer {
+ return .{ .context = req };
+ }
+
+ /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
+ pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
+ switch (req.headers.transfer_encoding) {
+ .chunked => {
+ req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len}) catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+ req.connection.data.conn.writeAll(bytes) catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+ req.connection.data.conn.writeAll("\r\n") catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+
+ return bytes.len;
+ },
+ .content_length => |*len| {
+ if (len.* < bytes.len) return error.MessageTooLong;
+
+ const amt = req.connection.data.conn.write(bytes) catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ };
+ len.* -= amt;
+ return amt;
+ },
+ .none => return error.NotWriteable,
+ }
+ }
+
+ /// Finish the body of a request. This notifies the server that you have no more data to send.
+ pub fn finish(req: *Request) !void {
+ switch (req.headers.transfer_encoding) {
+ .chunked => req.connection.data.conn.writeAll("0\r\n") catch |err| {
+ req.client.last_error = .{ .write = err };
+ return error.WriteFailed;
+ },
+ .content_length => |len| if (len != 0) return error.MessageNotCompleted,
+ .none => {},
+ }
}
};
+/// Release all associated resources with the client.
+/// TODO: currently leaks all request allocated data
pub fn deinit(client: *Client) void {
client.connection_pool.deinit(client);
@@ -227,8 +875,10 @@ pub fn deinit(client: *Client) void {
client.* = undefined;
}
-pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
+pub const ConnectError = Allocator.Error || error{ ConnectionFailed, TlsInitializationFailed };
+/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
+/// This function is threadsafe.
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
if (client.connection_pool.findConnection(.{
.host = host,
@@ -241,22 +891,36 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
+ const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| {
+ client.last_error = .{ .connect = err };
+ return error.ConnectionFailed;
+ };
+ errdefer stream.close();
+
conn.data = .{
- .stream = try net.tcpConnectToHost(client.allocator, host, port),
- .tls_client = undefined,
- .protocol = protocol,
+ .buffered = .{ .conn = .{
+ .stream = stream,
+ .tls_client = undefined,
+ .protocol = protocol,
+ } },
.host = try client.allocator.dupe(u8, host),
.port = port,
};
+ errdefer client.allocator.free(conn.data.host);
switch (protocol) {
.plain => {},
.tls => {
- conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
- conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
+ conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
+ errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client);
+
+ conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch |err| {
+ client.last_error = .{ .tls = err };
+ return error.TlsInitializationFailed;
+ };
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
- conn.data.tls_client.allow_truncation_attacks = true;
+ conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
},
}
@@ -265,24 +929,44 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
return conn;
}
-pub const RequestError = ConnectError || Connection.WriteError || error{
+pub const RequestError = ConnectError || error{
UnsupportedUrlScheme,
UriMissingHost,
- CertificateAuthorityBundleTooBig,
- InvalidPadding,
- MissingEndCertificateMarker,
- Unseekable,
- EndOfStream,
+ CertificateAuthorityBundleFailed,
+ WriteFailed,
+};
+
+pub const Options = struct {
+ handle_redirects: bool = true,
+ max_redirects: u32 = 3,
+ header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
+
+ pub const HeaderStrategy = union(enum) {
+ /// In this case, the client's Allocator will be used to store the
+ /// entire HTTP header. This value is the maximum total size of
+ /// HTTP headers allowed, otherwise
+ /// error.HttpHeadersExceededSizeLimit is returned from read().
+ dynamic: usize,
+ /// This is used to store the entire HTTP header. If the HTTP
+ /// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
+ /// is returned from read(). When this is used, `error.OutOfMemory`
+ /// cannot be returned from `read()`.
+ static: []u8,
+ };
};
-pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
- const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http"))
- .plain
- else if (mem.eql(u8, uri.scheme, "https"))
- .tls
- else
- return error.UnsupportedUrlScheme;
+pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{
+ .{ "http", .plain },
+ .{ "ws", .plain },
+ .{ "https", .tls },
+ .{ "wss", .tls },
+});
+
+/// Form and send a http request to a server.
+/// This function is threadsafe.
+pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Options) RequestError!Request {
+ const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
const port: u16 = uri.port orelse switch (protocol) {
.plain => 80,
@@ -291,91 +975,45 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
const host = uri.host orelse return error.UriMissingHost;
- if (client.next_https_rescan_certs and protocol == .tls) {
- client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
- defer client.connection_pool.mutex.unlock();
+ if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) {
+ client.ca_bundle_mutex.lock();
+ defer client.ca_bundle_mutex.unlock();
if (client.next_https_rescan_certs) {
- try client.ca_bundle.rescan(client.allocator);
- client.next_https_rescan_certs = false;
+ client.ca_bundle.rescan(client.allocator) catch |err| {
+ client.last_error = .{ .ca_bundle = err };
+ return error.CertificateAuthorityBundleFailed;
+ };
+ @atomicStore(bool, &client.next_https_rescan_certs, false, .Release);
}
}
var req: Request = .{
.uri = uri,
.client = client,
- .headers = headers,
.connection = try client.connect(host, port, protocol),
+ .headers = headers,
.redirects_left = options.max_redirects,
.handle_redirects = options.handle_redirects,
- .compression_init = false,
- .response = switch (options.header_strategy) {
- .dynamic => |max| Response.initDynamic(max),
- .static => |buf| Response.initStatic(buf),
+ .response = .{
+ .parser = switch (options.header_strategy) {
+ .dynamic => |max| proto.HeadersParser.initDynamic(max),
+ .static => |buf| proto.HeadersParser.initStatic(buf),
+ },
},
.arena = undefined,
};
+ errdefer req.deinit();
req.arena = std.heap.ArenaAllocator.init(client.allocator);
- {
- var buffered = std.io.bufferedWriter(req.connection.data.writer());
- const writer = buffered.writer();
-
- const escaped_path = try Uri.escapePath(client.allocator, uri.path);
- defer client.allocator.free(escaped_path);
-
- const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
- defer if (escaped_query) |q| client.allocator.free(q);
+ req.start(uri, headers) catch |err| {
+ if (err == error.OutOfMemory) return error.OutOfMemory;
+ const err_casted = @errSetCast(BufferedConnection.WriteError, err);
- const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
- defer if (escaped_fragment) |f| client.allocator.free(f);
-
- try writer.writeAll(@tagName(headers.method));
- try writer.writeByte(' ');
- if (escaped_path.len == 0) {
- try writer.writeByte('/');
- } else {
- try writer.writeAll(escaped_path);
- }
- if (escaped_query) |q| {
- try writer.writeByte('?');
- try writer.writeAll(q);
- }
- if (escaped_fragment) |f| {
- try writer.writeByte('#');
- try writer.writeAll(f);
- }
- try writer.writeByte(' ');
- try writer.writeAll(@tagName(headers.version));
- try writer.writeAll("\r\nHost: ");
- try writer.writeAll(host);
- try writer.writeAll("\r\nUser-Agent: ");
- try writer.writeAll(headers.user_agent);
- if (headers.connection == .close) {
- try writer.writeAll("\r\nConnection: close");
- } else {
- try writer.writeAll("\r\nConnection: keep-alive");
- }
- try writer.writeAll("\r\nAccept-Encoding: gzip, deflate, zstd");
-
- switch (headers.transfer_encoding) {
- .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"),
- .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}),
- .none => {},
- }
-
- for (headers.custom) |header| {
- try writer.writeAll("\r\n");
- try writer.writeAll(header.name);
- try writer.writeAll(": ");
- try writer.writeAll(header.value);
- }
-
- try writer.writeAll("\r\n\r\n");
-
- try buffered.flush();
- }
+ client.last_error = .{ .write = err_casted };
+ return error.WriteFailed;
+ };
return req;
}
@@ -390,5 +1028,5 @@ test {
if (builtin.os.tag == .wasi) return error.SkipZigTest;
- _ = Request;
+ std.testing.refAllDecls(@This());
}
diff --git a/lib/std/http/Client/Request.zig b/lib/std/http/Client/Request.zig
deleted file mode 100644
index 9e2ebd2d6c..0000000000
--- a/lib/std/http/Client/Request.zig
+++ /dev/null
@@ -1,482 +0,0 @@
-const std = @import("std");
-const http = std.http;
-const Uri = std.Uri;
-const mem = std.mem;
-const assert = std.debug.assert;
-
-const Client = @import("../Client.zig");
-const Connection = Client.Connection;
-const ConnectionNode = Client.ConnectionPool.Node;
-const Response = @import("Response.zig");
-
-const Request = @This();
-
-const read_buffer_size = 8192;
-const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
-
-uri: Uri,
-client: *Client,
-connection: *ConnectionNode,
-response: Response,
-/// These are stored in Request so that they are available when following
-/// redirects.
-headers: Headers,
-
-redirects_left: u32,
-handle_redirects: bool,
-compression_init: bool,
-
-/// Used as a allocator for resolving redirects locations.
-arena: std.heap.ArenaAllocator,
-
-/// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning.
-read_buffer: [read_buffer_size]u8 = undefined,
-read_buffer_start: ReadBufferIndex = 0,
-read_buffer_len: ReadBufferIndex = 0,
-
-pub const RequestTransfer = union(enum) {
- content_length: u64,
- chunked: void,
- none: void,
-};
-
-pub const Headers = struct {
- version: http.Version = .@"HTTP/1.1",
- method: http.Method = .GET,
- user_agent: []const u8 = "zig (std.http)",
- connection: http.Connection = .keep_alive,
- transfer_encoding: RequestTransfer = .none,
-
- custom: []const http.CustomHeader = &[_]http.CustomHeader{},
-};
-
-pub const Options = struct {
- handle_redirects: bool = true,
- max_redirects: u32 = 3,
- header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
-
- pub const HeaderStrategy = union(enum) {
- /// In this case, the client's Allocator will be used to store the
- /// entire HTTP header. This value is the maximum total size of
- /// HTTP headers allowed, otherwise
- /// error.HttpHeadersExceededSizeLimit is returned from read().
- dynamic: usize,
- /// This is used to store the entire HTTP header. If the HTTP
- /// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
- /// is returned from read(). When this is used, `error.OutOfMemory`
- /// cannot be returned from `read()`.
- static: []u8,
- };
-};
-
-/// Frees all resources associated with the request.
-pub fn deinit(req: *Request) void {
- switch (req.response.compression) {
- .none => {},
- .deflate => |*deflate| deflate.deinit(),
- .gzip => |*gzip| gzip.deinit(),
- .zstd => |*zstd| zstd.deinit(),
- }
-
- if (req.response.header_bytes_owned) {
- req.response.header_bytes.deinit(req.client.allocator);
- }
-
- if (!req.response.done) {
- // If the response wasn't fully read, then we need to close the connection.
- req.connection.data.closing = true;
- req.client.connection_pool.release(req.client, req.connection);
- }
-
- req.arena.deinit();
- req.* = undefined;
-}
-
-pub const ReadRawError = Connection.ReadError || Uri.ParseError || Client.RequestError || error{
- UnexpectedEndOfStream,
- TooManyHttpRedirects,
- HttpRedirectMissingLocation,
- HttpHeadersInvalid,
-};
-
-pub const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw);
-
-/// Read from the underlying stream, without decompressing or parsing the headers. Must be called
-/// after waitForCompleteHead() has returned successfully.
-pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize {
- assert(req.response.state.isContent());
-
- var index: usize = 0;
- while (index == 0) {
- const amt = try req.readRawAdvanced(buffer[index..]);
- if (amt == 0 and req.response.done) break;
- index += amt;
- }
-
- return index;
-}
-
-fn checkForCompleteHead(req: *Request, buffer: []u8) !usize {
- switch (req.response.state) {
- .invalid => unreachable,
- .start, .seen_r, .seen_rn, .seen_rnr => {},
- else => return 0, // No more headers to read.
- }
-
- const i = req.response.findHeadersEnd(buffer[0..]);
- if (req.response.state == .invalid) return error.HttpHeadersInvalid;
-
- const headers_data = buffer[0..i];
- if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {
- return error.HttpHeadersExceededSizeLimit;
- }
- try req.response.header_bytes.appendSlice(req.client.allocator, headers_data);
-
- if (req.response.state == .finished) {
- req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
-
- if (req.response.headers.upgrade) |_| {
- req.connection.data.closing = false;
- req.response.done = true;
- return i;
- }
-
- if (req.response.headers.connection == .keep_alive) {
- req.connection.data.closing = false;
- } else {
- req.connection.data.closing = true;
- }
-
- if (req.response.headers.transfer_encoding) |transfer_encoding| {
- switch (transfer_encoding) {
- .chunked => {
- req.response.next_chunk_length = 0;
- req.response.state = .chunk_size;
- },
- }
- } else if (req.response.headers.content_length) |content_length| {
- req.response.next_chunk_length = content_length;
-
- if (content_length == 0) req.response.done = true;
- } else {
- req.response.done = true;
- }
-
- return i;
- }
-
- return 0;
-}
-
-pub const WaitForCompleteHeadError = ReadRawError || error{
- UnexpectedEndOfStream,
-
- HttpHeadersExceededSizeLimit,
- ShortHttpStatusLine,
- BadHttpVersion,
- HttpHeaderContinuationsUnsupported,
- HttpTransferEncodingUnsupported,
- HttpConnectionHeaderUnsupported,
-};
-
-/// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent.
-pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void {
- if (req.response.state.isContent()) return;
-
- while (true) {
- const nread = try req.connection.data.read(req.read_buffer[0..]);
- const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]);
-
- if (amt != 0) {
- req.read_buffer_start = @intCast(ReadBufferIndex, amt);
- req.read_buffer_len = @intCast(ReadBufferIndex, nread);
- return;
- } else if (nread == 0) {
- return error.UnexpectedEndOfStream;
- }
- }
-}
-
-/// This one can return 0 without meaning EOF.
-fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
- assert(req.response.state.isContent());
- if (req.response.done) return 0;
-
- // var in: []const u8 = undefined;
- if (req.read_buffer_start == req.read_buffer_len) {
- const nread = try req.connection.data.read(req.read_buffer[0..]);
- if (nread == 0) return error.UnexpectedEndOfStream;
-
- req.read_buffer_start = 0;
- req.read_buffer_len = @intCast(ReadBufferIndex, nread);
- }
-
- var out_index: usize = 0;
- while (true) {
- switch (req.response.state) {
- .invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable,
- .finished => {
- // TODO https://github.com/ziglang/zig/issues/14039
- const buf_avail = req.read_buffer_len - req.read_buffer_start;
- const data_avail = req.response.next_chunk_length;
- const out_avail = buffer.len;
-
- if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
- const can_read = @intCast(usize, @min(buf_avail, data_avail));
- req.response.next_chunk_length -= can_read;
-
- if (req.response.next_chunk_length == 0) {
- req.client.connection_pool.release(req.client, req.connection);
- req.connection = undefined;
- req.response.done = true;
- }
-
- return 0; // skip over as much data as possible
- }
-
- const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail));
- req.response.next_chunk_length -= can_read;
-
- mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]);
- req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
-
- if (req.response.next_chunk_length == 0) {
- req.client.connection_pool.release(req.client, req.connection);
- req.connection = undefined;
- req.response.done = true;
- }
-
- return can_read;
- },
- .chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) {
- 0 => return out_index,
- 1 => switch (req.read_buffer[req.read_buffer_start]) {
- '\r' => {
- req.response.state = .chunk_size_prefix_n;
- return out_index;
- },
- else => {
- req.response.state = .invalid;
- return error.HttpHeadersInvalid;
- },
- },
- else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) {
- int16("\r\n") => {
- req.read_buffer_start += 2;
- req.response.state = .chunk_size;
- continue;
- },
- else => {
- req.response.state = .invalid;
- return error.HttpHeadersInvalid;
- },
- },
- },
- .chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) {
- 0 => return out_index,
- else => switch (req.read_buffer[req.read_buffer_start]) {
- '\n' => {
- req.read_buffer_start += 1;
- req.response.state = .chunk_size;
- continue;
- },
- else => {
- req.response.state = .invalid;
- return error.HttpHeadersInvalid;
- },
- },
- },
- .chunk_size, .chunk_r => {
- const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]);
- switch (req.response.state) {
- .invalid => return error.HttpHeadersInvalid,
- .chunk_data => {
- if (req.response.next_chunk_length == 0) {
- req.response.done = true;
- req.client.connection_pool.release(req.client, req.connection);
- req.connection = undefined;
-
- return out_index;
- }
-
- req.read_buffer_start += @intCast(ReadBufferIndex, i);
- continue;
- },
- .chunk_size => return out_index,
- else => unreachable,
- }
- },
- .chunk_data => {
- // TODO https://github.com/ziglang/zig/issues/14039
- const buf_avail = req.read_buffer_len - req.read_buffer_start;
- const data_avail = req.response.next_chunk_length;
- const out_avail = buffer.len - out_index;
-
- if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
- const can_read = @intCast(usize, @min(buf_avail, data_avail));
- req.response.next_chunk_length -= can_read;
-
- if (req.response.next_chunk_length == 0) {
- req.client.connection_pool.release(req.client, req.connection);
- req.connection = undefined;
- req.response.done = true;
- continue;
- }
-
- return 0; // skip over as much data as possible
- }
-
- const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail));
- req.response.next_chunk_length -= can_read;
-
- mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]);
- req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
- out_index += can_read;
-
- if (req.response.next_chunk_length == 0) {
- req.response.state = .chunk_size_prefix_r;
-
- continue;
- }
-
- return out_index;
- },
- }
- }
-}
-
-pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported };
-
-pub const Reader = std.io.Reader(*Request, ReadError, read);
-
-pub fn reader(req: *Request) Reader {
- return .{ .context = req };
-}
-
-pub fn read(req: *Request, buffer: []u8) ReadError!usize {
- while (true) {
- if (!req.response.state.isContent()) try req.waitForCompleteHead();
-
- if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
- assert(try req.readRaw(buffer) == 0);
-
- if (req.redirects_left == 0) return error.TooManyHttpRedirects;
-
- const location = req.response.headers.location orelse
- return error.HttpRedirectMissingLocation;
- const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
-
- var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
- const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
- errdefer new_arena.deinit();
-
- req.arena.deinit();
- req.arena = new_arena;
-
- const new_req = try req.client.request(resolved_url, req.headers, .{
- .max_redirects = req.redirects_left - 1,
- .header_strategy = if (req.response.header_bytes_owned) .{
- .dynamic = req.response.max_header_bytes,
- } else .{
- .static = req.response.header_bytes.unusedCapacitySlice(),
- },
- });
- req.deinit();
- req.* = new_req;
- } else {
- break;
- }
- }
-
- if (req.response.compression == .none) {
- if (req.response.headers.transfer_compression) |compression| {
- switch (compression) {
- .compress => return error.CompressionNotSupported,
- .deflate => req.response.compression = .{
- .deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }),
- },
- .gzip => req.response.compression = .{
- .gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }),
- },
- .zstd => req.response.compression = .{
- .zstd = std.compress.zstd.decompressStream(req.client.allocator, ReaderRaw{ .context = req }),
- },
- }
- }
- }
-
- return switch (req.response.compression) {
- .deflate => |*deflate| try deflate.read(buffer),
- .gzip => |*gzip| try gzip.read(buffer),
- .zstd => |*zstd| try zstd.read(buffer),
- else => try req.readRaw(buffer),
- };
-}
-
-pub fn readAll(req: *Request, buffer: []u8) !usize {
- var index: usize = 0;
- while (index < buffer.len) {
- const amt = try read(req, buffer[index..]);
- if (amt == 0) break;
- index += amt;
- }
- return index;
-}
-
-pub const WriteError = Connection.WriteError || error{MessageTooLong};
-
-pub const Writer = std.io.Writer(*Request, WriteError, write);
-
-pub fn writer(req: *Request) Writer {
- return .{ .context = req };
-}
-
-/// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
-pub fn write(req: *Request, bytes: []const u8) !usize {
- switch (req.headers.transfer_encoding) {
- .chunked => {
- try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
- try req.connection.data.writeAll(bytes);
- try req.connection.data.writeAll("\r\n");
-
- return bytes.len;
- },
- .content_length => |*len| {
- if (len.* < bytes.len) return error.MessageTooLong;
-
- const amt = try req.connection.data.write(bytes);
- len.* -= amt;
- return amt;
- },
- .none => return error.NotWriteable,
- }
-}
-
-/// Finish the body of a request. This notifies the server that you have no more data to send.
-pub fn finish(req: *Request) !void {
- switch (req.headers.transfer_encoding) {
- .chunked => try req.connection.data.writeAll("0\r\n"),
- .content_length => |len| if (len != 0) return error.MessageNotCompleted,
- .none => {},
- }
-}
-
-inline fn int16(array: *const [2]u8) u16 {
- return @bitCast(u16, array.*);
-}
-
-inline fn int32(array: *const [4]u8) u32 {
- return @bitCast(u32, array.*);
-}
-
-inline fn int64(array: *const [8]u8) u64 {
- return @bitCast(u64, array.*);
-}
-
-test {
- const builtin = @import("builtin");
-
- if (builtin.os.tag == .wasi) return error.SkipZigTest;
-
- _ = Response;
-}
diff --git a/lib/std/http/Client/Response.zig b/lib/std/http/Client/Response.zig
deleted file mode 100644
index f1a3b07dd8..0000000000
--- a/lib/std/http/Client/Response.zig
+++ /dev/null
@@ -1,509 +0,0 @@
-const std = @import("std");
-const http = std.http;
-const mem = std.mem;
-const testing = std.testing;
-const assert = std.debug.assert;
-
-const Client = @import("../Client.zig");
-const Response = @This();
-
-headers: Headers,
-state: State,
-header_bytes_owned: bool,
-/// This could either be a fixed buffer provided by the API user or it
-/// could be our own array list.
-header_bytes: std.ArrayListUnmanaged(u8),
-max_header_bytes: usize,
-next_chunk_length: u64,
-done: bool = false,
-
-compression: union(enum) {
- deflate: Client.DeflateDecompressor,
- gzip: Client.GzipDecompressor,
- zstd: Client.ZstdDecompressor,
- none: void,
-} = .none,
-
-pub const Headers = struct {
- status: http.Status,
- version: http.Version,
- location: ?[]const u8 = null,
- content_length: ?u64 = null,
- transfer_encoding: ?http.TransferEncoding = null,
- transfer_compression: ?http.ContentEncoding = null,
- connection: http.Connection = .close,
- upgrade: ?[]const u8 = null,
-
- number_of_headers: usize = 0,
-
- pub fn parse(bytes: []const u8) !Headers {
- var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
-
- const first_line = it.first();
- if (first_line.len < 12)
- return error.ShortHttpStatusLine;
-
- const version: http.Version = switch (int64(first_line[0..8])) {
- int64("HTTP/1.0") => .@"HTTP/1.0",
- int64("HTTP/1.1") => .@"HTTP/1.1",
- else => return error.BadHttpVersion,
- };
- if (first_line[8] != ' ') return error.HttpHeadersInvalid;
- const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*));
-
- var headers: Headers = .{
- .version = version,
- .status = status,
- };
-
- while (it.next()) |line| {
- headers.number_of_headers += 1;
-
- if (line.len == 0) return error.HttpHeadersInvalid;
- switch (line[0]) {
- ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
- else => {},
- }
- var line_it = mem.split(u8, line, ": ");
- const header_name = line_it.first();
- const header_value = line_it.rest();
- if (std.ascii.eqlIgnoreCase(header_name, "location")) {
- if (headers.location != null) return error.HttpHeadersInvalid;
- headers.location = header_value;
- } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
- if (headers.content_length != null) return error.HttpHeadersInvalid;
- headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
- } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
- // Transfer-Encoding: second, first
- // Transfer-Encoding: deflate, chunked
- var iter = std.mem.splitBackwards(u8, header_value, ",");
-
- if (iter.next()) |first| {
- const trimmed = std.mem.trim(u8, first, " ");
-
- if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
- if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
- headers.transfer_encoding = te;
- } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
- if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
- headers.transfer_compression = ce;
- } else {
- return error.HttpTransferEncodingUnsupported;
- }
- }
-
- if (iter.next()) |second| {
- if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
-
- const trimmed = std.mem.trim(u8, second, " ");
-
- if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
- headers.transfer_compression = ce;
- } else {
- return error.HttpTransferEncodingUnsupported;
- }
- }
-
- if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
- } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
- if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
-
- const trimmed = std.mem.trim(u8, header_value, " ");
-
- if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
- headers.transfer_compression = ce;
- } else {
- return error.HttpTransferEncodingUnsupported;
- }
- } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
- if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
- headers.connection = .keep_alive;
- } else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
- headers.connection = .close;
- } else {
- return error.HttpConnectionHeaderUnsupported;
- }
- } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
- headers.upgrade = header_value;
- }
- }
-
- return headers;
- }
-
- test "parse headers" {
- const example =
- "HTTP/1.1 301 Moved Permanently\r\n" ++
- "Location: https://www.example.com/\r\n" ++
- "Content-Type: text/html; charset=UTF-8\r\n" ++
- "Content-Length: 220\r\n\r\n";
- const parsed = try Headers.parse(example);
- try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version);
- try testing.expectEqual(http.Status.moved_permanently, parsed.status);
- try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse
- return error.TestFailed);
- try testing.expectEqual(@as(?u64, 220), parsed.content_length);
- }
-
- test "header continuation" {
- const example =
- "HTTP/1.0 200 OK\r\n" ++
- "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++
- "Content-Length: 220\r\n\r\n";
- try testing.expectError(
- error.HttpHeaderContinuationsUnsupported,
- Headers.parse(example),
- );
- }
-
- test "extra content length" {
- const example =
- "HTTP/1.0 200 OK\r\n" ++
- "Content-Length: 220\r\n" ++
- "Content-Type: text/html; charset=UTF-8\r\n" ++
- "content-length: 220\r\n\r\n";
- try testing.expectError(
- error.HttpHeadersInvalid,
- Headers.parse(example),
- );
- }
-};
-
-inline fn int16(array: *const [2]u8) u16 {
- return @bitCast(u16, array.*);
-}
-
-inline fn int32(array: *const [4]u8) u32 {
- return @bitCast(u32, array.*);
-}
-
-inline fn int64(array: *const [8]u8) u64 {
- return @bitCast(u64, array.*);
-}
-
-pub const State = enum {
- /// Begin header parsing states.
- invalid,
- start,
- seen_r,
- seen_rn,
- seen_rnr,
- finished,
- /// Begin transfer-encoding: chunked parsing states.
- chunk_size_prefix_r,
- chunk_size_prefix_n,
- chunk_size,
- chunk_r,
- chunk_data,
-
- pub fn isContent(self: State) bool {
- return switch (self) {
- .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false,
- .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true,
- };
- }
-};
-
-pub fn initDynamic(max: usize) Response {
- return .{
- .state = .start,
- .headers = undefined,
- .header_bytes = .{},
- .max_header_bytes = max,
- .header_bytes_owned = true,
- .next_chunk_length = undefined,
- };
-}
-
-pub fn initStatic(buf: []u8) Response {
- return .{
- .state = .start,
- .headers = undefined,
- .header_bytes = .{ .items = buf[0..0], .capacity = buf.len },
- .max_header_bytes = buf.len,
- .header_bytes_owned = false,
- .next_chunk_length = undefined,
- };
-}
-
-/// Returns how many bytes are part of HTTP headers. Always less than or
-/// equal to bytes.len. If the amount returned is less than bytes.len, it
-/// means the headers ended and the first byte after the double \r\n\r\n is
-/// located at `bytes[result]`.
-pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize {
- var index: usize = 0;
-
- // TODO: https://github.com/ziglang/zig/issues/8220
- state: while (true) {
- switch (r.state) {
- .invalid => unreachable,
- .finished => unreachable,
- .start => while (true) {
- switch (bytes.len - index) {
- 0 => return index,
- 1 => {
- if (bytes[index] == '\r')
- r.state = .seen_r;
- return index + 1;
- },
- 2 => {
- if (int16(bytes[index..][0..2]) == int16("\r\n")) {
- r.state = .seen_rn;
- } else if (bytes[index + 1] == '\r') {
- r.state = .seen_r;
- }
- return index + 2;
- },
- 3 => {
- if (int16(bytes[index..][0..2]) == int16("\r\n") and
- bytes[index + 2] == '\r')
- {
- r.state = .seen_rnr;
- } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) {
- r.state = .seen_rn;
- } else if (bytes[index + 2] == '\r') {
- r.state = .seen_r;
- }
- return index + 3;
- },
- 4...15 => {
- if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) {
- r.state = .finished;
- return index + 4;
- } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and
- bytes[index + 3] == '\r')
- {
- r.state = .seen_rnr;
- index += 4;
- continue :state;
- } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) {
- r.state = .seen_rn;
- index += 4;
- continue :state;
- } else if (bytes[index + 3] == '\r') {
- r.state = .seen_r;
- index += 4;
- continue :state;
- }
- index += 4;
- continue;
- },
- else => {
- const chunk = bytes[index..][0..16];
- const v: @Vector(16, u8) = chunk.*;
- const matches_r = v == @splat(16, @as(u8, '\r'));
- const iota = std.simd.iota(u8, 16);
- const default = @splat(16, @as(u8, 16));
- const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default));
- switch (sub_index) {
- 0...12 => {
- index += sub_index + 4;
- if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) {
- r.state = .finished;
- return index;
- }
- continue;
- },
- 13 => {
- index += 16;
- if (int16(chunk[14..][0..2]) == int16("\n\r")) {
- r.state = .seen_rnr;
- continue :state;
- }
- continue;
- },
- 14 => {
- index += 16;
- if (chunk[15] == '\n') {
- r.state = .seen_rn;
- continue :state;
- }
- continue;
- },
- 15 => {
- r.state = .seen_r;
- index += 16;
- continue :state;
- },
- 16 => {
- index += 16;
- continue;
- },
- else => unreachable,
- }
- },
- }
- },
-
- .seen_r => switch (bytes.len - index) {
- 0 => return index,
- 1 => {
- switch (bytes[index]) {
- '\n' => r.state = .seen_rn,
- '\r' => r.state = .seen_r,
- else => r.state = .start,
- }
- return index + 1;
- },
- 2 => {
- if (int16(bytes[index..][0..2]) == int16("\n\r")) {
- r.state = .seen_rnr;
- return index + 2;
- }
- r.state = .start;
- return index + 2;
- },
- else => {
- if (int16(bytes[index..][0..2]) == int16("\n\r") and
- bytes[index + 2] == '\n')
- {
- r.state = .finished;
- return index + 3;
- }
- index += 3;
- r.state = .start;
- continue :state;
- },
- },
- .seen_rn => switch (bytes.len - index) {
- 0 => return index,
- 1 => {
- switch (bytes[index]) {
- '\r' => r.state = .seen_rnr,
- else => r.state = .start,
- }
- return index + 1;
- },
- else => {
- if (int16(bytes[index..][0..2]) == int16("\r\n")) {
- r.state = .finished;
- return index + 2;
- }
- index += 2;
- r.state = .start;
- continue :state;
- },
- },
- .seen_rnr => switch (bytes.len - index) {
- 0 => return index,
- else => {
- if (bytes[index] == '\n') {
- r.state = .finished;
- return index + 1;
- }
- index += 1;
- r.state = .start;
- continue :state;
- },
- },
- .chunk_size_prefix_r => unreachable,
- .chunk_size_prefix_n => unreachable,
- .chunk_size => unreachable,
- .chunk_r => unreachable,
- .chunk_data => unreachable,
- }
-
- return index;
- }
-}
-
-pub fn findChunkedLen(r: *Response, bytes: []const u8) usize {
- var i: usize = 0;
- if (r.state == .chunk_size) {
- while (i < bytes.len) : (i += 1) {
- const digit = switch (bytes[i]) {
- '0'...'9' => |b| b - '0',
- 'A'...'Z' => |b| b - 'A' + 10,
- 'a'...'z' => |b| b - 'a' + 10,
- '\r' => {
- r.state = .chunk_r;
- i += 1;
- break;
- },
- else => {
- r.state = .invalid;
- return i;
- },
- };
- const mul = @mulWithOverflow(r.next_chunk_length, 16);
- if (mul[1] != 0) {
- r.state = .invalid;
- return i;
- }
- const add = @addWithOverflow(mul[0], digit);
- if (add[1] != 0) {
- r.state = .invalid;
- return i;
- }
- r.next_chunk_length = add[0];
- } else {
- return i;
- }
- }
- assert(r.state == .chunk_r);
- if (i == bytes.len) return i;
-
- if (bytes[i] == '\n') {
- r.state = .chunk_data;
- return i + 1;
- } else {
- r.state = .invalid;
- return i;
- }
-}
-
-fn parseInt3(nnn: @Vector(3, u8)) u10 {
- const zero: @Vector(3, u8) = .{ '0', '0', '0' };
- const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
- return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
-}
-
-test parseInt3 {
- const expectEqual = std.testing.expectEqual;
- try expectEqual(@as(u10, 0), parseInt3("000".*));
- try expectEqual(@as(u10, 418), parseInt3("418".*));
- try expectEqual(@as(u10, 999), parseInt3("999".*));
-}
-
-test "find headers end basic" {
- var buffer: [1]u8 = undefined;
- var r = Response.initStatic(&buffer);
- try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4"));
- try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18"));
- try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah"));
-}
-
-test "find headers end vectorized" {
- var buffer: [1]u8 = undefined;
- var r = Response.initStatic(&buffer);
- const example =
- "HTTP/1.1 301 Moved Permanently\r\n" ++
- "Location: https://www.example.com/\r\n" ++
- "Content-Type: text/html; charset=UTF-8\r\n" ++
- "Content-Length: 220\r\n" ++
- "\r\ncontent";
- try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example));
-}
-
-test "find headers end bug" {
- var buffer: [1]u8 = undefined;
- var r = Response.initStatic(&buffer);
- const trail = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
- const example =
- "HTTP/1.1 200 OK\r\n" ++
- "Access-Control-Allow-Origin: https://render.githubusercontent.com\r\n" ++
- "content-disposition: attachment; filename=zig-0.10.0.tar.gz\r\n" ++
- "Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; sandbox\r\n" ++
- "Content-Type: application/x-gzip\r\n" ++
- "ETag: \"bfae0af6b01c7c0d89eb667cb5f0e65265968aeebda2689177e6b26acd3155ca\"\r\n" ++
- "Strict-Transport-Security: max-age=31536000\r\n" ++
- "Vary: Authorization,Accept-Encoding,Origin\r\n" ++
- "X-Content-Type-Options: nosniff\r\n" ++
- "X-Frame-Options: deny\r\n" ++
- "X-XSS-Protection: 1; mode=block\r\n" ++
- "Date: Fri, 06 Jan 2023 22:26:22 GMT\r\n" ++
- "Transfer-Encoding: chunked\r\n" ++
- "X-GitHub-Request-Id: 89C6:17E9:A7C9E:124B51:63B8A00E\r\n" ++
- "connection: close\r\n\r\n" ++ trail;
- try testing.expectEqual(@as(usize, example.len - trail.len), r.findHeadersEnd(example));
-}
diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig
new file mode 100644
index 0000000000..85fbf25265
--- /dev/null
+++ b/lib/std/http/Server.zig
@@ -0,0 +1,600 @@
+const std = @import("../std.zig");
+const testing = std.testing;
+const http = std.http;
+const mem = std.mem;
+const net = std.net;
+const Uri = std.Uri;
+const Allocator = mem.Allocator;
+const assert = std.debug.assert;
+
+const Server = @This();
+const proto = @import("protocol.zig");
+
+allocator: Allocator,
+
+socket: net.StreamServer,
+
+/// An interface to either a plain or TLS connection.
+pub const Connection = struct {
+ stream: net.Stream,
+ protocol: Protocol,
+
+ closing: bool = true,
+
+ pub const Protocol = enum { plain };
+
+ pub fn read(conn: *Connection, buffer: []u8) !usize {
+ switch (conn.protocol) {
+ .plain => return conn.stream.read(buffer),
+ // .tls => return conn.tls_client.read(conn.stream, buffer),
+ }
+ }
+
+ pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
+ switch (conn.protocol) {
+ .plain => return conn.stream.readAtLeast(buffer, len),
+ // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
+ }
+ }
+
+ pub const ReadError = net.Stream.ReadError;
+
+ pub const Reader = std.io.Reader(*Connection, ReadError, read);
+
+ pub fn reader(conn: *Connection) Reader {
+ return Reader{ .context = conn };
+ }
+
+ pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
+ switch (conn.protocol) {
+ .plain => return conn.stream.writeAll(buffer),
+ // .tls => return conn.tls_client.writeAll(conn.stream, buffer),
+ }
+ }
+
+ pub fn write(conn: *Connection, buffer: []const u8) !usize {
+ switch (conn.protocol) {
+ .plain => return conn.stream.write(buffer),
+ // .tls => return conn.tls_client.write(conn.stream, buffer),
+ }
+ }
+
+ pub const WriteError = net.Stream.WriteError || error{};
+ pub const Writer = std.io.Writer(*Connection, WriteError, write);
+
+ pub fn writer(conn: *Connection) Writer {
+ return Writer{ .context = conn };
+ }
+
+ pub fn close(conn: *Connection) void {
+ conn.stream.close();
+ }
+};
+
+/// A buffered (and peekable) Connection.
+pub const BufferedConnection = struct {
+ pub const buffer_size = 0x2000;
+
+ conn: Connection,
+ buf: [buffer_size]u8 = undefined,
+ start: u16 = 0,
+ end: u16 = 0,
+
+ pub fn fill(bconn: *BufferedConnection) ReadError!void {
+ if (bconn.end != bconn.start) return;
+
+ const nread = try bconn.conn.read(bconn.buf[0..]);
+ if (nread == 0) return error.EndOfStream;
+ bconn.start = 0;
+ bconn.end = @truncate(u16, nread);
+ }
+
+ pub fn peek(bconn: *BufferedConnection) []const u8 {
+ return bconn.buf[bconn.start..bconn.end];
+ }
+
+ pub fn clear(bconn: *BufferedConnection, num: u16) void {
+ bconn.start += num;
+ }
+
+ pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
+ var out_index: u16 = 0;
+ while (out_index < len) {
+ const available = bconn.end - bconn.start;
+ const left = buffer.len - out_index;
+
+ if (available > 0) {
+ const can_read = @truncate(u16, @min(available, left));
+
+ std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
+ out_index += can_read;
+ bconn.start += can_read;
+
+ continue;
+ }
+
+ if (left > bconn.buf.len) {
+ // skip the buffer if the output is large enough
+ return bconn.conn.read(buffer[out_index..]);
+ }
+
+ try bconn.fill();
+ }
+
+ return out_index;
+ }
+
+ pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
+ return bconn.readAtLeast(buffer, 1);
+ }
+
+ pub const ReadError = Connection.ReadError || error{EndOfStream};
+ pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
+
+ pub fn reader(bconn: *BufferedConnection) Reader {
+ return Reader{ .context = bconn };
+ }
+
+ pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
+ return bconn.conn.writeAll(buffer);
+ }
+
+ pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
+ return bconn.conn.write(buffer);
+ }
+
+ pub const WriteError = Connection.WriteError;
+ pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
+
+ pub fn writer(bconn: *BufferedConnection) Writer {
+ return Writer{ .context = bconn };
+ }
+
+ pub fn close(bconn: *BufferedConnection) void {
+ bconn.conn.close();
+ }
+};
+
+/// A HTTP request originating from a client.
+pub const Request = struct {
+ pub const Headers = struct {
+ method: http.Method,
+ target: []const u8,
+ version: http.Version,
+ content_length: ?u64 = null,
+ transfer_encoding: ?http.TransferEncoding = null,
+ transfer_compression: ?http.ContentEncoding = null,
+ connection: http.Connection = .close,
+ host: ?[]const u8 = null,
+
+ pub const ParseError = error{
+ ShortHttpStatusLine,
+ BadHttpVersion,
+ UnknownHttpMethod,
+ HttpHeadersInvalid,
+ HttpHeaderContinuationsUnsupported,
+ HttpTransferEncodingUnsupported,
+ HttpConnectionHeaderUnsupported,
+ InvalidCharacter,
+ };
+
+ pub fn parse(bytes: []const u8) !Headers {
+ var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
+
+ const first_line = it.next() orelse return error.HttpHeadersInvalid;
+ if (first_line.len < 10)
+ return error.ShortHttpStatusLine;
+
+ const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
+ const method_str = first_line[0..method_end];
+ const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod;
+
+ const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid;
+ if (version_start == method_end) return error.HttpHeadersInvalid;
+
+ const version_str = first_line[version_start + 1 ..];
+ if (version_str.len != 8) return error.HttpHeadersInvalid;
+ const version: http.Version = switch (int64(version_str[0..8])) {
+ int64("HTTP/1.0") => .@"HTTP/1.0",
+ int64("HTTP/1.1") => .@"HTTP/1.1",
+ else => return error.BadHttpVersion,
+ };
+
+ const target = first_line[method_end + 1 .. version_start];
+
+ var headers: Headers = .{
+ .method = method,
+ .target = target,
+ .version = version,
+ };
+
+ while (it.next()) |line| {
+ if (line.len == 0) return error.HttpHeadersInvalid;
+ switch (line[0]) {
+ ' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
+ else => {},
+ }
+
+ var line_it = mem.tokenize(u8, line, ": ");
+ const header_name = line_it.next() orelse return error.HttpHeadersInvalid;
+ const header_value = line_it.rest();
+ if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
+ if (headers.content_length != null) return error.HttpHeadersInvalid;
+ headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
+ } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
+ // Transfer-Encoding: second, first
+ // Transfer-Encoding: deflate, chunked
+ var iter = mem.splitBackwards(u8, header_value, ",");
+
+ if (iter.next()) |first| {
+ const trimmed = mem.trim(u8, first, " ");
+
+ if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| {
+ if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
+ headers.transfer_encoding = te;
+ } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |second| {
+ if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
+
+ const trimmed = mem.trim(u8, second, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ }
+
+ if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
+ if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
+
+ const trimmed = mem.trim(u8, header_value, " ");
+
+ if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
+ headers.transfer_compression = ce;
+ } else {
+ return error.HttpTransferEncodingUnsupported;
+ }
+ } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
+ if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
+ headers.connection = .keep_alive;
+ } else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
+ headers.connection = .close;
+ } else {
+ return error.HttpConnectionHeaderUnsupported;
+ }
+ } else if (std.ascii.eqlIgnoreCase(header_name, "host")) {
+ headers.host = header_value;
+ }
+ }
+
+ return headers;
+ }
+
+ inline fn int64(array: *const [8]u8) u64 {
+ return @bitCast(u64, array.*);
+ }
+ };
+
+ headers: Headers = undefined,
+ parser: proto.HeadersParser,
+ compression: Compression = .none,
+};
+
+/// A HTTP response waiting to be sent.
+///
+/// [/ <----------------------------------- \]
+/// Order of operations: accept -> wait -> do [ -> write -> finish][ -> reset /]
+/// \ -> read /
+pub const Response = struct {
+ pub const Headers = struct {
+ version: http.Version = .@"HTTP/1.1",
+ status: http.Status = .ok,
+ reason: ?[]const u8 = null,
+
+ server: ?[]const u8 = "zig (std.http)",
+ connection: http.Connection = .keep_alive,
+ transfer_encoding: RequestTransfer = .none,
+
+ custom: []const http.CustomHeader = &[_]http.CustomHeader{},
+ };
+
+ server: *Server,
+ address: net.Address,
+ connection: BufferedConnection,
+
+ headers: Headers = .{},
+ request: Request,
+
+ /// Reset this response to its initial state. This must be called before handling a second request on the same connection.
+ pub fn reset(res: *Response) void {
+ switch (res.request.compression) {
+ .none => {},
+ .deflate => |*deflate| deflate.deinit(),
+ .gzip => |*gzip| gzip.deinit(),
+ .zstd => |*zstd| zstd.deinit(),
+ }
+
+ if (!res.request.parser.done) {
+ // If the response wasn't fully read, then we need to close the connection.
+ res.connection.conn.closing = true;
+ }
+
+ if (res.connection.conn.closing) {
+ res.connection.close();
+
+ if (res.request.parser.header_bytes_owned) {
+ res.request.parser.header_bytes.deinit(res.server.allocator);
+ }
+
+ res.* = undefined;
+ } else {
+ res.request.parser.reset();
+ }
+ }
+
+ /// Send the response headers.
+ pub fn do(res: *Response) !void {
+ var buffered = std.io.bufferedWriter(res.connection.writer());
+ const w = buffered.writer();
+
+ try w.writeAll(@tagName(res.headers.version));
+ try w.writeByte(' ');
+ try w.print("{d}", .{@enumToInt(res.headers.status)});
+ try w.writeByte(' ');
+ if (res.headers.reason) |reason| {
+ try w.writeAll(reason);
+ } else if (res.headers.status.phrase()) |phrase| {
+ try w.writeAll(phrase);
+ }
+
+ if (res.headers.server) |server| {
+ try w.writeAll("\r\nServer: ");
+ try w.writeAll(server);
+ }
+
+ if (res.headers.connection == .close) {
+ try w.writeAll("\r\nConnection: close");
+ } else {
+ try w.writeAll("\r\nConnection: keep-alive");
+ }
+
+ switch (res.headers.transfer_encoding) {
+ .chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"),
+ .content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}),
+ .none => {},
+ }
+
+ for (res.headers.custom) |header| {
+ try w.writeAll("\r\n");
+ try w.writeAll(header.name);
+ try w.writeAll(": ");
+ try w.writeAll(header.value);
+ }
+
+ try w.writeAll("\r\n\r\n");
+
+ try buffered.flush();
+ }
+
+ pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
+
+ pub const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead);
+
+ pub fn transferReader(res: *Response) TransferReader {
+ return .{ .context = res };
+ }
+
+ pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
+ if (res.request.parser.isComplete()) return 0;
+
+ var index: usize = 0;
+ while (index == 0) {
+ const amt = try res.request.parser.read(&res.connection, buf[index..], false);
+ if (amt == 0 and res.request.parser.isComplete()) break;
+ index += amt;
+ }
+
+ return index;
+ }
+
+ pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
+
+ /// Wait for the client to send a complete request head.
+ pub fn wait(res: *Response) !void {
+ while (true) {
+ try res.connection.fill();
+
+ const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
+ res.connection.clear(@intCast(u16, nchecked));
+
+ if (res.request.parser.state.isContent()) break;
+ }
+
+ res.request.headers = try Request.Headers.parse(res.request.parser.header_bytes.items);
+
+ if (res.headers.connection == .keep_alive and res.request.headers.connection == .keep_alive) {
+ res.connection.conn.closing = false;
+ } else {
+ res.connection.conn.closing = true;
+ }
+
+ if (res.request.headers.transfer_encoding) |te| {
+ switch (te) {
+ .chunked => {
+ res.request.parser.next_chunk_length = 0;
+ res.request.parser.state = .chunk_head_size;
+ },
+ }
+ } else if (res.request.headers.content_length) |cl| {
+ res.request.parser.next_chunk_length = cl;
+
+ if (cl == 0) res.request.parser.done = true;
+ } else {
+ res.request.parser.done = true;
+ }
+
+ if (!res.request.parser.done) {
+ if (res.request.headers.transfer_compression) |tc| switch (tc) {
+ .compress => return error.CompressionNotSupported,
+ .deflate => res.request.compression = .{
+ .deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()),
+ },
+ .gzip => res.request.compression = .{
+ .gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()),
+ },
+ .zstd => res.request.compression = .{
+ .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()),
+ },
+ };
+ }
+ }
+
+ pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError;
+
+ pub const Reader = std.io.Reader(*Response, ReadError, read);
+
+ pub fn reader(res: *Response) Reader {
+ return .{ .context = res };
+ }
+
+ pub fn read(res: *Response, buffer: []u8) ReadError!usize {
+ return switch (res.request.compression) {
+ .deflate => |*deflate| try deflate.read(buffer),
+ .gzip => |*gzip| try gzip.read(buffer),
+ .zstd => |*zstd| try zstd.read(buffer),
+ else => try res.transferRead(buffer),
+ };
+ }
+
+ pub fn readAll(res: *Response, buffer: []u8) !usize {
+ var index: usize = 0;
+ while (index < buffer.len) {
+ const amt = try read(res, buffer[index..]);
+ if (amt == 0) break;
+ index += amt;
+ }
+ return index;
+ }
+
+ pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
+
+ pub const Writer = std.io.Writer(*Response, WriteError, write);
+
+ pub fn writer(res: *Response) Writer {
+ return .{ .context = res };
+ }
+
+ /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
+ pub fn write(res: *Response, bytes: []const u8) WriteError!usize {
+ switch (res.headers.transfer_encoding) {
+ .chunked => {
+ try res.connection.writer().print("{x}\r\n", .{bytes.len});
+ try res.connection.writeAll(bytes);
+ try res.connection.writeAll("\r\n");
+
+ return bytes.len;
+ },
+ .content_length => |*len| {
+ if (len.* < bytes.len) return error.MessageTooLong;
+
+ const amt = try res.connection.write(bytes);
+ len.* -= amt;
+ return amt;
+ },
+ .none => return error.NotWriteable,
+ }
+ }
+
+ /// Finish the body of a request. This notifies the server that you have no more data to send.
+ pub fn finish(res: *Response) !void {
+ switch (res.headers.transfer_encoding) {
+ .chunked => try res.connection.writeAll("0\r\n"),
+ .content_length => |len| if (len != 0) return error.MessageNotCompleted,
+ .none => {},
+ }
+ }
+};
+
+/// The mode of transport for responses.
+pub const RequestTransfer = union(enum) {
+ content_length: u64,
+ chunked: void,
+ none: void,
+};
+
+/// The decompressor for request messages.
+pub const Compression = union(enum) {
+ pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
+ pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
+ pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
+
+ deflate: DeflateDecompressor,
+ gzip: GzipDecompressor,
+ zstd: ZstdDecompressor,
+ none: void,
+};
+
+pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server {
+ return .{
+ .allocator = allocator,
+ .socket = net.StreamServer.init(options),
+ };
+}
+
+pub fn deinit(server: *Server) void {
+ server.socket.deinit();
+}
+
+pub const ListenError = std.os.SocketError || std.os.BindError || std.os.ListenError || std.os.SetSockOptError || std.os.GetSockNameError;
+
+/// Start the HTTP server listening on the given address.
+pub fn listen(server: *Server, address: net.Address) !void {
+ try server.socket.listen(address);
+}
+
+pub const AcceptError = net.StreamServer.AcceptError || Allocator.Error;
+
+pub const HeaderStrategy = union(enum) {
+ /// In this case, the client's Allocator will be used to store the
+ /// entire HTTP header. This value is the maximum total size of
+ /// HTTP headers allowed, otherwise
+ /// error.HttpHeadersExceededSizeLimit is returned from read().
+ dynamic: usize,
+ /// This is used to store the entire HTTP header. If the HTTP
+ /// header is too big to fit, `error.HttpHeadersExceededSizeLimit`
+ /// is returned from read(). When this is used, `error.OutOfMemory`
+ /// cannot be returned from `read()`.
+ static: []u8,
+};
+
+/// Accept a new connection and allocate a Response for it.
+pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response {
+ const in = try server.socket.accept();
+
+ const res = try server.allocator.create(Response);
+ res.* = .{
+ .server = server,
+ .address = in.address,
+ .connection = .{ .conn = .{
+ .stream = in.stream,
+ .protocol = .plain,
+ } },
+ .request = .{
+ .parser = switch (options) {
+ .dynamic => |max| proto.HeadersParser.initDynamic(max),
+ .static => |buf| proto.HeadersParser.initStatic(buf),
+ },
+ },
+ };
+
+ return res;
+}
diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig
new file mode 100644
index 0000000000..5e63d3092b
--- /dev/null
+++ b/lib/std/http/protocol.zig
@@ -0,0 +1,842 @@
+const std = @import("std");
+const testing = std.testing;
+const mem = std.mem;
+
+const assert = std.debug.assert;
+
+pub const State = enum {
+ /// Begin header parsing states.
+ invalid,
+ start,
+ seen_n,
+ seen_r,
+ seen_rn,
+ seen_rnr,
+ finished,
+ /// Begin transfer-encoding: chunked parsing states.
+ chunk_head_size,
+ chunk_head_ext,
+ chunk_head_r,
+ chunk_data,
+ chunk_data_suffix,
+ chunk_data_suffix_r,
+
+ /// Returns true if the parser is in a content state (ie. not waiting for more headers).
+ pub fn isContent(self: State) bool {
+ return switch (self) {
+ .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false,
+ .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true,
+ };
+ }
+};
+
+pub const HeadersParser = struct {
+ state: State = .start,
+ /// Whether or not `header_bytes` is allocated or was provided as a fixed buffer.
+ header_bytes_owned: bool,
+ /// Either a fixed buffer of len `max_header_bytes` or a dynamic buffer that can grow up to `max_header_bytes`.
+ /// Pointers into this buffer are not stable until after a message is complete.
+ header_bytes: std.ArrayListUnmanaged(u8),
+ /// The maximum allowed size of `header_bytes`.
+ max_header_bytes: usize,
+ next_chunk_length: u64 = 0,
+ /// Whether this parser is done parsing a complete message.
+ /// A message is only done when the entire payload has been read.
+ done: bool = false,
+
+ /// Initializes the parser with a dynamically growing header buffer of up to `max` bytes.
+ pub fn initDynamic(max: usize) HeadersParser {
+ return .{
+ .header_bytes = .{},
+ .max_header_bytes = max,
+ .header_bytes_owned = true,
+ };
+ }
+
+ /// Initializes the parser with a provided buffer `buf`.
+ pub fn initStatic(buf: []u8) HeadersParser {
+ return .{
+ .header_bytes = .{ .items = buf[0..0], .capacity = buf.len },
+ .max_header_bytes = buf.len,
+ .header_bytes_owned = false,
+ };
+ }
+
+ /// Completely resets the parser to it's initial state.
+ /// This must be called after a message is complete.
+ pub fn reset(r: *HeadersParser) void {
+ assert(r.done); // The message must be completely read before reset, otherwise the parser is in an invalid state.
+
+ r.header_bytes.clearRetainingCapacity();
+
+ r.* = .{
+ .header_bytes = r.header_bytes,
+ .max_header_bytes = r.max_header_bytes,
+ .header_bytes_owned = r.header_bytes_owned,
+ };
+ }
+
+ /// Returns the number of bytes consumed by headers. This is always less than or equal to `bytes.len`.
+ /// You should check `r.state.isContent()` after this to check if the headers are done.
+ ///
+ /// If the amount returned is less than `bytes.len`, you may assume that the parser is in a content state and the
+ /// first byte of content is located at `bytes[result]`.
+ pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 {
+ const vector_len: comptime_int = comptime std.math.max(std.simd.suggestVectorSize(u8) orelse 1, 8);
+ const len = @intCast(u32, bytes.len);
+ var index: u32 = 0;
+
+ while (true) {
+ switch (r.state) {
+ .invalid => unreachable,
+ .finished => return index,
+ .start => switch (len - index) {
+ 0 => return index,
+ 1 => {
+ switch (bytes[index]) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ }
+
+ return index + 1;
+ },
+ 2 => {
+ const b16 = int16(bytes[index..][0..2]);
+ const b8 = intShift(u8, b16);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ return index + 2;
+ },
+ 3 => {
+ const b24 = int24(bytes[index..][0..3]);
+ const b16 = intShift(u16, b24);
+ const b8 = intShift(u8, b24);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ switch (b24) {
+ int24("\r\n\r") => r.state = .seen_rnr,
+ else => {},
+ }
+
+ return index + 3;
+ },
+ 4...vector_len - 1 => {
+ const b32 = int32(bytes[index..][0..4]);
+ const b24 = intShift(u24, b32);
+ const b16 = intShift(u16, b32);
+ const b8 = intShift(u8, b32);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ switch (b24) {
+ int24("\r\n\r") => r.state = .seen_rnr,
+ else => {},
+ }
+
+ switch (b32) {
+ int32("\r\n\r\n") => r.state = .finished,
+ else => {},
+ }
+
+ index += 4;
+ continue;
+ },
+ else => {
+ const Vector = @Vector(vector_len, u8);
+ // const BoolVector = @Vector(vector_len, bool);
+ const BitVector = @Vector(vector_len, u1);
+ const SizeVector = @Vector(vector_len, u8);
+
+ const chunk = bytes[index..][0..vector_len];
+ const v: Vector = chunk.*;
+ const matches_r = @bitCast(BitVector, v == @splat(vector_len, @as(u8, '\r')));
+ const matches_n = @bitCast(BitVector, v == @splat(vector_len, @as(u8, '\n')));
+ const matches_or: SizeVector = matches_r | matches_n;
+
+ const matches = @reduce(.Add, matches_or);
+ switch (matches) {
+ 0 => {},
+ 1 => switch (chunk[vector_len - 1]) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ },
+ 2 => {
+ const b16 = int16(chunk[vector_len - 2 ..][0..2]);
+ const b8 = intShift(u8, b16);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+ },
+ 3 => {
+ const b24 = int24(chunk[vector_len - 3 ..][0..3]);
+ const b16 = intShift(u16, b24);
+ const b8 = intShift(u8, b24);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ switch (b24) {
+ int24("\r\n\r") => r.state = .seen_rnr,
+ else => {},
+ }
+ },
+ 4...vector_len => {
+ inline for (0..vector_len - 3) |i_usize| {
+ const i = @truncate(u32, i_usize);
+
+ const b32 = int32(chunk[i..][0..4]);
+ const b16 = intShift(u16, b32);
+
+ if (b32 == int32("\r\n\r\n")) {
+ r.state = .finished;
+ return index + i + 4;
+ } else if (b16 == int16("\n\n")) {
+ r.state = .finished;
+ return index + i + 2;
+ }
+ }
+
+ const b24 = int24(chunk[vector_len - 3 ..][0..3]);
+ const b16 = intShift(u16, b24);
+ const b8 = intShift(u8, b24);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => {},
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ switch (b24) {
+ int24("\r\n\r") => r.state = .seen_rnr,
+ else => {},
+ }
+ },
+ else => unreachable,
+ }
+
+ index += vector_len;
+ continue;
+ },
+ },
+ .seen_n => switch (len - index) {
+ 0 => return index,
+ else => {
+ switch (bytes[index]) {
+ '\n' => r.state = .finished,
+ else => r.state = .start,
+ }
+
+ index += 1;
+ continue;
+ },
+ },
+ .seen_r => switch (len - index) {
+ 0 => return index,
+ 1 => {
+ switch (bytes[index]) {
+ '\n' => r.state = .seen_rn,
+ '\r' => r.state = .seen_r,
+ else => r.state = .start,
+ }
+
+ return index + 1;
+ },
+ 2 => {
+ const b16 = int16(bytes[index..][0..2]);
+ const b8 = intShift(u8, b16);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_rn,
+ else => r.state = .start,
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\r") => r.state = .seen_rnr,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ return index + 2;
+ },
+ else => {
+ const b24 = int24(bytes[index..][0..3]);
+ const b16 = intShift(u16, b24);
+ const b8 = intShift(u8, b24);
+
+ switch (b8) {
+ '\r' => r.state = .seen_r,
+ '\n' => r.state = .seen_n,
+ else => r.state = .start,
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .seen_rn,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ switch (b24) {
+ int24("\n\r\n") => r.state = .finished,
+ else => {},
+ }
+
+ index += 3;
+ continue;
+ },
+ },
+ .seen_rn => switch (len - index) {
+ 0 => return index,
+ 1 => {
+ switch (bytes[index]) {
+ '\r' => r.state = .seen_rnr,
+ '\n' => r.state = .seen_n,
+ else => r.state = .start,
+ }
+
+ return index + 1;
+ },
+ else => {
+ const b16 = int16(bytes[index..][0..2]);
+ const b8 = intShift(u8, b16);
+
+ switch (b8) {
+ '\r' => r.state = .seen_rnr,
+ '\n' => r.state = .seen_n,
+ else => r.state = .start,
+ }
+
+ switch (b16) {
+ int16("\r\n") => r.state = .finished,
+ int16("\n\n") => r.state = .finished,
+ else => {},
+ }
+
+ index += 2;
+ continue;
+ },
+ },
+ .seen_rnr => switch (len - index) {
+ 0 => return index,
+ else => {
+ switch (bytes[index]) {
+ '\n' => r.state = .finished,
+ else => r.state = .start,
+ }
+
+ index += 1;
+ continue;
+ },
+ },
+ .chunk_head_size => unreachable,
+ .chunk_head_ext => unreachable,
+ .chunk_head_r => unreachable,
+ .chunk_data => unreachable,
+ .chunk_data_suffix => unreachable,
+ .chunk_data_suffix_r => unreachable,
+ }
+
+ return index;
+ }
+ }
+
+ /// Returns the number of bytes consumed by the chunk size. This is always less than or equal to `bytes.len`.
+ /// You should check `r.state == .chunk_data` after this to check if the chunk size has been fully parsed.
+ ///
+ /// If the amount returned is less than `bytes.len`, you may assume that the parser is in the `chunk_data` state
+ /// and that the first byte of the chunk is at `bytes[result]`.
+ pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 {
+ const len = @intCast(u32, bytes.len);
+
+ for (bytes[0..], 0..) |c, i| {
+ const index = @intCast(u32, i);
+ switch (r.state) {
+ .chunk_data_suffix => switch (c) {
+ '\r' => r.state = .chunk_data_suffix_r,
+ '\n' => r.state = .chunk_head_size,
+ else => {
+ r.state = .invalid;
+ return index;
+ },
+ },
+ .chunk_data_suffix_r => switch (c) {
+ '\n' => r.state = .chunk_head_size,
+ else => {
+ r.state = .invalid;
+ return index;
+ },
+ },
+ .chunk_head_size => {
+ const digit = switch (c) {
+ '0'...'9' => |b| b - '0',
+ 'A'...'Z' => |b| b - 'A' + 10,
+ 'a'...'z' => |b| b - 'a' + 10,
+ '\r' => {
+ r.state = .chunk_head_r;
+ continue;
+ },
+ '\n' => {
+ r.state = .chunk_data;
+ return index + 1;
+ },
+ else => {
+ r.state = .chunk_head_ext;
+ continue;
+ },
+ };
+
+ const new_len = r.next_chunk_length *% 16 +% digit;
+ if (new_len <= r.next_chunk_length and r.next_chunk_length != 0) {
+ r.state = .invalid;
+ return index;
+ }
+
+ r.next_chunk_length = new_len;
+ },
+ .chunk_head_ext => switch (c) {
+ '\r' => r.state = .chunk_head_r,
+ '\n' => {
+ r.state = .chunk_data;
+ return index + 1;
+ },
+ else => continue,
+ },
+ .chunk_head_r => switch (c) {
+ '\n' => {
+ r.state = .chunk_data;
+ return index + 1;
+ },
+ else => {
+ r.state = .invalid;
+ return index;
+ },
+ },
+ else => unreachable,
+ }
+ }
+
+ return len;
+ }
+
+ /// Returns whether or not the parser has finished parsing a complete message. A message is only complete after the
+ /// entire body has been read and any trailing headers have been parsed.
+ pub fn isComplete(r: *HeadersParser) bool {
+ return r.done and r.state == .finished;
+ }
+
+ pub const CheckCompleteHeadError = mem.Allocator.Error || error{HttpHeadersExceededSizeLimit};
+
+ /// Pushes `in` into the parser. Returns the number of bytes consumed by the header. Any header bytes are appended
+ /// to the `header_bytes` buffer.
+ ///
+ /// This function only uses `allocator` if `r.header_bytes_owned` is true, and may be undefined otherwise.
+ pub fn checkCompleteHead(r: *HeadersParser, allocator: std.mem.Allocator, in: []const u8) CheckCompleteHeadError!u32 {
+ if (r.state.isContent()) return 0;
+
+ const i = r.findHeadersEnd(in);
+ const data = in[0..i];
+ if (r.header_bytes.items.len + data.len > r.max_header_bytes) {
+ return error.HttpHeadersExceededSizeLimit;
+ } else {
+ if (r.header_bytes_owned) try r.header_bytes.ensureUnusedCapacity(allocator, data.len);
+
+ r.header_bytes.appendSliceAssumeCapacity(data);
+ }
+
+ return i;
+ }
+
+ pub const ReadError = error{
+ HttpChunkInvalid,
+ };
+
+ /// Reads the body of the message into `buffer`. Returns the number of bytes placed in the buffer.
+ ///
+ /// If `skip` is true, the buffer will be unused and the body will be skipped.
+ ///
+ /// See `std.http.Client.BufferedConnection for an example of `bconn`.
+ pub fn read(r: *HeadersParser, bconn: anytype, buffer: []u8, skip: bool) !usize {
+ assert(r.state.isContent());
+ if (r.done) return 0;
+
+ var out_index: usize = 0;
+ while (true) {
+ switch (r.state) {
+ .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable,
+ .finished => {
+ const data_avail = r.next_chunk_length;
+
+ if (skip) {
+ try bconn.fill();
+
+ const nread = @min(bconn.peek().len, data_avail);
+ bconn.clear(@intCast(u16, nread));
+ r.next_chunk_length -= nread;
+
+ if (r.next_chunk_length == 0) r.done = true;
+
+ return 0;
+ } else {
+ const out_avail = buffer.len;
+
+ const can_read = @intCast(usize, @min(data_avail, out_avail));
+ const nread = try bconn.read(buffer[0..can_read]);
+ r.next_chunk_length -= nread;
+
+ if (r.next_chunk_length == 0) r.done = true;
+
+ return nread;
+ }
+ },
+ .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
+ try bconn.fill();
+
+ const i = r.findChunkedLen(bconn.peek());
+ bconn.clear(@intCast(u16, i));
+
+ switch (r.state) {
+ .invalid => return error.HttpChunkInvalid,
+ .chunk_data => if (r.next_chunk_length == 0) {
+ // The trailer section is formatted identically to the header section.
+ r.state = .seen_rn;
+ r.done = true;
+
+ return out_index;
+ },
+ else => return out_index,
+ }
+
+ continue;
+ },
+ .chunk_data => {
+ const data_avail = r.next_chunk_length;
+ const out_avail = buffer.len - out_index;
+
+ if (skip) {
+ try bconn.fill();
+
+ const nread = @min(bconn.peek().len, data_avail);
+ bconn.clear(@intCast(u16, nread));
+ r.next_chunk_length -= nread;
+ } else {
+ const can_read = @intCast(usize, @min(data_avail, out_avail));
+ const nread = try bconn.read(buffer[out_index..][0..can_read]);
+ r.next_chunk_length -= nread;
+ out_index += nread;
+ }
+
+ if (r.next_chunk_length == 0) {
+ r.state = .chunk_data_suffix;
+ continue;
+ }
+
+ return out_index;
+ },
+ }
+ }
+ }
+};
+
+inline fn int16(array: *const [2]u8) u16 {
+ return @bitCast(u16, array.*);
+}
+
+inline fn int24(array: *const [3]u8) u24 {
+ return @bitCast(u24, array.*);
+}
+
+inline fn int32(array: *const [4]u8) u32 {
+ return @bitCast(u32, array.*);
+}
+
+inline fn intShift(comptime T: type, x: anytype) T {
+ switch (@import("builtin").cpu.arch.endian()) {
+ .Little => return @truncate(T, x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T))),
+ .Big => return @truncate(T, x),
+ }
+}
+
+/// A buffered (and peekable) Connection.
+const MockBufferedConnection = struct {
+ pub const buffer_size = 0x2000;
+
+ conn: std.io.FixedBufferStream([]const u8),
+ buf: [buffer_size]u8 = undefined,
+ start: u16 = 0,
+ end: u16 = 0,
+
+ pub fn fill(bconn: *MockBufferedConnection) ReadError!void {
+ if (bconn.end != bconn.start) return;
+
+ const nread = try bconn.conn.read(bconn.buf[0..]);
+ if (nread == 0) return error.EndOfStream;
+ bconn.start = 0;
+ bconn.end = @truncate(u16, nread);
+ }
+
+ pub fn peek(bconn: *MockBufferedConnection) []const u8 {
+ return bconn.buf[bconn.start..bconn.end];
+ }
+
+ pub fn clear(bconn: *MockBufferedConnection, num: u16) void {
+ bconn.start += num;
+ }
+
+ pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
+ var out_index: u16 = 0;
+ while (out_index < len) {
+ const available = bconn.end - bconn.start;
+ const left = buffer.len - out_index;
+
+ if (available > 0) {
+ const can_read = @truncate(u16, @min(available, left));
+
+ std.mem.copy(u8, buffer[out_index..], bconn.buf[bconn.start..][0..can_read]);
+ out_index += can_read;
+ bconn.start += can_read;
+
+ continue;
+ }
+
+ if (left > bconn.buf.len) {
+ // skip the buffer if the output is large enough
+ return bconn.conn.read(buffer[out_index..]);
+ }
+
+ try bconn.fill();
+ }
+
+ return out_index;
+ }
+
+ pub fn read(bconn: *MockBufferedConnection, buffer: []u8) ReadError!usize {
+ return bconn.readAtLeast(buffer, 1);
+ }
+
+ pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream};
+ pub const Reader = std.io.Reader(*MockBufferedConnection, ReadError, read);
+
+ pub fn reader(bconn: *MockBufferedConnection) Reader {
+ return Reader{ .context = bconn };
+ }
+
+ pub fn writeAll(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!void {
+ return bconn.conn.writeAll(buffer);
+ }
+
+ pub fn write(bconn: *MockBufferedConnection, buffer: []const u8) WriteError!usize {
+ return bconn.conn.write(buffer);
+ }
+
+ pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError;
+ pub const Writer = std.io.Writer(*MockBufferedConnection, WriteError, write);
+
+ pub fn writer(bconn: *MockBufferedConnection) Writer {
+ return Writer{ .context = bconn };
+ }
+};
+
+test "HeadersParser.findHeadersEnd" {
+ var r: HeadersParser = undefined;
+ const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\nHello";
+
+ for (0..36) |i| {
+ r = HeadersParser.initDynamic(0);
+ try std.testing.expectEqual(@intCast(u32, i), r.findHeadersEnd(data[0..i]));
+ try std.testing.expectEqual(@intCast(u32, 35 - i), r.findHeadersEnd(data[i..]));
+ }
+}
+
+test "HeadersParser.findChunkedLen" {
+ var r: HeadersParser = undefined;
+ const data = "Ff\r\nf0f000 ; ext\n0\r\nffffffffffffffffffffffffffffffffffffffff\r\n";
+
+ r = HeadersParser.initDynamic(0);
+ r.state = .chunk_head_size;
+ r.next_chunk_length = 0;
+
+ const first = r.findChunkedLen(data[0..]);
+ try testing.expectEqual(@as(u32, 4), first);
+ try testing.expectEqual(@as(u64, 0xff), r.next_chunk_length);
+ try testing.expectEqual(State.chunk_data, r.state);
+ r.state = .chunk_head_size;
+ r.next_chunk_length = 0;
+
+ const second = r.findChunkedLen(data[first..]);
+ try testing.expectEqual(@as(u32, 13), second);
+ try testing.expectEqual(@as(u64, 0xf0f000), r.next_chunk_length);
+ try testing.expectEqual(State.chunk_data, r.state);
+ r.state = .chunk_head_size;
+ r.next_chunk_length = 0;
+
+ const third = r.findChunkedLen(data[first + second ..]);
+ try testing.expectEqual(@as(u32, 3), third);
+ try testing.expectEqual(@as(u64, 0), r.next_chunk_length);
+ try testing.expectEqual(State.chunk_data, r.state);
+ r.state = .chunk_head_size;
+ r.next_chunk_length = 0;
+
+ const fourth = r.findChunkedLen(data[first + second + third ..]);
+ try testing.expectEqual(@as(u32, 16), fourth);
+ try testing.expectEqual(@as(u64, 0xffffffffffffffff), r.next_chunk_length);
+ try testing.expectEqual(State.invalid, r.state);
+}
+
+test "HeadersParser.read length" {
+ // mock BufferedConnection for read
+
+ var r = HeadersParser.initDynamic(256);
+ defer r.header_bytes.deinit(std.testing.allocator);
+ const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello";
+ var fbs = std.io.fixedBufferStream(data);
+
+ var bconn = MockBufferedConnection{
+ .conn = fbs,
+ };
+
+ while (true) { // read headers
+ try bconn.fill();
+
+ const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+ bconn.clear(@intCast(u16, nchecked));
+
+ if (r.state.isContent()) break;
+ }
+
+ var buf: [8]u8 = undefined;
+
+ r.next_chunk_length = 5;
+ const len = try r.read(&bconn, &buf, false);
+ try std.testing.expectEqual(@as(usize, 5), len);
+ try std.testing.expectEqualStrings("Hello", buf[0..len]);
+
+ try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.header_bytes.items);
+}
+
+test "HeadersParser.read chunked" {
+ // mock BufferedConnection for read
+
+ var r = HeadersParser.initDynamic(256);
+ defer r.header_bytes.deinit(std.testing.allocator);
+ const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n";
+ var fbs = std.io.fixedBufferStream(data);
+
+ var bconn = MockBufferedConnection{
+ .conn = fbs,
+ };
+
+ while (true) { // read headers
+ try bconn.fill();
+
+ const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+ bconn.clear(@intCast(u16, nchecked));
+
+ if (r.state.isContent()) break;
+ }
+ var buf: [8]u8 = undefined;
+
+ r.state = .chunk_head_size;
+ const len = try r.read(&bconn, &buf, false);
+ try std.testing.expectEqual(@as(usize, 5), len);
+ try std.testing.expectEqualStrings("Hello", buf[0..len]);
+
+ try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.header_bytes.items);
+}
+
+test "HeadersParser.read chunked trailer" {
+ // mock BufferedConnection for read
+
+ var r = HeadersParser.initDynamic(256);
+ defer r.header_bytes.deinit(std.testing.allocator);
+ const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n";
+ var fbs = std.io.fixedBufferStream(data);
+
+ var bconn = MockBufferedConnection{
+ .conn = fbs,
+ };
+
+ while (true) { // read headers
+ try bconn.fill();
+
+ const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+ bconn.clear(@intCast(u16, nchecked));
+
+ if (r.state.isContent()) break;
+ }
+ var buf: [8]u8 = undefined;
+
+ r.state = .chunk_head_size;
+ const len = try r.read(&bconn, &buf, false);
+ try std.testing.expectEqual(@as(usize, 5), len);
+ try std.testing.expectEqualStrings("Hello", buf[0..len]);
+
+ while (true) { // read headers
+ try bconn.fill();
+
+ const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
+ bconn.clear(@intCast(u16, nchecked));
+
+ if (r.state.isContent()) break;
+ }
+
+ try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.header_bytes.items);
+}