diff options
| author | Andrew Kelley <andrew@ziglang.org> | 2023-01-05 19:42:59 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-05 19:42:59 -0700 |
| commit | b3e495a38a5e334f5e30e255592f810e0017919c (patch) | |
| tree | a7960dd1f6a8feb849452dfa69151eedac836722 /lib/std/http/Client.zig | |
| parent | 6ad92108e2cbba06064724d8d91abaede20f355a (diff) | |
| parent | 3055ab7f8639deca318f238f21680776a7149acb (diff) | |
| download | zig-b3e495a38a5e334f5e30e255592f810e0017919c.tar.gz zig-b3e495a38a5e334f5e30e255592f810e0017919c.zip | |
Merge pull request #14202 from ziglang/std.http
std.http.Client: support HTTP redirects
Diffstat (limited to 'lib/std/http/Client.zig')
| -rw-r--r-- | lib/std/http/Client.zig | 617 |
1 files changed, 514 insertions, 103 deletions
diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 8a4a771416..c6262f4706 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1,62 +1,447 @@ //! This API is a barely-touched, barely-functional http client, just the //! absolute minimum thing I needed in order to test `std.crypto.tls`. Bear //! with me and I promise the API will become useful and streamlined. +//! +//! TODO: send connection: keep-alive and LRU cache a configurable number of +//! open connections to skip DNS and TLS handshake for subsequent requests. const std = @import("../std.zig"); +const mem = std.mem; const assert = std.debug.assert; const http = std.http; const net = std.net; const Client = @This(); const Url = std.Url; +const Allocator = std.mem.Allocator; +const testing = std.testing; -allocator: std.mem.Allocator, -headers: std.ArrayListUnmanaged(u8) = .{}, -active_requests: usize = 0, +/// Used for tcpConnectToHost and storing HTTP headers when an externally +/// managed buffer is not provided. +allocator: Allocator, ca_bundle: std.crypto.Certificate.Bundle = .{}, +pub const Connection = struct { + stream: net.Stream, + /// undefined unless protocol is tls. + tls_client: std.crypto.tls.Client, + protocol: Protocol, + + pub const Protocol = enum { plain, tls }; + + pub fn read(conn: *Connection, buffer: []u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.read(buffer), + .tls => return conn.tls_client.read(conn.stream, buffer), + } + } + + pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize { + switch (conn.protocol) { + .plain => return conn.stream.readAtLeast(buffer, len), + .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len), + } + } + + pub fn writeAll(conn: *Connection, buffer: []const u8) !void { + switch (conn.protocol) { + .plain => return conn.stream.writeAll(buffer), + .tls => return conn.tls_client.writeAll(conn.stream, buffer), + } + } + + pub fn write(conn: *Connection, buffer: []const u8) !usize { + switch (conn.protocol) { + .plain => return conn.stream.write(buffer), + .tls => return conn.tls_client.write(conn.stream, buffer), + } + } +}; + /// TODO: emit error.UnexpectedEndOfStream or something like that when the read /// data does not match the content length. This is necessary since HTTPS disables /// close_notify protection on underlying TLS streams. pub const Request = struct { client: *Client, - stream: net.Stream, - headers: std.ArrayListUnmanaged(u8) = .{}, - tls_client: std.crypto.tls.Client, - protocol: Protocol, - response_headers: http.Headers = .{}, + connection: Connection, + redirects_left: u32, + response: Response, + /// These are stored in Request so that they are available when following + /// redirects. + headers: Headers, - pub const Protocol = enum { http, https }; + pub const Response = struct { + headers: Response.Headers, + state: State, + header_bytes_owned: bool, + /// This could either be a fixed buffer provided by the API user or it + /// could be our own array list. + header_bytes: std.ArrayListUnmanaged(u8), + max_header_bytes: usize, - pub const Options = struct { + pub const Headers = struct { + location: ?[]const u8 = null, + status: http.Status, + version: http.Version, + content_length: ?u64 = null, + + pub fn parse(bytes: []const u8) !Response.Headers { + var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n"); + + const first_line = it.first(); + if (first_line.len < 12) + return error.ShortHttpStatusLine; + + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.BadHttpVersion, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status = @intToEnum(http.Status, parseInt3(first_line[9..12].*)); + + var headers: Response.Headers = .{ + .version = version, + .status = status, + }; + + while (it.next()) |line| { + if (line.len == 0) return error.HttpHeadersInvalid; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } + var line_it = mem.split(u8, line, ": "); + const header_name = line_it.first(); + const header_value = line_it.rest(); + if (std.ascii.eqlIgnoreCase(header_name, "location")) { + if (headers.location != null) return error.HttpHeadersInvalid; + headers.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + if (headers.content_length != null) return error.HttpHeadersInvalid; + headers.content_length = try std.fmt.parseInt(u64, header_value, 10); + } + } + + return headers; + } + + test "parse headers" { + const example = + "HTTP/1.1 301 Moved Permanently\r\n" ++ + "Location: https://www.example.com/\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n\r\n"; + const parsed = try Response.Headers.parse(example); + try testing.expectEqual(http.Version.@"HTTP/1.1", parsed.version); + try testing.expectEqual(http.Status.moved_permanently, parsed.status); + try testing.expectEqualStrings("https://www.example.com/", parsed.location orelse + return error.TestFailed); + try testing.expectEqual(@as(?u64, 220), parsed.content_length); + } + + test "header continuation" { + const example = + "HTTP/1.0 200 OK\r\n" ++ + "Content-Type: text/html;\r\n charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n\r\n"; + try testing.expectError( + error.HttpHeaderContinuationsUnsupported, + Response.Headers.parse(example), + ); + } + + test "extra content length" { + const example = + "HTTP/1.0 200 OK\r\n" ++ + "Content-Length: 220\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "content-length: 220\r\n\r\n"; + try testing.expectError( + error.HttpHeadersInvalid, + Response.Headers.parse(example), + ); + } + }; + + pub const State = enum { + invalid, + finished, + start, + seen_r, + seen_rn, + seen_rnr, + }; + + pub fn initDynamic(max: usize) Response { + return .{ + .state = .start, + .headers = undefined, + .header_bytes = .{}, + .max_header_bytes = max, + .header_bytes_owned = true, + }; + } + + pub fn initStatic(buf: []u8) Response { + return .{ + .state = .start, + .headers = undefined, + .header_bytes = .{ .items = buf[0..0], .capacity = buf.len }, + .max_header_bytes = buf.len, + .header_bytes_owned = false, + }; + } + + /// Returns how many bytes are part of HTTP headers. Always less than or + /// equal to bytes.len. If the amount returned is less than bytes.len, it + /// means the headers ended and the first byte after the double \r\n\r\n is + /// located at `bytes[result]`. + pub fn findHeadersEnd(r: *Response, bytes: []const u8) usize { + var index: usize = 0; + + // TODO: https://github.com/ziglang/zig/issues/8220 + state: while (true) { + switch (r.state) { + .invalid => unreachable, + .finished => unreachable, + .start => while (true) { + switch (bytes.len - index) { + 0 => return index, + 1 => { + if (bytes[index] == '\r') + r.state = .seen_r; + return index + 1; + }, + 2 => { + if (int16(bytes[index..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + } else if (bytes[index + 1] == '\r') { + r.state = .seen_r; + } + return index + 2; + }, + 3 => { + if (int16(bytes[index..][0..2]) == int16("\r\n") and + bytes[index + 2] == '\r') + { + r.state = .seen_rnr; + } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + } else if (bytes[index + 2] == '\r') { + r.state = .seen_r; + } + return index + 3; + }, + 4...15 => { + if (int32(bytes[index..][0..4]) == int32("\r\n\r\n")) { + r.state = .finished; + return index + 4; + } else if (int16(bytes[index + 1 ..][0..2]) == int16("\r\n") and + bytes[index + 3] == '\r') + { + r.state = .seen_rnr; + index += 4; + continue :state; + } else if (int16(bytes[index + 2 ..][0..2]) == int16("\r\n")) { + r.state = .seen_rn; + index += 4; + continue :state; + } else if (bytes[index + 3] == '\r') { + r.state = .seen_r; + index += 4; + continue :state; + } + index += 4; + continue; + }, + else => { + const chunk = bytes[index..][0..16]; + const v: @Vector(16, u8) = chunk.*; + const matches_r = v == @splat(16, @as(u8, '\r')); + const iota = std.simd.iota(u8, 16); + const default = @splat(16, @as(u8, 16)); + const sub_index = @reduce(.Min, @select(u8, matches_r, iota, default)); + switch (sub_index) { + 0...12 => { + index += sub_index + 4; + if (int32(chunk[sub_index..][0..4]) == int32("\r\n\r\n")) { + r.state = .finished; + return index; + } + continue; + }, + 13 => { + index += 16; + if (int16(chunk[14..][0..2]) == int16("\n\r")) { + r.state = .seen_rnr; + continue :state; + } + continue; + }, + 14 => { + index += 16; + if (chunk[15] == '\n') { + r.state = .seen_rn; + continue :state; + } + continue; + }, + 15 => { + r.state = .seen_r; + index += 16; + continue :state; + }, + 16 => { + index += 16; + continue; + }, + else => unreachable, + } + }, + } + }, + + .seen_r => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\n' => r.state = .seen_rn, + '\r' => r.state = .seen_r, + else => r.state = .start, + } + return index + 1; + }, + 2 => { + if (int16(bytes[index..][0..2]) == int16("\n\r")) { + r.state = .seen_rnr; + return index + 2; + } + r.state = .start; + return index + 2; + }, + else => { + if (int16(bytes[index..][0..2]) == int16("\n\r") and + bytes[index + 2] == '\n') + { + r.state = .finished; + return index + 3; + } + index += 3; + r.state = .start; + continue :state; + }, + }, + .seen_rn => switch (bytes.len - index) { + 0 => return index, + 1 => { + switch (bytes[index]) { + '\r' => r.state = .seen_rnr, + else => r.state = .start, + } + return index + 1; + }, + else => { + if (int16(bytes[index..][0..2]) == int16("\r\n")) { + r.state = .finished; + return index + 2; + } + index += 2; + r.state = .start; + continue :state; + }, + }, + .seen_rnr => switch (bytes.len - index) { + 0 => return index, + else => { + if (bytes[index] == '\n') { + r.state = .finished; + return index + 1; + } + index += 1; + r.state = .start; + continue :state; + }, + }, + } + + return index; + } + } + + fn parseInt3(nnn: @Vector(3, u8)) u10 { + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm); + } + + test parseInt3 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000".*)); + try expectEqual(@as(u10, 418), parseInt3("418".*)); + try expectEqual(@as(u10, 999), parseInt3("999".*)); + } + + inline fn int16(array: *const [2]u8) u16 { + return @bitCast(u16, array.*); + } + + inline fn int32(array: *const [4]u8) u32 { + return @bitCast(u32, array.*); + } + + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(u64, array.*); + } + + test "find headers end basic" { + var buffer: [1]u8 = undefined; + var r = Response.initStatic(&buffer); + try testing.expectEqual(@as(usize, 10), r.findHeadersEnd("HTTP/1.1 4")); + try testing.expectEqual(@as(usize, 2), r.findHeadersEnd("18")); + try testing.expectEqual(@as(usize, 8), r.findHeadersEnd(" lol\r\n\r\nblah blah")); + } + + test "find headers end vectorized" { + var buffer: [1]u8 = undefined; + var r = Response.initStatic(&buffer); + const example = + "HTTP/1.1 301 Moved Permanently\r\n" ++ + "Location: https://www.example.com/\r\n" ++ + "Content-Type: text/html; charset=UTF-8\r\n" ++ + "Content-Length: 220\r\n" ++ + "\r\ncontent"; + try testing.expectEqual(@as(usize, 131), r.findHeadersEnd(example)); + } + }; + + pub const Headers = struct { method: http.Method = .GET, }; - pub fn deinit(req: *Request) void { - req.client.active_requests -= 1; - req.headers.deinit(req.client.allocator); - req.* = undefined; - } + pub const Options = struct { + max_redirects: u32 = 3, + header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 }, - pub fn addHeader(req: *Request, name: []const u8, value: []const u8) !void { - const gpa = req.client.allocator; - // Ensure an extra +2 for the \r\n in end() - try req.headers.ensureUnusedCapacity(gpa, name.len + value.len + 6); - req.headers.appendSliceAssumeCapacity(name); - req.headers.appendSliceAssumeCapacity(": "); - req.headers.appendSliceAssumeCapacity(value); - req.headers.appendSliceAssumeCapacity("\r\n"); - } + pub const HeaderStrategy = union(enum) { + /// In this case, the client's Allocator will be used to store the + /// entire HTTP header. This value is the maximum total size of + /// HTTP headers allowed, otherwise + /// error.HttpHeadersExceededSizeLimit is returned from read(). + dynamic: usize, + /// This is used to store the entire HTTP header. If the HTTP + /// header is too big to fit, `error.HttpHeadersExceededSizeLimit` + /// is returned from read(). When this is used, `error.OutOfMemory` + /// cannot be returned from `read()`. + static: []u8, + }; + }; - pub fn end(req: *Request) !void { - req.headers.appendSliceAssumeCapacity("\r\n"); - switch (req.protocol) { - .http => { - try req.stream.writeAll(req.headers.items); - }, - .https => { - try req.tls_client.writeAll(req.stream, req.headers.items); - }, + /// May be skipped if header strategy is buffer. + pub fn deinit(req: *Request) void { + if (req.response.header_bytes_owned) { + req.response.header_bytes.deinit(req.client.allocator); } + req.* = undefined; } pub fn readAll(req: *Request, buffer: []u8) !usize { @@ -71,7 +456,7 @@ pub const Request = struct { assert(len <= buffer.len); var index: usize = 0; while (index < len) { - const headers_finished = req.response_headers.state == .finished; + const headers_finished = req.response.state == .finished; const amt = try readAdvanced(req, buffer[index..]); if (amt == 0 and headers_finished) break; index += amt; @@ -82,100 +467,126 @@ pub const Request = struct { /// This one can return 0 without meaning EOF. /// TODO change to readvAdvanced pub fn readAdvanced(req: *Request, buffer: []u8) !usize { - if (req.response_headers.state == .finished) return readRaw(req, buffer); + if (req.response.state == .finished) return req.connection.read(buffer); - const amt = try readRaw(req, buffer); + const amt = try req.connection.read(buffer); const data = buffer[0..amt]; - const i = req.response_headers.feed(data); - if (req.response_headers.state == .invalid) return error.InvalidHttpHeaders; - if (i < data.len) { - const rest = data[i..]; - std.mem.copy(u8, buffer, rest); - return rest.len; + const i = req.response.findHeadersEnd(data); + if (req.response.state == .invalid) return error.HttpHeadersInvalid; + + const headers_data = data[0..i]; + if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) { + return error.HttpHeadersExceededSizeLimit; } - return 0; - } + try req.response.header_bytes.appendSlice(req.client.allocator, headers_data); - /// Only abstracts over http/https. - fn readRaw(req: *Request, buffer: []u8) !usize { - switch (req.protocol) { - .http => return req.stream.read(buffer), - .https => return req.tls_client.read(req.stream, buffer), + if (req.response.state == .finished) { + req.response.headers = try Response.Headers.parse(req.response.header_bytes.items); + } + + if (req.response.headers.status.class() == .redirect) { + if (req.redirects_left == 0) return error.TooManyHttpRedirects; + const location = req.response.headers.location orelse + return error.HttpRedirectMissingLocation; + const new_url = try std.Url.parse(location); + const new_req = try req.client.request(new_url, req.headers, .{ + .max_redirects = req.redirects_left - 1, + .header_strategy = if (req.response.header_bytes_owned) .{ + .dynamic = req.response.max_header_bytes, + } else .{ + .static = req.response.header_bytes.unusedCapacitySlice(), + }, + }); + req.deinit(); + req.* = new_req; + return readAdvanced(req, buffer); } - } - /// Only abstracts over http/https. - fn readAtLeastRaw(req: *Request, buffer: []u8, len: usize) !usize { - switch (req.protocol) { - .http => return req.stream.readAtLeast(buffer, len), - .https => return req.tls_client.readAtLeast(req.stream, buffer, len), + const body_data = data[i..]; + if (body_data.len > 0) { + mem.copy(u8, buffer, body_data); + return body_data.len; } + return 0; + } + + test { + _ = Response; } }; -pub fn deinit(client: *Client) void { - assert(client.active_requests == 0); - client.headers.deinit(client.allocator); +pub fn deinit(client: *Client, gpa: Allocator) void { + client.ca_bundle.deinit(gpa); client.* = undefined; } -pub fn request(client: *Client, url: Url, options: Request.Options) !Request { - const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse - return error.UnsupportedUrlScheme; - const port: u16 = url.port orelse switch (protocol) { - .http => 80, - .https => 443, - }; - - var req: Request = .{ - .client = client, - .stream = try net.tcpConnectToHost(client.allocator, url.host, port), - .protocol = protocol, +pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection { + var conn: Connection = .{ + .stream = try net.tcpConnectToHost(client.allocator, host, port), .tls_client = undefined, + .protocol = protocol, }; - client.active_requests += 1; - errdefer req.deinit(); switch (protocol) { - .http => {}, - .https => { - req.tls_client = try std.crypto.tls.Client.init(req.stream, client.ca_bundle, url.host); + .plain => {}, + .tls => { + conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host); // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. - req.tls_client.allow_truncation_attacks = true; + conn.tls_client.allow_truncation_attacks = true; }, } - try req.headers.ensureUnusedCapacity( - client.allocator, - @tagName(options.method).len + - 1 + - url.path.len + - " HTTP/1.1\r\nHost: ".len + - url.host.len + - "\r\nUpgrade-Insecure-Requests: 1\r\n".len + - client.headers.items.len + - 2, // for the \r\n at the end of headers - ); - req.headers.appendSliceAssumeCapacity(@tagName(options.method)); - req.headers.appendSliceAssumeCapacity(" "); - req.headers.appendSliceAssumeCapacity(url.path); - req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: "); - req.headers.appendSliceAssumeCapacity(url.host); - switch (protocol) { - .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"), - .http => req.headers.appendSliceAssumeCapacity("\r\n"), + return conn; +} + +pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Request.Options) !Request { + const protocol: Connection.Protocol = if (mem.eql(u8, url.scheme, "http")) + .plain + else if (mem.eql(u8, url.scheme, "https")) + .tls + else + return error.UnsupportedUrlScheme; + + const port: u16 = url.port orelse switch (protocol) { + .plain => 80, + .tls => 443, + }; + + var req: Request = .{ + .client = client, + .headers = headers, + .connection = try client.connect(url.host, port, protocol), + .redirects_left = options.max_redirects, + .response = switch (options.header_strategy) { + .dynamic => |max| Request.Response.initDynamic(max), + .static => |buf| Request.Response.initStatic(buf), + }, + }; + + { + var h = try std.BoundedArray(u8, 1000).init(0); + try h.appendSlice(@tagName(headers.method)); + try h.appendSlice(" "); + try h.appendSlice(url.path); + try h.appendSlice(" HTTP/1.1\r\nHost: "); + try h.appendSlice(url.host); + try h.appendSlice("\r\nConnection: close\r\n\r\n"); + + const header_bytes = h.slice(); + try req.connection.writeAll(header_bytes); } - req.headers.appendSliceAssumeCapacity(client.headers.items); return req; } -pub fn addHeader(client: *Client, name: []const u8, value: []const u8) !void { - const gpa = client.allocator; - try client.headers.ensureUnusedCapacity(gpa, name.len + value.len + 4); - client.headers.appendSliceAssumeCapacity(name); - client.headers.appendSliceAssumeCapacity(": "); - client.headers.appendSliceAssumeCapacity(value); - client.headers.appendSliceAssumeCapacity("\r\n"); +test { + const builtin = @import("builtin"); + const native_endian = comptime builtin.cpu.arch.endian(); + if (builtin.zig_backend == .stage2_llvm and native_endian == .Big) { + // https://github.com/ziglang/zig/issues/13782 + return error.SkipZigTest; + } + + _ = Request; } |
