From 134294230a08d531afd0a6d823ae3046b1699b0f Mon Sep 17 00:00:00 2001 From: Nameless Date: Fri, 14 Apr 2023 12:37:32 -0500 Subject: std.http: add Headers --- lib/std/http/Server.zig | 295 ++++++++++++++++++++++++------------------------ 1 file changed, 149 insertions(+), 146 deletions(-) (limited to 'lib/std/http/Server.zig') diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 1ecb8fbd69..acf6f3c22d 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -157,134 +157,120 @@ pub const BufferedConnection = struct { /// A HTTP request originating from a client. pub const Request = struct { - pub const Headers = struct { - method: http.Method, - target: []const u8, - version: http.Version, - content_length: ?u64 = null, - transfer_encoding: ?http.TransferEncoding = null, - transfer_compression: ?http.ContentEncoding = null, - connection: http.Connection = .close, - host: ?[]const u8 = null, - - pub const ParseError = error{ - ShortHttpStatusLine, - BadHttpVersion, - UnknownHttpMethod, - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidCharacter, - }; - - pub fn parse(bytes: []const u8) !Headers { - var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n"); - - const first_line = it.next() orelse return error.HttpHeadersInvalid; - if (first_line.len < 10) - return error.ShortHttpStatusLine; - - const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - const method_str = first_line[0..method_end]; - const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod; + pub const ParseError = Allocator.Error || error{ + ShortHttpStatusLine, + BadHttpVersion, + UnknownHttpMethod, + HttpHeadersInvalid, + HttpHeaderContinuationsUnsupported, + HttpTransferEncodingUnsupported, + HttpConnectionHeaderUnsupported, + InvalidCharacter, + }; - const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - if (version_start == method_end) return error.HttpHeadersInvalid; + pub fn parse(req: *Request, bytes: []const u8) !void { + var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n"); - const version_str = first_line[version_start + 1 ..]; - if (version_str.len != 8) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(version_str[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.BadHttpVersion, - }; + const first_line = it.next() orelse return error.HttpHeadersInvalid; + if (first_line.len < 10) + return error.ShortHttpStatusLine; - const target = first_line[method_end + 1 .. version_start]; + const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; + const method_str = first_line[0..method_end]; + const method = std.meta.stringToEnum(http.Method, method_str) orelse return error.UnknownHttpMethod; - var headers: Headers = .{ - .method = method, - .target = target, - .version = version, - }; + const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; + if (version_start == method_end) return error.HttpHeadersInvalid; - while (it.next()) |line| { - if (line.len == 0) return error.HttpHeadersInvalid; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } + const version_str = first_line[version_start + 1 ..]; + if (version_str.len != 8) return error.HttpHeadersInvalid; + const version: http.Version = switch (int64(version_str[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.BadHttpVersion, + }; - var line_it = mem.tokenize(u8, line, ": "); - const header_name = line_it.next() orelse return error.HttpHeadersInvalid; - const header_value = line_it.rest(); - if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - if (headers.content_length != null) return error.HttpHeadersInvalid; - headers.content_length = try std.fmt.parseInt(u64, header_value, 10); - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwards(u8, header_value, ","); - - if (iter.next()) |first| { - const trimmed = mem.trim(u8, first, " "); - - if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { - if (headers.transfer_encoding != null) return error.HttpHeadersInvalid; - headers.transfer_encoding = te; - } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - if (headers.transfer_compression != null) return error.HttpHeadersInvalid; - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } - } + const target = first_line[method_end + 1 .. version_start]; - if (iter.next()) |second| { - if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported; + req.method = method; + req.target = target; + req.version = version; - const trimmed = mem.trim(u8, second, " "); + while (it.next()) |line| { + if (line.len == 0) return error.HttpHeadersInvalid; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - headers.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } + var line_it = mem.tokenize(u8, line, ": "); + const header_name = line_it.next() orelse return error.HttpHeadersInvalid; + const header_value = line_it.rest(); + + try req.headers.append(header_name, header_value); + + if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (req.content_length != null) return error.HttpHeadersInvalid; + req.content_length = try std.fmt.parseInt(u64, header_value, 10); + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwards(u8, header_value, ","); + + if (iter.next()) |first| { + const trimmed = mem.trim(u8, first, " "); + + if (std.meta.stringToEnum(http.TransferEncoding, trimmed)) |te| { + if (req.transfer_encoding != null) return error.HttpHeadersInvalid; + req.transfer_encoding = te; + } else if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + if (req.transfer_compression != null) return error.HttpHeadersInvalid; + req.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; } + } - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (headers.transfer_compression != null) return error.HttpHeadersInvalid; + if (iter.next()) |second| { + if (req.transfer_compression != null) return error.HttpTransferEncodingUnsupported; - const trimmed = mem.trim(u8, header_value, " "); + const trimmed = mem.trim(u8, second, " "); if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - headers.transfer_compression = ce; + req.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; } - } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) { - headers.connection = .keep_alive; - } else if (std.ascii.eqlIgnoreCase(header_value, "close")) { - headers.connection = .close; - } else { - return error.HttpConnectionHeaderUnsupported; - } - } else if (std.ascii.eqlIgnoreCase(header_name, "host")) { - headers.host = header_value; } - } - return headers; - } + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (req.transfer_compression != null) return error.HttpHeadersInvalid; - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(u64, array.*); + const trimmed = mem.trim(u8, header_value, " "); + + if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + req.transfer_compression = ce; + } else { + return error.HttpTransferEncodingUnsupported; + } + } } - }; + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(u64, array.*); + } - headers: Headers = undefined, + method: http.Method, + target: []const u8, + version: http.Version, + + content_length: ?u64 = null, + transfer_encoding: ?http.TransferEncoding = null, + transfer_compression: ?http.ContentEncoding = null, + + headers: http.Headers = undefined, parser: proto.HeadersParser, compression: Compression = .none, }; @@ -295,23 +281,17 @@ pub const Request = struct { /// Order of operations: accept -> wait -> do [ -> write -> finish][ -> reset /] /// \ -> read / pub const Response = struct { - pub const Headers = struct { - version: http.Version = .@"HTTP/1.1", - status: http.Status = .ok, - reason: ?[]const u8 = null, + version: http.Version = .@"HTTP/1.1", + status: http.Status = .ok, + reason: ?[]const u8 = null, - server: ?[]const u8 = "zig (std.http)", - connection: http.Connection = .keep_alive, - transfer_encoding: RequestTransfer = .none, - - custom: []const http.CustomHeader = &[_]http.CustomHeader{}, - }; + transfer_encoding: ResponseTransfer = .none, server: *Server, address: net.Address, connection: BufferedConnection, - headers: Headers = .{}, + headers: http.Headers, request: Request, /// Reset this response to its initial state. This must be called before handling a second request on the same connection. @@ -346,41 +326,54 @@ pub const Response = struct { var buffered = std.io.bufferedWriter(res.connection.writer()); const w = buffered.writer(); - try w.writeAll(@tagName(res.headers.version)); + try w.writeAll(@tagName(res.version)); try w.writeByte(' '); - try w.print("{d}", .{@enumToInt(res.headers.status)}); + try w.print("{d}", .{@enumToInt(res.status)}); try w.writeByte(' '); - if (res.headers.reason) |reason| { + if (res.reason) |reason| { try w.writeAll(reason); - } else if (res.headers.status.phrase()) |phrase| { + } else if (res.status.phrase()) |phrase| { try w.writeAll(phrase); } + try w.writeAll("\r\n"); - if (res.headers.server) |server| { - try w.writeAll("\r\nServer: "); - try w.writeAll(server); + if (!res.headers.contains("server")) { + try w.writeAll("Server: zig (std.http)\r\n"); } - if (res.headers.connection == .close) { - try w.writeAll("\r\nConnection: close"); - } else { - try w.writeAll("\r\nConnection: keep-alive"); + if (!res.headers.contains("connection")) { + try w.writeAll("Connection: keep-alive\r\n"); } - switch (res.headers.transfer_encoding) { - .chunked => try w.writeAll("\r\nTransfer-Encoding: chunked"), - .content_length => |content_length| try w.print("\r\nContent-Length: {d}", .{content_length}), - .none => {}, - } + const has_transfer_encoding = res.headers.contains("transfer-encoding"); + const has_content_length = res.headers.contains("content-length"); - for (res.headers.custom) |header| { - try w.writeAll("\r\n"); - try w.writeAll(header.name); - try w.writeAll(": "); - try w.writeAll(header.value); + if (!has_transfer_encoding and !has_content_length) { + switch (res.transfer_encoding) { + .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"), + .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}), + .none => {}, + } + } else { + if (has_content_length) { + const content_length = try std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10); + + res.transfer_encoding = .{ .content_length = content_length }; + } else if (has_transfer_encoding) { + const transfer_encoding = res.headers.getFirstValue("content-length").?; + if (std.mem.eql(u8, transfer_encoding, "chunked")) { + res.transfer_encoding = .chunked; + } else { + return error.UnsupportedTransferEncoding; + } + } else { + res.transfer_encoding = .none; + } } - try w.writeAll("\r\n\r\n"); + try w.print("{}", .{res.headers}); + + try w.writeAll("\r\n"); try buffered.flush(); } @@ -419,22 +412,28 @@ pub const Response = struct { if (res.request.parser.state.isContent()) break; } - res.request.headers = try Request.Headers.parse(res.request.parser.header_bytes.items); + res.request.headers = .{ .allocator = res.server.allocator, .owned = true }; + try res.request.parse(res.request.parser.header_bytes.items); - if (res.headers.connection == .keep_alive and res.request.headers.connection == .keep_alive) { + const res_connection = res.headers.getFirstValue("connection"); + const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?); + + const req_connection = res.request.headers.getFirstValue("connection"); + const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?); + if (res_keepalive and req_keepalive) { res.connection.conn.closing = false; } else { res.connection.conn.closing = true; } - if (res.request.headers.transfer_encoding) |te| { + if (res.request.transfer_encoding) |te| { switch (te) { .chunked => { res.request.parser.next_chunk_length = 0; res.request.parser.state = .chunk_head_size; }, } - } else if (res.request.headers.content_length) |cl| { + } else if (res.request.content_length) |cl| { res.request.parser.next_chunk_length = cl; if (cl == 0) res.request.parser.done = true; @@ -443,7 +442,7 @@ pub const Response = struct { } if (!res.request.parser.done) { - if (res.request.headers.transfer_compression) |tc| switch (tc) { + if (res.request.transfer_compression) |tc| switch (tc) { .compress => return error.CompressionNotSupported, .deflate => res.request.compression = .{ .deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()), @@ -495,7 +494,7 @@ pub const Response = struct { /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent. pub fn write(res: *Response, bytes: []const u8) WriteError!usize { - switch (res.headers.transfer_encoding) { + switch (res.transfer_encoding) { .chunked => { try res.connection.writer().print("{x}\r\n", .{bytes.len}); try res.connection.writeAll(bytes); @@ -525,7 +524,7 @@ pub const Response = struct { }; /// The mode of transport for responses. -pub const RequestTransfer = union(enum) { +pub const ResponseTransfer = union(enum) { content_length: u64, chunked: void, none: void, @@ -588,7 +587,11 @@ pub fn accept(server: *Server, options: HeaderStrategy) AcceptError!*Response { .stream = in.stream, .protocol = .plain, } }, + .headers = .{ .allocator = server.allocator }, .request = .{ + .version = undefined, + .method = undefined, + .target = undefined, .parser = switch (options) { .dynamic => |max| proto.HeadersParser.initDynamic(max), .static => |buf| proto.HeadersParser.initStatic(buf), -- cgit v1.2.3 From 85221b4e977756bf30f7d45a7eb8636ea0d5168a Mon Sep 17 00:00:00 2001 From: Nameless Date: Sun, 16 Apr 2023 16:26:25 -0500 Subject: std.http: curate some Server errors, fix reading chunked bodies --- lib/std/http/Client.zig | 66 ++++++++++++--------- lib/std/http/Server.zig | 155 +++++++++++++++++++++++++++++++----------------- src/Package.zig | 2 + 3 files changed, 140 insertions(+), 83 deletions(-) (limited to 'lib/std/http/Server.zig') diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 614b70b216..1b144b2b18 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -193,7 +193,13 @@ pub const Connection = struct { }; } - pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure }; + pub const ReadError = error{ + TlsFailure, + TlsAlert, + ConnectionTimedOut, + ConnectionResetByPeer, + UnexpectedReadFailure, + }; pub const Reader = std.io.Reader(*Connection, ReadError, read); @@ -518,7 +524,10 @@ pub const Request = struct { req.* = undefined; } - pub fn start(req: *Request, uri: Uri) !void { + pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + + /// Send the request to the server. + pub fn start(req: *Request, uri: Uri) StartError!void { var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer()); const w = buffered.writer(); @@ -575,7 +584,7 @@ pub const Request = struct { } } else { if (has_content_length) { - const content_length = try std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10); + const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; req.transfer_encoding = .{ .content_length = content_length }; } else if (has_transfer_encoding) { @@ -618,7 +627,7 @@ pub const Request = struct { return index; } - pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed }; + pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported }; /// Waits for a response from the server and parses any headers that are sent. /// This function will block until the final response is received. @@ -739,25 +748,23 @@ pub const Request = struct { /// Reads data from the response body. Must be called after `do`. pub fn read(req: *Request, buffer: []u8) ReadError!usize { - while (true) { - const out_index = switch (req.response.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), - }; - - if (out_index == 0) { - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.data.buffered.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); - req.connection.data.buffered.clear(@intCast(u16, nchecked)); - } - } + const out_index = switch (req.response.compression) { + .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, + else => try req.transferRead(buffer), + }; + + if (out_index == 0) { + while (!req.response.parser.state.isContent()) { // read trailing headers + try req.connection.data.buffered.fill(); - return out_index; + const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); + req.connection.data.buffered.clear(@intCast(u16, nchecked)); + } } + + return out_index; } /// Reads data from the response body. Must be called after `do`. @@ -800,15 +807,19 @@ pub const Request = struct { } } + pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(req, bytes[index..]); + } + } + pub const FinishError = WriteError || error{MessageNotCompleted}; /// Finish the body of a request. This notifies the server that you have no more data to send. pub fn finish(req: *Request) FinishError!void { switch (req.transfer_encoding) { - .chunked => req.connection.data.conn.writeAll("0\r\n\r\n") catch |err| { - req.client.last_error = .{ .write = err }; - return error.WriteFailed; - }, + .chunked => try req.connection.data.conn.writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, } @@ -923,7 +934,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio } } -pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || std.fmt.ParseIntError || BufferedConnection.WriteError || error{ +pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{ UnsupportedUrlScheme, UriMissingHost, @@ -998,6 +1009,7 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option .handle_redirects = options.handle_redirects, .response = .{ .status = undefined, + .reason = undefined, .version = undefined, .headers = undefined, .parser = switch (options.header_strategy) { @@ -1011,8 +1023,6 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option req.arena = std.heap.ArenaAllocator.init(client.allocator); - try req.start(uri); - return req; } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index acf6f3c22d..779e9fa984 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -23,21 +23,33 @@ pub const Connection = struct { pub const Protocol = enum { plain }; - pub fn read(conn: *Connection, buffer: []u8) !usize { - switch (conn.protocol) { - .plain => return conn.stream.read(buffer), + pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { + return switch (conn.protocol) { + .plain => conn.stream.read(buffer), // .tls => return conn.tls_client.read(conn.stream, buffer), - } + } catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; } - pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize { - switch (conn.protocol) { - .plain => return conn.stream.readAtLeast(buffer, len), + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize { + return switch (conn.protocol) { + .plain => conn.stream.readAtLeast(buffer, len), // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), - } + } catch |err| switch (err) { + error.ConnectionTimedOut => return error.ConnectionTimedOut, + error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, + else => return error.UnexpectedReadFailure, + }; } - pub const ReadError = net.Stream.ReadError; + pub const ReadError = error{ + ConnectionTimedOut, + ConnectionResetByPeer, + UnexpectedReadFailure, + }; pub const Reader = std.io.Reader(*Connection, ReadError, read); @@ -45,21 +57,31 @@ pub const Connection = struct { return Reader{ .context = conn }; } - pub fn writeAll(conn: *Connection, buffer: []const u8) !void { - switch (conn.protocol) { - .plain => return conn.stream.writeAll(buffer), + pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void { + return switch (conn.protocol) { + .plain => conn.stream.writeAll(buffer), // .tls => return conn.tls_client.writeAll(conn.stream, buffer), - } + } catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; } - pub fn write(conn: *Connection, buffer: []const u8) !usize { - switch (conn.protocol) { - .plain => return conn.stream.write(buffer), + pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { + return switch (conn.protocol) { + .plain => conn.stream.write(buffer), // .tls => return conn.tls_client.write(conn.stream, buffer), - } + } catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; } - pub const WriteError = net.Stream.WriteError || error{}; + pub const WriteError = error{ + ConnectionResetByPeer, + UnexpectedWriteFailure, + }; + pub const Writer = std.io.Writer(*Connection, WriteError, write); pub fn writer(conn: *Connection) Writer { @@ -155,6 +177,25 @@ pub const BufferedConnection = struct { } }; +/// The mode of transport for responses. +pub const ResponseTransfer = union(enum) { + content_length: u64, + chunked: void, + none: void, +}; + +/// The decompressor for request messages. +pub const Compression = union(enum) { + pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader); + pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader); + pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{}); + + deflate: DeflateDecompressor, + gzip: GzipDecompressor, + zstd: ZstdDecompressor, + none: void, +}; + /// A HTTP request originating from a client. pub const Request = struct { pub const ParseError = Allocator.Error || error{ @@ -165,10 +206,11 @@ pub const Request = struct { HttpHeaderContinuationsUnsupported, HttpTransferEncodingUnsupported, HttpConnectionHeaderUnsupported, - InvalidCharacter, + InvalidContentLength, + CompressionNotSupported, }; - pub fn parse(req: *Request, bytes: []const u8) !void { + pub fn parse(req: *Request, bytes: []const u8) ParseError!void { var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n"); const first_line = it.next() orelse return error.HttpHeadersInvalid; @@ -211,7 +253,7 @@ pub const Request = struct { if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { if (req.content_length != null) return error.HttpHeadersInvalid; - req.content_length = try std.fmt.parseInt(u64, header_value, 10); + req.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { // Transfer-Encoding: second, first // Transfer-Encoding: deflate, chunked @@ -321,6 +363,8 @@ pub const Response = struct { } } + pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength }; + /// Send the response headers. pub fn do(res: *Response) !void { var buffered = std.io.bufferedWriter(res.connection.writer()); @@ -356,7 +400,7 @@ pub const Response = struct { } } else { if (has_content_length) { - const content_length = try std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10); + const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength; res.transfer_encoding = .{ .content_length = content_length }; } else if (has_transfer_encoding) { @@ -386,23 +430,23 @@ pub const Response = struct { return .{ .context = res }; } - pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize { - if (res.request.parser.isComplete()) return 0; + fn transferRead(res: *Response, buf: []u8) TransferReadError!usize { + if (res.request.parser.done) return 0; var index: usize = 0; while (index == 0) { const amt = try res.request.parser.read(&res.connection, buf[index..], false); - if (amt == 0 and res.request.parser.isComplete()) break; + if (amt == 0 and res.request.parser.done) break; index += amt; } return index; } - pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported}; + pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported }; /// Wait for the client to send a complete request head. - pub fn wait(res: *Response) !void { + pub fn wait(res: *Response) WaitError!void { while (true) { try res.connection.fill(); @@ -445,10 +489,10 @@ pub const Response = struct { if (res.request.transfer_compression) |tc| switch (tc) { .compress => return error.CompressionNotSupported, .deflate => res.request.compression = .{ - .deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()), + .deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .gzip => res.request.compression = .{ - .gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()), + .gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed, }, .zstd => res.request.compression = .{ .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()), @@ -457,7 +501,7 @@ pub const Response = struct { } } - pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError; + pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure}; pub const Reader = std.io.Reader(*Response, ReadError, read); @@ -466,12 +510,23 @@ pub const Response = struct { } pub fn read(res: *Response, buffer: []u8) ReadError!usize { - return switch (res.request.compression) { - .deflate => |*deflate| try deflate.read(buffer), - .gzip => |*gzip| try gzip.read(buffer), - .zstd => |*zstd| try zstd.read(buffer), + const out_index = switch (res.request.compression) { + .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, else => try res.transferRead(buffer), }; + + if (out_index == 0) { + while (!res.request.parser.state.isContent()) { // read trailing headers + try res.connection.fill(); + + const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); + res.connection.clear(@intCast(u16, nchecked)); + } + } + + return out_index; } pub fn readAll(res: *Response, buffer: []u8) !usize { @@ -513,9 +568,18 @@ pub const Response = struct { } } + pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try write(req, bytes[index..]); + } + } + + pub const FinishError = WriteError || error{MessageNotCompleted}; + /// Finish the body of a request. This notifies the server that you have no more data to send. - pub fn finish(res: *Response) !void { - switch (res.headers.transfer_encoding) { + pub fn finish(res: *Response) FinishError!void { + switch (res.transfer_encoding) { .chunked => try res.connection.writeAll("0\r\n\r\n"), .content_length => |len| if (len != 0) return error.MessageNotCompleted, .none => {}, @@ -523,25 +587,6 @@ pub const Response = struct { } }; -/// The mode of transport for responses. -pub const ResponseTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for request messages. -pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader); - pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader); - pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{}); - - deflate: DeflateDecompressor, - gzip: GzipDecompressor, - zstd: ZstdDecompressor, - none: void, -}; - pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server { return .{ .allocator = allocator, diff --git a/src/Package.zig b/src/Package.zig index f471e2d606..7c9d2a5f20 100644 --- a/src/Package.zig +++ b/src/Package.zig @@ -485,6 +485,8 @@ fn fetchAndUnpack( var req = try http_client.request(uri, h, .{ .method = .GET }); defer req.deinit(); + try req.start(); + try req.do(); if (mem.endsWith(u8, uri.path, ".tar.gz")) { -- cgit v1.2.3 From a23c8662b41cf6954d8294ea316fb28a88481a7e Mon Sep 17 00:00:00 2001 From: Nameless Date: Mon, 17 Apr 2023 19:37:24 -0500 Subject: std.http: pass Method to request directly, parse trailing headers --- lib/std/http/Client.zig | 32 ++++++++++++++++++++------------ lib/std/http/Server.zig | 10 ++++++++++ src/Package.zig | 2 +- 3 files changed, 31 insertions(+), 13 deletions(-) (limited to 'lib/std/http/Server.zig') diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 1b144b2b18..4ff29a215a 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -527,7 +527,7 @@ pub const Request = struct { pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; /// Send the request to the server. - pub fn start(req: *Request, uri: Uri) StartError!void { + pub fn start(req: *Request) StartError!void { var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer()); const w = buffered.writer(); @@ -535,14 +535,14 @@ pub const Request = struct { try w.writeByte(' '); if (req.method == .CONNECT) { - try w.writeAll(uri.host.?); + try w.writeAll(req.uri.host.?); try w.writeByte(':'); - try w.print("{}", .{uri.port.?}); + try w.print("{}", .{req.uri.port.?}); } else if (req.connection.data.proxied) { // proxied connections require the full uri - try w.print("{+/}", .{uri}); + try w.print("{+/}", .{req.uri}); } else { - try w.print("{/}", .{uri}); + try w.print("{/}", .{req.uri}); } try w.writeByte(' '); @@ -551,7 +551,7 @@ pub const Request = struct { if (!req.headers.contains("host")) { try w.writeAll("Host: "); - try w.writeAll(uri.host.?); + try w.writeAll(req.uri.host.?); try w.writeAll("\r\n"); } @@ -704,8 +704,7 @@ pub const Request = struct { req.arena.deinit(); req.arena = new_arena; - const new_req = try req.client.request(resolved_url, req.headers, .{ - .method = req.method, + const new_req = try req.client.request(req.method, resolved_url, req.headers, .{ .version = req.version, .max_redirects = req.redirects_left - 1, .header_strategy = if (req.response.parser.header_bytes_owned) .{ @@ -738,7 +737,7 @@ pub const Request = struct { } } - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure}; + pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; pub const Reader = std.io.Reader(*Request, ReadError, read); @@ -756,12 +755,22 @@ pub const Request = struct { }; if (out_index == 0) { + const has_trail = !req.response.parser.state.isContent(); + while (!req.response.parser.state.isContent()) { // read trailing headers try req.connection.data.buffered.fill(); const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek()); req.connection.data.buffered.clear(@intCast(u16, nchecked)); } + + if (has_trail) { + req.response.headers = http.Headers{ .allocator = req.client.allocator, .owned = false }; + + // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. + // This will *only* fail for a malformed trailer. + req.response.parse(req.response.parser.header_bytes.items) catch return error.InvalidTrailers; + } } return out_index; @@ -943,7 +952,6 @@ pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request }; pub const Options = struct { - method: http.Method = .GET, version: http.Version = .@"HTTP/1.1", handle_redirects: bool = true, @@ -976,7 +984,7 @@ pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{ /// Form and send a http request to a server. /// This function is threadsafe. -pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Options) RequestError!Request { +pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: Options) RequestError!Request { const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme; const port: u16 = uri.port orelse switch (protocol) { @@ -1003,7 +1011,7 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option .client = client, .connection = conn, .headers = headers, - .method = options.method, + .method = method, .version = options.version, .redirects_left = options.max_redirects, .handle_redirects = options.handle_redirects, diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 779e9fa984..94efb94d79 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -518,12 +518,22 @@ pub const Response = struct { }; if (out_index == 0) { + const has_trail = !res.request.parser.state.isContent(); + while (!res.request.parser.state.isContent()) { // read trailing headers try res.connection.fill(); const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek()); res.connection.clear(@intCast(u16, nchecked)); } + + if (has_trail) { + res.request.headers = http.Headers{ .allocator = res.server.allocator, .owned = false }; + + // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error. + // This will *only* fail for a malformed trailer. + res.request.parse(res.request.parser.header_bytes.items) catch return error.InvalidTrailers; + } } return out_index; diff --git a/src/Package.zig b/src/Package.zig index 7c9d2a5f20..7d98ddaba3 100644 --- a/src/Package.zig +++ b/src/Package.zig @@ -482,7 +482,7 @@ fn fetchAndUnpack( var h = std.http.Headers{ .allocator = gpa }; defer h.deinit(); - var req = try http_client.request(uri, h, .{ .method = .GET }); + var req = try http_client.request(.GET, uri, h, .{}); defer req.deinit(); try req.start(); -- cgit v1.2.3