diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2024-02-23 17:41:38 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-23 17:41:38 -0800 |
| commit | cfce81f7d5f11ab93b2d5fd26df41edf967f333b (patch) | |
| tree | 11e52ad0a44620f4a4519683abd945146c11b312 /lib/std/http/Server.zig | |
| parent | 7230b68b350b16c637e84f3ff224be24d23214ce (diff) | |
| parent | 653d4158cdcb20be82ff525e122277064e6acb92 (diff) | |
| download | zig-cfce81f7d5f11ab93b2d5fd26df41edf967f333b.tar.gz zig-cfce81f7d5f11ab93b2d5fd26df41edf967f333b.zip | |
Merge pull request #18955 from ziglang/std.http.Server
take std.http in a different direction
Diffstat (limited to 'lib/std/http/Server.zig')
| -rw-r--r-- | lib/std/http/Server.zig | 1671 |
1 files changed, 921 insertions, 750 deletions
diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 4659041779..2d360d40a4 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,873 +1,1044 @@ -//! HTTP Server implementation. -//! -//! This server assumes *all* clients are well behaved and standard compliant; it can and will deadlock if a client holds a connection open without sending a request. -//! -//! Example usage: -//! -//! ```zig -//! var server = Server.init(.{ .reuse_address = true }); -//! defer server.deinit(); -//! -//! try server.listen(bind_addr); -//! -//! while (true) { -//! var res = try server.accept(.{ .allocator = gpa }); -//! defer res.deinit(); -//! -//! while (res.reset() != .closing) { -//! res.wait() catch |err| switch (err) { -//! error.HttpHeadersInvalid => break, -//! error.HttpHeadersExceededSizeLimit => { -//! res.status = .request_header_fields_too_large; -//! res.send() catch break; -//! break; -//! }, -//! else => { -//! res.status = .bad_request; -//! res.send() catch break; -//! break; -//! }, -//! } -//! -//! res.status = .ok; -//! res.transfer_encoding = .chunked; -//! -//! try res.send(); -//! try res.writeAll("Hello, World!\n"); -//! try res.finish(); -//! } -//! } -//! ``` - -const std = @import("../std.zig"); -const testing = std.testing; -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const Allocator = mem.Allocator; -const assert = std.debug.assert; - -const Server = @This(); -const proto = @import("protocol.zig"); - -/// The underlying server socket. -socket: net.StreamServer, - -/// An interface to a plain connection. -pub const Connection = struct { - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - pub const Protocol = enum { plain }; +//! Blocking HTTP server implementation. +//! Handles a single connection's lifecycle. + +connection: net.Server.Connection, +/// Keeps track of whether the Server is ready to accept a new request on the +/// same connection, and makes invalid API usage cause assertion failures +/// rather than HTTP protocol violations. +state: State, +/// User-provided buffer that must outlive this Server. +/// Used to store the client's entire HTTP header. +read_buffer: []u8, +/// Amount of available data inside read_buffer. +read_buffer_len: usize, +/// Index into `read_buffer` of the first byte of the next HTTP request. +next_request_start: usize, + +pub const State = enum { + /// The connection is available to be used for the first time, or reused. + ready, + /// An error occurred in `receiveHead`. + receiving_head, + /// A Request object has been obtained and from there a Response can be + /// opened. + received_head, + /// The client is uploading something to this Server. + receiving_body, + /// The connection is eligible for another HTTP request, however the client + /// and server did not negotiate connection: keep-alive. + closing, +}; - stream: net.Stream, - protocol: Protocol, - - closing: bool = true, - - read_buf: [buffer_size]u8 = undefined, - read_start: u16 = 0, - read_end: u16 = 0, - - pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - return switch (conn.protocol) { - .plain => conn.stream.readAtLeast(buffer, len), - // .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len), - } catch |err| { - switch (err) { - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } - }; - } +/// Initialize an HTTP server that can respond to multiple requests on the same +/// connection. +/// The returned `Server` is ready for `receiveHead` to be called. +pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { + return .{ + .connection = connection, + .state = .ready, + .read_buffer = read_buffer, + .read_buffer_len = 0, + .next_request_start = 0, + }; +} - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; +pub const ReceiveHeadError = error{ + /// Client sent too many bytes of HTTP headers. + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Client sent headers that did not conform to the HTTP protocol. + HttpHeadersInvalid, + /// A low level I/O error occurred trying to read the headers. + HttpHeadersUnreadable, + /// Partial HTTP request was received but the connection was closed before + /// fully receiving the headers. + HttpRequestTruncated, + /// The client sent 0 bytes of headers before closing the stream. + /// In other words, a keep-alive connection was finally closed. + HttpConnectionClosing, +}; - const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @as(u16, @intCast(nread)); +/// The header bytes reference the read buffer that Server was initialized with +/// and remain alive until the next call to receiveHead. +pub fn receiveHead(s: *Server) ReceiveHeadError!Request { + assert(s.state == .ready); + s.state = .received_head; + errdefer s.state = .receiving_head; + + // In case of a reused connection, move the next request's bytes to the + // beginning of the buffer. + if (s.next_request_start > 0) { + if (s.read_buffer_len > s.next_request_start) { + rebase(s, 0); + } else { + s.read_buffer_len = 0; + } } - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; - } + var hp: http.HeadParser = .{}; - pub fn drop(conn: *Connection, num: u16) void { - conn.read_start += num; + if (s.read_buffer_len > 0) { + const bytes = s.read_buffer[0..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, end); } - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - - var out_index: u16 = 0; - while (out_index < len) { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len - out_index; - - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - out_index += @as(u16, @intCast(available_buffer)); - conn.read_start += @as(u16, @intCast(available_buffer)); - - break; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - out_index += available_read; - conn.read_start += available_read; - - if (out_index >= len) break; - } - - const leftover_buffer = available_buffer - available_read; - const leftover_len = len - out_index; - - if (leftover_buffer > conn.read_buf.len) { - // skip the buffer if the output is large enough - return conn.rawReadAtLeast(buffer[out_index..], leftover_len); + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = s.connection.stream.read(buf) catch + return error.HttpHeadersUnreadable; + if (read_n == 0) { + if (s.read_buffer_len > 0) { + return error.HttpRequestTruncated; + } else { + return error.HttpConnectionClosing; } - - try conn.fill(); } - - return out_index; - } - - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = error{ - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, - }; - - pub const Reader = std.io.Reader(*Connection, ReadError, read); - - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void { - return switch (conn.protocol) { - .plain => conn.stream.writeAll(buffer), - // .tls => return conn.tls_client.writeAll(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - return switch (conn.protocol) { - .plain => conn.stream.write(buffer), - // .tls => return conn.tls_client.write(conn.stream, buffer), - } catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) + return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); } +} - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, +fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { + return .{ + .server = s, + .head_end = head_end, + .head = Request.Head.parse(s.read_buffer[0..head_end]) catch + return error.HttpHeadersInvalid, + .reader_state = undefined, }; +} - pub const Writer = std.io.Writer(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; - } - - pub fn close(conn: *Connection) void { - conn.stream.close(); - } -}; - -/// The mode of transport for responses. -pub const ResponseTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for request messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(Response.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(Response.TransferReader); - pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP request originating from a client. pub const Request = struct { - pub const ParseError = Allocator.Error || error{ - UnknownHttpMethod, - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionNotSupported, + server: *Server, + /// Index into Server's read_buffer. + head_end: usize, + head: Head, + reader_state: union { + remaining_content_length: u64, + chunk_parser: http.ChunkParser, + }, + + pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); + pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); + pub const ZstdDecompressor = std.compress.zstd.Decompressor(std.io.AnyReader); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, }; - pub fn parse(req: *Request, bytes: []const u8) ParseError!void { - var it = mem.tokenizeAny(u8, bytes, "\r\n"); - - const first_line = it.next() orelse return error.HttpHeadersInvalid; - if (first_line.len < 10) - return error.HttpHeadersInvalid; - - const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); - - const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - if (version_start == method_end) return error.HttpHeadersInvalid; - - const version_str = first_line[version_start + 1 ..]; - if (version_str.len != 8) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(version_str[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, + pub const Head = struct { + method: http.Method, + target: []const u8, + version: http.Version, + expect: ?[]const u8, + content_type: ?[]const u8, + content_length: ?u64, + transfer_encoding: http.TransferEncoding, + transfer_compression: http.ContentEncoding, + keep_alive: bool, + compression: Compression, + + pub const ParseError = error{ + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidContentLength, + CompressionUnsupported, + MissingFinalNewline, }; - const target = first_line[method_end + 1 .. version_start]; - - req.method = method; - req.target = target; - req.version = version; - - while (it.next()) |line| { - if (line.len == 0) return error.HttpHeadersInvalid; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.tokenizeAny(u8, line, ": "); - const header_name = line_it.next() orelse return error.HttpHeadersInvalid; - const header_value = line_it.rest(); - - try req.headers.append(header_name, header_value); - - if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (req.content_length != null) return error.HttpHeadersInvalid; - req.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (req.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - req.transfer_encoding = transfer; - - next = iter.next(); + pub fn parse(bytes: []const u8) ParseError!Head { + var it = mem.splitSequence(u8, bytes, "\r\n"); + + const first_line = it.next().?; + if (first_line.len < 10) + return error.HttpHeadersInvalid; + + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (method_end > 24) return error.HttpHeadersInvalid; + + const method_str = first_line[0..method_end]; + const method: http.Method = @enumFromInt(http.Method.parse(method_str)); + + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse + return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; + + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + + const target = first_line[method_end + 1 .. version_start]; + + var head: Head = .{ + .method = method, + .target = target, + .version = version, + .expect = null, + .content_type = null, + .content_length = null, + .transfer_encoding = .none, + .transfer_compression = .identity, + .keep_alive = false, + .compression = .none, + }; + + while (it.next()) |line| { + if (line.len == 0) return head; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, } - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - req.transfer_compression = transfer; + var line_it = mem.splitSequence(u8, line, ": "); + const header_name = line_it.next().?; + const header_value = line_it.rest(); + if (header_value.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + head.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "expect")) { + head.expect = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + head.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (head.content_length != null) return error.HttpHeadersInvalid; + head.content_length = std.fmt.parseInt(u64, header_value, 10) catch + return error.InvalidContentLength; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; + + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + head.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; } - } + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (req.transfer_compression != .identity) return error.HttpHeadersInvalid; + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); - const trimmed = mem.trim(u8, header_value, " "); + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (head.transfer_encoding != .none) + return error.HttpHeadersInvalid; // we already have a transfer encoding + head.transfer_encoding = transfer; - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - req.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } - } - } - - inline fn int64(array: *const [8]u8) u64 { - return @as(u64, @bitCast(array.*)); - } - - /// The HTTP request method. - method: http.Method, - - /// The HTTP request target. - target: []const u8, - - /// The HTTP version of this request. - version: http.Version, - - /// The length of the request body, if known. - content_length: ?u64 = null, - - /// The transfer encoding of the request body, or .none if not present. - transfer_encoding: http.TransferEncoding = .none, - - /// The compression of the request body, or .identity (no compression) if not present. - transfer_compression: http.ContentEncoding = .identity, - - /// The list of HTTP request headers - headers: http.Headers, - - parser: proto.HeadersParser, - compression: Compression = .none, -}; - -/// A HTTP response waiting to be sent. -/// -/// Order of operations: -/// ``` -/// [/ <--------------------------------------- \] -/// accept -> wait -> send [ -> write -> finish][ -> reset /] -/// \ -> read / -/// ``` -pub const Response = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, - - transfer_encoding: ResponseTransfer = .none, - - /// The allocator responsible for allocating memory for this response. - allocator: Allocator, - - /// The peer's address - address: net.Address, - - /// The underlying connection for this response. - connection: Connection, + next = iter.next(); + } - /// The HTTP response headers - headers: http.Headers, + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); - /// The HTTP request that this response is responding to. - /// - /// This field is only valid after calling `wait`. - request: Request, + if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (head.transfer_compression != .identity) + return error.HttpHeadersInvalid; // double compression is not supported + head.transfer_compression = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } - state: State = .first, + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } + } + return error.MissingFinalNewline; + } - const State = enum { - first, - start, - waited, - responded, - finished, + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } }; - /// Free all resources associated with this response. - pub fn deinit(res: *Response) void { - res.connection.close(); - - res.headers.deinit(); - res.request.headers.deinit(); - - if (res.request.parser.header_bytes_owned) { - res.request.parser.header_bytes.deinit(res.allocator); - } + pub fn iterateHeaders(r: *Request) http.HeaderIterator { + return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); } - pub const ResetState = enum { reset, closing }; + pub const RespondOptions = struct { + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, + keep_alive: bool = true, + extra_headers: []const http.Header = &.{}, + transfer_encoding: ?http.TransferEncoding = null, + }; - /// Reset this response to its initial state. This must be called before handling a second request on the same connection. - pub fn reset(res: *Response) ResetState { - if (res.state == .first) { - res.state = .start; - return .reset; + /// Send an entire HTTP response to the client, including headers and body. + /// + /// Automatically handles HEAD requests by omitting the body. + /// + /// Unless `transfer_encoding` is specified, uses the "content-length" + /// header. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// Asserts status is not `continue`. + /// Asserts there are at most 25 extra_headers. + /// Asserts that "\r\n" does not occur in any header name or value. + pub fn respond( + request: *Request, + content: []const u8, + options: RespondOptions, + ) Response.WriteError!void { + const max_extra_headers = 25; + assert(options.status != .@"continue"); + assert(options.extra_headers.len <= max_extra_headers); + if (std.debug.runtime_safety) { + for (options.extra_headers) |header| { + assert(std.mem.indexOfScalar(u8, header.name, ':') == null); + assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null); + assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); + } } - if (!res.request.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - res.connection.closing = true; - return .closing; + const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and options.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + + const phrase = options.reason orelse options.status.phrase() orelse ""; + + var first_buffer: [500]u8 = undefined; + var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); + if (request.head.expect != null) { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + try request.server.connection.stream.writeAll(h.items); + return; } + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(options.version), @intFromEnum(options.status), phrase, + }) catch unreachable; - // A connection is only keep-alive if the Connection header is present and it's value is not "close". - // The server and client must both agree - // - // send() defaults to using keep-alive if the client requests it. - const res_connection = res.headers.getFirstValue("connection"); - const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); - - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - if (req_keepalive and (res_keepalive or res_connection == null)) { - res.connection.closing = false; - } else { - res.connection.closing = true; - } + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); - switch (res.request.compression) { + if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { .none => {}, - .deflate => {}, - .gzip => {}, - .zstd => |*zstd| zstd.deinit(), - } - - res.state = .start; - res.version = .@"HTTP/1.1"; - res.status = .ok; - res.reason = null; - - res.transfer_encoding = .none; - - res.headers.clearRetainingCapacity(); - - res.request.headers.clearAndFree(); // FIXME: figure out why `clearRetainingCapacity` causes a leak in hash_map here - res.request.parser.reset(); - - res.request = Request{ - .version = undefined, - .method = undefined, - .target = undefined, - .headers = res.request.headers, - .parser = res.request.parser, - }; - - if (res.connection.closing) { - return .closing; + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), } else { - return .reset; + h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; } - } - pub const SendError = Connection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + var chunk_header_buffer: [18]u8 = undefined; + var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; - /// Send the HTTP response headers to the client. - pub fn send(res: *Response) SendError!void { - switch (res.state) { - .waited => res.state = .responded, - .first, .start, .responded, .finished => unreachable, + iovecs[iovecs_len] = .{ + .iov_base = h.items.ptr, + .iov_len = h.items.len, + }; + iovecs_len += 1; + + for (options.extra_headers) |header| { + iovecs[iovecs_len] = .{ + .iov_base = header.name.ptr, + .iov_len = header.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = header.value.ptr, + .iov_len = header.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; } - var buffered = std.io.bufferedWriter(res.connection.writer()); - const w = buffered.writer(); - - try w.writeAll(@tagName(res.version)); - try w.writeByte(' '); - try w.print("{d}", .{@intFromEnum(res.status)}); - try w.writeByte(' '); - if (res.reason) |reason| { - try w.writeAll(reason); - } else if (res.status.phrase()) |phrase| { - try w.writeAll(phrase); - } - try w.writeAll("\r\n"); + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + + if (request.head.method != .HEAD) { + const is_chunked = (options.transfer_encoding orelse .none) == .chunked; + if (is_chunked) { + if (content.len > 0) { + const chunk_header = std.fmt.bufPrint( + &chunk_header_buffer, + "{x}\r\n", + .{content.len}, + ) catch unreachable; + + iovecs[iovecs_len] = .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = content.ptr, + .iov_len = content.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } - if (res.status == .@"continue") { - res.state = .waited; // we still need to send another request after this - } else { - if (!res.headers.contains("server")) { - try w.writeAll("Server: zig (std.http)\r\n"); + iovecs[iovecs_len] = .{ + .iov_base = "0\r\n\r\n", + .iov_len = 5, + }; + iovecs_len += 1; + } else if (content.len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = content.ptr, + .iov_len = content.len, + }; + iovecs_len += 1; } + } - if (!res.headers.contains("connection")) { - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - - if (req_keepalive) { - try w.writeAll("Connection: keep-alive\r\n"); - } else { - try w.writeAll("Connection: close\r\n"); - } - } + try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); + } - const has_transfer_encoding = res.headers.contains("transfer-encoding"); - const has_content_length = res.headers.contains("content-length"); + pub const RespondStreamingOptions = struct { + /// An externally managed slice of memory used to batch bytes before + /// sending. `respondStreaming` asserts this is large enough to store + /// the full HTTP response head. + /// + /// Must outlive the returned Response. + send_buffer: []u8, + /// If provided, the response will use the content-length header; + /// otherwise it will use transfer-encoding: chunked. + content_length: ?u64 = null, + /// Options that are shared with the `respond` method. + respond_options: RespondOptions = .{}, + }; - if (!has_transfer_encoding and !has_content_length) { - switch (res.transfer_encoding) { - .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), - .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), - .none => {}, - } + /// The header is buffered but not sent until Response.flush is called. + /// + /// If the request contains a body and the connection is to be reused, + /// discards the request body, leaving the Server in the `ready` state. If + /// this discarding fails, the connection is marked as not to be reused and + /// no error is surfaced. + /// + /// HEAD requests are handled transparently by setting a flag on the + /// returned Response to omit the body. However it may be worth noticing + /// that flag and skipping any expensive work that would otherwise need to + /// be done to satisfy the request. + /// + /// Asserts `send_buffer` is large enough to store the entire response header. + /// Asserts status is not `continue`. + pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + const o = options.respond_options; + assert(o.status != .@"continue"); + const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; + const server_keep_alive = !transfer_encoding_none and o.keep_alive; + const keep_alive = request.discardBody(server_keep_alive); + const phrase = o.reason orelse o.status.phrase() orelse ""; + + var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); + + const elide_body = if (request.head.expect != null) eb: { + // reader() and hence discardBody() above sets expect to null if it + // is handled. So the fact that it is not null here means unhandled. + h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); + break :eb true; + } else eb: { + h.fixedWriter().print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }) catch unreachable; + if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"); + + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { + h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; } else { - if (has_content_length) { - const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; - - res.transfer_encoding = .{ .content_length = content_length }; - } else if (has_transfer_encoding) { - const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?; - if (std.mem.eql(u8, transfer_encoding, "chunked")) { - res.transfer_encoding = .chunked; - } else { - return error.UnsupportedTransferEncoding; - } - } else { - res.transfer_encoding = .none; - } + h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); } - try w.print("{}", .{res.headers}); - } - - if (res.request.method == .HEAD) { - res.transfer_encoding = .none; - } + for (o.extra_headers) |header| { + h.appendSliceAssumeCapacity(header.name); + h.appendSliceAssumeCapacity(": "); + h.appendSliceAssumeCapacity(header.value); + h.appendSliceAssumeCapacity("\r\n"); + } - try w.writeAll("\r\n"); + h.appendSliceAssumeCapacity("\r\n"); + break :eb request.head.method == .HEAD; + }; - try buffered.flush(); + return .{ + .stream = request.server.connection.stream, + .send_buffer = options.send_buffer, + .send_buffer_start = 0, + .send_buffer_end = h.items.len, + .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { + .chunked => .chunked, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .chunked, + .elide_body = elide_body, + .chunk_len = 0, + }; } - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; + pub const ReadError = net.Stream.ReadError || error{ + HttpChunkInvalid, + HttpHeadersOversize, + }; - const TransferReader = std.io.Reader(*Response, TransferReadError, transferRead); + fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; - fn transferReader(res: *Response) TransferReader { - return .{ .context = res }; + const remaining_content_length = &request.reader_state.remaining_content_length; + if (remaining_content_length.* == 0) { + s.state = .ready; + return 0; + } + assert(s.state == .receiving_body); + const available = try fill(s, request.head_end); + const len = @min(remaining_content_length.*, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + remaining_content_length.* -= len; + s.next_request_start += len; + if (remaining_content_length.* == 0) + s.state = .ready; + return len; } - fn transferRead(res: *Response, buf: []u8) TransferReadError!usize { - if (res.request.parser.done) return 0; + fn fill(s: *Server, head_end: usize) ReadError![]u8 { + const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; + if (available.len > 0) return available; + s.next_request_start = head_end; + s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); + return s.read_buffer[head_end..s.read_buffer_len]; + } - var index: usize = 0; - while (index == 0) { - const amt = try res.request.parser.read(&res.connection, buf[index..], false); - if (amt == 0 and res.request.parser.done) break; - index += amt; + fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + const request: *Request = @constCast(@alignCast(@ptrCast(context))); + const s = request.server; + + const cp = &request.reader_state.chunk_parser; + const head_end = request.head_end; + + // Protect against returning 0 before the end of stream. + var out_end: usize = 0; + while (out_end == 0) { + switch (cp.state) { + .invalid => return 0, + .data => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const len = @min(cp.chunk_len, available.len, buffer.len); + @memcpy(buffer[0..len], available[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += len; + continue; + }, + else => { + assert(s.state == .receiving_body); + const available = try fill(s, head_end); + const n = cp.feed(available); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + if (cp.chunk_len == 0) { + // The next bytes in the stream are trailers, + // or \r\n to indicate end of chunked body. + // + // This function must append the trailers at + // head_end so that headers and trailers are + // together. + // + // Since returning 0 would indicate end of + // stream, this function must read all the + // trailers before returning. + if (s.next_request_start > head_end) rebase(s, head_end); + var hp: http.HeadParser = .{}; + { + const bytes = s.read_buffer[head_end..s.read_buffer_len]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + while (true) { + const buf = s.read_buffer[s.read_buffer_len..]; + if (buf.len == 0) + return error.HttpHeadersOversize; + const read_n = try s.connection.stream.read(buf); + s.read_buffer_len += read_n; + const bytes = buf[0..read_n]; + const end = hp.feed(bytes); + if (hp.state == .finished) { + cp.state = .invalid; + s.state = .ready; + s.next_request_start = s.read_buffer_len - bytes.len + end; + return out_end; + } + } + } + const data = available[n..]; + const len = @min(cp.chunk_len, data.len, buffer.len); + @memcpy(buffer[0..len], data[0..len]); + cp.chunk_len -= len; + if (cp.chunk_len == 0) + cp.state = .data_suffix; + out_end += len; + s.next_request_start += n + len; + continue; + }, + else => continue, + } + }, + } } - - return index; + return out_end; } - pub const WaitError = Connection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported }; + pub const ReaderError = Response.WriteError || error{ + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; - /// Wait for the client to send a complete request head. + /// In the case that the request contains "expect: 100-continue", this + /// function writes the continuation header, which means it can fail with a + /// write error. After sending the continuation header, it sets the + /// request's expect field to `null`. /// - /// For correct behavior, the following rules must be followed: - /// - /// * If this returns any error in `Connection.ReadError`, you MUST immediately close the connection by calling `deinit`. - /// * If this returns `error.HttpHeadersInvalid`, you MAY immediately close the connection by calling `deinit`. - /// * If this returns `error.HttpHeadersExceededSizeLimit`, you MUST respond with a 431 status code and then call `deinit`. - /// * If this returns any error in `Request.ParseError`, you MUST respond with a 400 status code and then call `deinit`. - /// * If this returns any other error, you MUST respond with a 400 status code and then call `deinit`. - /// * If the request has an Expect header containing 100-continue, you MUST either: - /// * Respond with a 100 status code, then call `wait` again. - /// * Respond with a 417 status code. - pub fn wait(res: *Response) WaitError!void { - switch (res.state) { - .first, .start => res.state = .waited, - .waited, .responded, .finished => unreachable, + /// Asserts that this function is only called once. + pub fn reader(request: *Request) ReaderError!std.io.AnyReader { + const s = request.server; + assert(s.state == .received_head); + s.state = .receiving_body; + s.next_request_start = request.head_end; + + if (request.head.expect) |expect| { + if (mem.eql(u8, expect, "100-continue")) { + try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; + } else { + return error.HttpExpectationFailed; + } } - while (true) { - try res.connection.fill(); - - const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); - res.connection.drop(@as(u16, @intCast(nchecked))); - - if (res.request.parser.state.isContent()) break; + switch (request.head.transfer_encoding) { + .chunked => { + request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; + return .{ + .readFn = read_chunked, + .context = request, + }; + }, + .none => { + request.reader_state = .{ + .remaining_content_length = request.head.content_length orelse 0, + }; + return .{ + .readFn = read_cl, + .context = request, + }; + }, } + } - res.request.headers = .{ .allocator = res.allocator, .owned = true }; - try res.request.parse(res.request.parser.header_bytes.items); - - if (res.request.transfer_encoding != .none) { - switch (res.request.transfer_encoding) { - .none => unreachable, - .chunked => { - res.request.parser.next_chunk_length = 0; - res.request.parser.state = .chunk_head_size; - }, - } - } else if (res.request.content_length) |cl| { - res.request.parser.next_chunk_length = cl; - - if (cl == 0) res.request.parser.done = true; + /// Returns whether the connection: keep-alive header should be sent to the client. + /// If it would fail, it instead sets the Server state to `receiving_body` + /// and returns false. + fn discardBody(request: *Request, keep_alive: bool) bool { + // Prepare to receive another request on the same connection. + // There are two factors to consider: + // * Any body the client sent must be discarded. + // * The Server's read_buffer may already have some bytes in it from + // whatever came after the head, which may be the next HTTP request + // or the request body. + // If the connection won't be kept alive, then none of this matters + // because the connection will be severed after the response is sent. + const s = request.server; + if (keep_alive and request.head.keep_alive) switch (s.state) { + .received_head => { + const r = request.reader() catch return false; + _ = r.discard() catch return false; + assert(s.state == .ready); + return true; + }, + .receiving_body, .ready => return true, + else => unreachable, } else { - res.request.parser.done = true; + s.state = .closing; + return false; } + } +}; - if (!res.request.parser.done) { - switch (res.request.transfer_compression) { - .identity => res.request.compression = .none, - .compress, .@"x-compress" => return error.CompressionNotSupported, - .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.decompressor(res.transferReader()), - }, - .gzip, .@"x-gzip" => res.request.compression = .{ - .gzip = std.compress.gzip.decompressor(res.transferReader()), - }, - .zstd => res.request.compression = .{ - .zstd = std.compress.zstd.decompressStream(res.allocator, res.transferReader()), - }, - } +pub const Response = struct { + stream: net.Stream, + send_buffer: []u8, + /// Index of the first byte in `send_buffer`. + /// This is 0 unless a short write happens in `write`. + send_buffer_start: usize, + /// Index of the last byte + 1 in `send_buffer`. + send_buffer_end: usize, + /// `null` means transfer-encoding: chunked. + /// As a debugging utility, counts down to zero as bytes are written. + transfer_encoding: TransferEncoding, + elide_body: bool, + /// Indicates how much of the end of the `send_buffer` corresponds to a + /// chunk. This amount of data will be wrapped by an HTTP chunk header. + chunk_len: usize, + + pub const TransferEncoding = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked, + }; + + pub const WriteError = net.Stream.WriteError; + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then calls `flush`. + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message, then flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn end(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + try flush_cl(r); + }, + .none => { + try flush_cl(r); + }, + .chunked => { + try flush_chunked(r, &.{}); + }, } + r.* = undefined; } - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; - - pub const Reader = std.io.Reader(*Response, ReadError, read); + pub const EndChunkedOptions = struct { + trailers: []const http.Header = &.{}, + }; - pub fn reader(res: *Response) Reader { - return .{ .context = res }; + /// Asserts that the Response is using transfer-encoding: chunked. + /// Writes the end-of-stream message and any optional trailers, then + /// flushes the stream to the system. + /// Respects the value of `elide_body` to omit all data after the headers. + /// Asserts there are at most 25 trailers. + pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + assert(r.transfer_encoding == .chunked); + try flush_chunked(r, options.trailers); + r.* = undefined; } - /// Reads data from the response body. Must be called after `wait`. - pub fn read(res: *Response, buffer: []u8) ReadError!usize { - switch (res.state) { - .waited, .responded, .finished => {}, - .first, .start => unreachable, + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + /// May return 0, which does not indicate end of stream. The caller decides + /// when the end of stream occurs by calling `end`. + pub fn write(r: *Response, bytes: []const u8) WriteError!usize { + switch (r.transfer_encoding) { + .content_length, .none => return write_cl(r, bytes), + .chunked => return write_chunked(r, bytes), } + } - const out_index = switch (res.request.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try res.transferRead(buffer), - }; + fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); - if (out_index == 0) { - const has_trail = !res.request.parser.state.isContent(); + var trash: u64 = std.math.maxInt(u64); + const len = switch (r.transfer_encoding) { + .content_length => |*len| len, + else => &trash, + }; - while (!res.request.parser.state.isContent()) { // read trailing headers - try res.connection.fill(); + if (r.elide_body) { + len.* -= bytes.len; + return bytes.len; + } - const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); - res.connection.drop(@as(u16, @intCast(nchecked))); + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + var iovecs: [2]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + }; + const n = try r.stream.writev(&iovecs); + + if (n >= send_buffer_len) { + // It was enough to reset the buffer. + r.send_buffer_start = 0; + r.send_buffer_end = 0; + const bytes_n = n - send_buffer_len; + len.* -= bytes_n; + return bytes_n; } - if (has_trail) { - res.request.headers = http.Headers{ .allocator = res.allocator, .owned = false }; - - // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. - // This will *only* fail for a malformed trailer. - res.request.parse(res.request.parser.header_bytes.items) catch return error.InvalidTrailers; - } + // It didn't even make it through the existing buffer, let + // alone the new bytes provided. + r.send_buffer_start += n; + return 0; } - return out_index; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(res: *Response, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(res, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + len.* -= bytes.len; + return bytes.len; } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; + fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + const r: *Response = @constCast(@alignCast(@ptrCast(context))); + assert(r.transfer_encoding == .chunked); - pub const Writer = std.io.Writer(*Response, WriteError, write); + if (r.elide_body) + return bytes.len; - pub fn writer(res: *Response) Writer { - return .{ .context = res }; - } + if (bytes.len + r.send_buffer_end > r.send_buffer.len) { + const send_buffer_len = r.send_buffer_end - r.send_buffer_start; + const chunk_len = r.chunk_len + bytes.len; + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(res: *Response, bytes: []const u8) WriteError!usize { - switch (res.state) { - .responded => {}, - .first, .waited, .start, .finished => unreachable, + var iovecs: [5]std.posix.iovec_const = .{ + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_start, + .iov_len = send_buffer_len - r.chunk_len, + }, + .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }, + .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }, + .{ + .iov_base = bytes.ptr, + .iov_len = bytes.len, + }, + .{ + .iov_base = "\r\n", + .iov_len = 2, + }, + }; + // TODO make this writev instead of writevAll, which involves + // complicating the logic of this function. + try r.stream.writevAll(&iovecs); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return bytes.len; } - switch (res.transfer_encoding) { - .chunked => { - try res.connection.writer().print("{x}\r\n", .{bytes.len}); - try res.connection.writeAll(bytes); - try res.connection.writeAll("\r\n"); - - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; - - const amt = try res.connection.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, - } + // All bytes can be stored in the remaining space of the buffer. + @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); + r.send_buffer_end += bytes.len; + r.chunk_len += bytes.len; + return bytes.len; } - /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void { + /// If using content-length, asserts that writing these bytes to the client + /// would not exceed the content-length value sent in the HTTP header. + pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { - index += try write(req, bytes[index..]); + index += try write(r, bytes[index..]); } } - pub const FinishError = WriteError || error{MessageNotCompleted}; - - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(res: *Response) FinishError!void { - switch (res.state) { - .responded => res.state = .finished, - .first, .waited, .start, .finished => unreachable, + /// Sends all buffered data to the client. + /// This is redundant after calling `end`. + /// Respects the value of `elide_body` to omit all data after the headers. + pub fn flush(r: *Response) WriteError!void { + switch (r.transfer_encoding) { + .none, .content_length => return flush_cl(r), + .chunked => return flush_chunked(r, null), } + } - switch (res.transfer_encoding) { - .chunked => try res.connection.writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, - } + fn flush_cl(r: *Response) WriteError!void { + try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; } -}; -/// Create a new HTTP server. -pub fn init(options: net.StreamServer.Options) Server { - return .{ - .socket = net.StreamServer.init(options), - }; -} + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + const max_trailers = 25; + if (end_trailers) |trailers| assert(trailers.len <= max_trailers); + assert(r.transfer_encoding == .chunked); -/// Free all resources associated with this server. -pub fn deinit(server: *Server) void { - server.socket.deinit(); -} + const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; -pub const ListenError = std.os.SocketError || std.os.BindError || std.os.ListenError || std.os.SetSockOptError || std.os.GetSockNameError; + if (r.elide_body) { + try r.stream.writeAll(http_headers); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; + return; + } -/// Start the HTTP server listening on the given address. -pub fn listen(server: *Server, address: net.Address) ListenError!void { - try server.socket.listen(address); -} + var header_buf: [18]u8 = undefined; + const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; -pub const AcceptError = net.StreamServer.AcceptError || Allocator.Error; - -pub const HeaderStrategy = union(enum) { - /// In this case, the client's Allocator will be used to store the - /// entire HTTP header. This value is the maximum total size of - /// HTTP headers allowed, otherwise - /// error.HttpHeadersExceededSizeLimit is returned from read(). - dynamic: usize, - /// This is used to store the entire HTTP header. If the HTTP - /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` - /// is returned from read(). When this is used, `error.OutOfMemory` - /// cannot be returned from `read()`. - static: []u8, -}; + var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; + var iovecs_len: usize = 0; -pub const AcceptOptions = struct { - allocator: Allocator, - header_strategy: HeaderStrategy = .{ .dynamic = 8192 }, -}; + iovecs[iovecs_len] = .{ + .iov_base = http_headers.ptr, + .iov_len = http_headers.len, + }; + iovecs_len += 1; + + if (r.chunk_len > 0) { + iovecs[iovecs_len] = .{ + .iov_base = chunk_header.ptr, + .iov_len = chunk_header.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, + .iov_len = r.chunk_len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } -/// Accept a new connection. -pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { - const in = try server.socket.accept(); - - return Response{ - .allocator = options.allocator, - .address = in.address, - .connection = .{ - .stream = in.stream, - .protocol = .plain, - }, - .headers = .{ .allocator = options.allocator }, - .request = .{ - .version = undefined, - .method = undefined, - .target = undefined, - .headers = .{ .allocator = options.allocator, .owned = false }, - .parser = switch (options.header_strategy) { - .dynamic => |max| proto.HeadersParser.initDynamic(max), - .static => |buf| proto.HeadersParser.initStatic(buf), - }, - }, - }; -} + if (end_trailers) |trailers| { + iovecs[iovecs_len] = .{ + .iov_base = "0\r\n", + .iov_len = 3, + }; + iovecs_len += 1; + + for (trailers) |trailer| { + iovecs[iovecs_len] = .{ + .iov_base = trailer.name.ptr, + .iov_len = trailer.name.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = ": ", + .iov_len = 2, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = trailer.value.ptr, + .iov_len = trailer.value.len, + }; + iovecs_len += 1; + + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } -test "HTTP server handles a chunked transfer coding request" { - const builtin = @import("builtin"); + iovecs[iovecs_len] = .{ + .iov_base = "\r\n", + .iov_len = 2, + }; + iovecs_len += 1; + } - // This test requires spawning threads. - if (builtin.single_threaded) { - return error.SkipZigTest; + try r.stream.writevAll(iovecs[0..iovecs_len]); + r.send_buffer_start = 0; + r.send_buffer_end = 0; + r.chunk_len = 0; } - const native_endian = comptime builtin.cpu.arch.endian(); - if (builtin.zig_backend == .stage2_llvm and native_endian == .big) { - // https://github.com/ziglang/zig/issues/13782 - return error.SkipZigTest; + pub fn writer(r: *Response) std.io.AnyWriter { + return .{ + .writeFn = switch (r.transfer_encoding) { + .none, .content_length => write_cl, + .chunked => write_chunked, + }, + .context = r, + }; } +}; - if (builtin.os.tag == .wasi) return error.SkipZigTest; - - const allocator = std.testing.allocator; - const expect = std.testing.expect; - - const max_header_size = 8192; - var server = std.http.Server.init(.{ .reuse_address = true }); - defer server.deinit(); - - const address = try std.net.Address.parseIp("127.0.0.1", 0); - try server.listen(address); - const server_port = server.socket.listen_address.in.getPort(); - - const server_thread = try std.Thread.spawn(.{}, (struct { - fn apply(s: *std.http.Server) !void { - var res = try s.accept(.{ - .allocator = allocator, - .header_strategy = .{ .dynamic = max_header_size }, - }); - defer res.deinit(); - defer _ = res.reset(); - try res.wait(); - - try expect(res.request.transfer_encoding == .chunked); - - const server_body: []const u8 = "message from server!\n"; - res.transfer_encoding = .{ .content_length = server_body.len }; - try res.headers.append("content-type", "text/plain"); - try res.headers.append("connection", "close"); - try res.send(); - - var buf: [128]u8 = undefined; - const n = try res.readAll(&buf); - try expect(std.mem.eql(u8, buf[0..n], "ABCD")); - _ = try res.writer().writeAll(server_body); - try res.finish(); - } - }).apply, .{&server}); - - const request_bytes = - "POST / HTTP/1.1\r\n" ++ - "Content-Type: text/plain\r\n" ++ - "Transfer-Encoding: chunked\r\n" ++ - "\r\n" ++ - "1\r\n" ++ - "A\r\n" ++ - "1\r\n" ++ - "B\r\n" ++ - "2\r\n" ++ - "CD\r\n" ++ - "0\r\n" ++ - "\r\n"; - - const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port); - defer stream.close(); - _ = try stream.writeAll(request_bytes[0..]); - - server_thread.join(); +fn rebase(s: *Server, index: usize) void { + const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; + const dest = s.read_buffer[index..][0..leftover.len]; + if (leftover.len <= s.next_request_start - index) { + @memcpy(dest, leftover); + } else { + mem.copyBackwards(u8, dest, leftover); + } + s.read_buffer_len = index + leftover.len; } + +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const assert = std.debug.assert; + +const Server = @This(); |
