diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2023-05-11 08:37:42 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-11 08:37:42 -0700 |
| commit | 3bb3d39fb4158ba4b811bcae7e7a897febf07e17 (patch) | |
| tree | 51d408085e1e9c60003c02960e509e3498807ab5 /lib/std/http/Server.zig | |
| parent | 5569e6b49d9b421d35e3175df36eb9fe7e4e8084 (diff) | |
| parent | 9017d758b96c1f296249670dffb774a598bc8598 (diff) | |
| download | zig-3bb3d39fb4158ba4b811bcae7e7a897febf07e17.tar.gz zig-3bb3d39fb4158ba4b811bcae7e7a897febf07e17.zip | |
Merge pull request #15487 from truemedian/http-tests
std.http: more http fixes, add standalone http server test
Diffstat (limited to 'lib/std/http/Server.zig')
| -rw-r--r-- | lib/std/http/Server.zig | 236 |
1 files changed, 167 insertions, 69 deletions
diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index c7f2a86c27..6b5db6725f 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -95,47 +95,50 @@ pub const Connection = struct { /// A buffered (and peekable) Connection. pub const BufferedConnection = struct { - pub const buffer_size = 0x2000; + pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; conn: Connection, - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, + read_buf: [buffer_size]u8 = undefined, + read_start: u16 = 0, + read_end: u16 = 0, + + write_buf: [buffer_size]u8 = undefined, + write_end: u16 = 0, pub fn fill(bconn: *BufferedConnection) ReadError!void { - if (bconn.end != bconn.start) return; + if (bconn.read_end != bconn.read_start) return; - const nread = try bconn.conn.read(bconn.buf[0..]); + const nread = try bconn.conn.read(bconn.read_buf[0..]); if (nread == 0) return error.EndOfStream; - bconn.start = 0; - bconn.end = @truncate(u16, nread); + bconn.read_start = 0; + bconn.read_end = @intCast(u16, nread); } pub fn peek(bconn: *BufferedConnection) []const u8 { - return bconn.buf[bconn.start..bconn.end]; + return bconn.read_buf[bconn.read_start..bconn.read_end]; } pub fn clear(bconn: *BufferedConnection, num: u16) void { - bconn.start += num; + bconn.read_start += num; } pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize { var out_index: u16 = 0; while (out_index < len) { - const available = bconn.end - bconn.start; + const available = bconn.read_end - bconn.read_start; const left = buffer.len - out_index; if (available > 0) { - const can_read = @truncate(u16, @min(available, left)); + const can_read = @intCast(u16, @min(available, left)); - @memcpy(buffer[out_index..][0..can_read], bconn.buf[bconn.start..][0..can_read]); + @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]); out_index += can_read; - bconn.start += can_read; + bconn.read_start += can_read; continue; } - if (left > bconn.buf.len) { + if (left > bconn.read_buf.len) { // skip the buffer if the output is large enough return bconn.conn.read(buffer[out_index..]); } @@ -158,11 +161,30 @@ pub const BufferedConnection = struct { } pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void { - return bconn.conn.writeAll(buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); + bconn.write_end += @intCast(u16, buffer.len); + } else { + try bconn.flush(); + try bconn.conn.writeAll(buffer); + } } pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize { - return bconn.conn.write(buffer); + if (bconn.write_buf.len - bconn.write_end >= buffer.len) { + @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer); + bconn.write_end += @intCast(u16, buffer.len); + + return buffer.len; + } else { + try bconn.flush(); + return try bconn.conn.write(buffer); + } + } + + pub fn flush(bconn: *BufferedConnection) WriteError!void { + defer bconn.write_end = 0; + return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]); } pub const WriteError = Connection.WriteError; @@ -199,8 +221,6 @@ pub const Compression = union(enum) { /// A HTTP request originating from a client. pub const Request = struct { pub const ParseError = Allocator.Error || error{ - ShortHttpStatusLine, - BadHttpVersion, UnknownHttpMethod, HttpHeadersInvalid, HttpHeaderContinuationsUnsupported, @@ -215,7 +235,7 @@ pub const Request = struct { const first_line = it.next() orelse return error.HttpHeadersInvalid; if (first_line.len < 10) - return error.ShortHttpStatusLine; + return error.HttpHeadersInvalid; const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; const method_str = first_line[0..method_end]; @@ -229,7 +249,7 @@ pub const Request = struct { const version: http.Version = switch (int64(version_str[0..8])) { int64("HTTP/1.0") => .@"HTTP/1.0", int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, + else => return error.HttpHeadersInvalid, }; const target = first_line[method_end + 1 .. version_start]; @@ -312,7 +332,7 @@ pub const Request = struct { transfer_encoding: ?http.TransferEncoding = null, transfer_compression: ?http.ContentEncoding = null, - headers: http.Headers = undefined, + headers: http.Headers, parser: proto.HeadersParser, compression: Compression = .none, }; @@ -329,21 +349,63 @@ pub const Response = struct { transfer_encoding: ResponseTransfer = .none, - server: *Server, + allocator: Allocator, address: net.Address, connection: BufferedConnection, headers: http.Headers, request: Request, + state: State = .first, + + const State = enum { + first, + start, + waited, + responded, + finished, + }; + pub fn deinit(res: *Response) void { - res.server.allocator.destroy(res); + res.connection.close(); + + res.headers.deinit(); + res.request.headers.deinit(); + + if (res.request.parser.header_bytes_owned) { + res.request.parser.header_bytes.deinit(res.allocator); + } } + pub const ResetState = enum { reset, closing }; + /// Reset this response to its initial state. This must be called before handling a second request on the same connection. - pub fn reset(res: *Response) void { - res.request.headers.deinit(); - res.headers.deinit(); + pub fn reset(res: *Response) ResetState { + if (res.state == .first) { + res.state = .start; + return .reset; + } + + if (!res.request.parser.done) { + // If the response wasn't fully read, then we need to close the connection. + res.connection.conn.closing = true; + return .closing; + } + + // A connection is only keep-alive if the Connection header is present and it's value is not "close". + // The server and client must both agree + // + // do() defaults to using keep-alive if the client requests it. + const res_connection = res.headers.getFirstValue("connection"); + const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); + + const req_connection = res.request.headers.getFirstValue("connection"); + const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); + if (req_keepalive and (res_keepalive or res_connection == null)) { + res.connection.conn.closing = false; + } else { + res.connection.conn.closing = true; + } switch (res.request.compression) { .none => {}, @@ -352,19 +414,30 @@ pub const Response = struct { .zstd => |*zstd| zstd.deinit(), } - if (!res.request.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - res.connection.conn.closing = true; - } + res.state = .start; + res.version = .@"HTTP/1.1"; + res.status = .ok; + res.reason = null; - if (res.connection.conn.closing) { - res.connection.close(); + res.transfer_encoding = .none; - if (res.request.parser.header_bytes_owned) { - res.request.parser.header_bytes.deinit(res.server.allocator); - } + res.headers.clearRetainingCapacity(); + + res.request.headers.clearAndFree(); // FIXME: figure out why `clearRetainingCapacity` causes a leak in hash_map here + res.request.parser.reset(); + + res.request = Request{ + .version = undefined, + .method = undefined, + .target = undefined, + .headers = res.request.headers, + .parser = res.request.parser, + }; + + if (res.connection.conn.closing) { + return .closing; } else { - res.request.parser.reset(); + return .reset; } } @@ -372,8 +445,12 @@ pub const Response = struct { /// Send the response headers. pub fn do(res: *Response) !void { - var buffered = std.io.bufferedWriter(res.connection.writer()); - const w = buffered.writer(); + switch (res.state) { + .waited => res.state = .responded, + .first, .start, .responded, .finished => unreachable, + } + + const w = res.connection.writer(); try w.writeAll(@tagName(res.version)); try w.writeByte(' '); @@ -391,7 +468,14 @@ pub const Response = struct { } if (!res.headers.contains("connection")) { - try w.writeAll("Connection: keep-alive\r\n"); + const req_connection = res.request.headers.getFirstValue("connection"); + const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); + + if (req_keepalive) { + try w.writeAll("Connection: keep-alive\r\n"); + } else { + try w.writeAll("Connection: close\r\n"); + } } const has_transfer_encoding = res.headers.contains("transfer-encoding"); @@ -424,7 +508,7 @@ pub const Response = struct { try w.writeAll("\r\n"); - try buffered.flush(); + try res.connection.flush(); } pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError; @@ -452,29 +536,23 @@ pub const Response = struct { /// Wait for the client to send a complete request head. pub fn wait(res: *Response) WaitError!void { + switch (res.state) { + .first, .start => res.state = .waited, + .waited, .responded, .finished => unreachable, + } + while (true) { try res.connection.fill(); - const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); res.connection.clear(@intCast(u16, nchecked)); if (res.request.parser.state.isContent()) break; } - res.request.headers = .{ .allocator = res.server.allocator, .owned = true }; + res.request.headers = .{ .allocator = res.allocator, .owned = true }; try res.request.parse(res.request.parser.header_bytes.items); - const res_connection = res.headers.getFirstValue("connection"); - const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); - - const req_connection = res.request.headers.getFirstValue("connection"); - const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); - if (res_keepalive and req_keepalive) { - res.connection.conn.closing = false; - } else { - res.connection.conn.closing = true; - } - if (res.request.transfer_encoding) |te| { switch (te) { .chunked => { @@ -494,13 +572,13 @@ pub const Response = struct { if (res.request.transfer_compression) |tc| switch (tc) { .compress => return error.CompressionNotSupported, .deflate => res.request.compression = .{ - .deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, + .deflate = std.compress.zlib.zlibStream(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .gzip => res.request.compression = .{ - .gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, + .gzip = std.compress.gzip.decompress(res.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => res.request.compression = .{ - .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()), + .zstd = std.compress.zstd.decompressStream(res.allocator, res.transferReader()), }, }; } @@ -515,6 +593,11 @@ pub const Response = struct { } pub fn read(res: *Response, buffer: []u8) ReadError!usize { + switch (res.state) { + .waited, .responded, .finished => {}, + .first, .start => unreachable, + } + const out_index = switch (res.request.compression) { .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, @@ -528,12 +611,12 @@ pub const Response = struct { while (!res.request.parser.state.isContent()) { // read trailing headers try res.connection.fill(); - const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + const nchecked = try res.request.parser.checkCompleteHead(res.allocator, res.connection.peek()); res.connection.clear(@intCast(u16, nchecked)); } if (has_trail) { - res.request.headers = http.Headers{ .allocator = res.server.allocator, .owned = false }; + res.request.headers = http.Headers{ .allocator = res.allocator, .owned = false }; // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. // This will *only* fail for a malformed trailer. @@ -564,6 +647,11 @@ pub const Response = struct { /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. pub fn write(res: *Response, bytes: []const u8) WriteError!usize { + switch (res.state) { + .responded => {}, + .first, .waited, .start, .finished => unreachable, + } + switch (res.transfer_encoding) { .chunked => { try res.connection.writer().print("{x}\r\n", .{bytes.len}); @@ -583,7 +671,7 @@ pub const Response = struct { } } - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { + pub fn writeAll(req: *Response, bytes: []const u8) WriteError!void { var index: usize = 0; while (index < bytes.len) { index += try write(req, bytes[index..]); @@ -594,11 +682,18 @@ pub const Response = struct { /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(res: *Response) FinishError!void { + switch (res.state) { + .responded => res.state = .finished, + .first, .waited, .start, .finished => unreachable, + } + switch (res.transfer_encoding) { .chunked => try res.connection.writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } + + try res.connection.flush(); } }; @@ -635,31 +730,34 @@ pub const HeaderStrategy = union(enum) { static: []u8, }; -/// Accept a new connection and allocate a Response for it. -pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response { +pub const AcceptOptions = struct { + allocator: Allocator, + header_strategy: HeaderStrategy = .{ .dynamic = 8192 }, +}; + +/// Accept a new connection. +pub fn accept(server: *Server, options: AcceptOptions) AcceptError!Response { const in = try server.socket.accept(); - const res = try server.allocator.create(Response); - res.* = .{ - .server = server, + return Response{ + .allocator = options.allocator, .address = in.address, .connection = .{ .conn = .{ .stream = in.stream, .protocol = .plain, } }, - .headers = .{ .allocator = server.allocator }, + .headers = .{ .allocator = options.allocator }, .request = .{ .version = undefined, .method = undefined, .target = undefined, - .parser = switch (options) { + .headers = .{ .allocator = options.allocator, .owned = false }, + .parser = switch (options.header_strategy) { .dynamic => |max| proto.HeadersParser.initDynamic(max), .static => |buf| proto.HeadersParser.initStatic(buf), }, }, }; - - return res; } test "HTTP server handles a chunked transfer coding request" { |
