From 6c7bd5e6956d70d8b1713af2e389a44c59f3eb0c Mon Sep 17 00:00:00 2001 From: Adam Harrison Date: Mon, 11 Mar 2024 17:10:32 -0400 Subject: Added in chunked transfer encoding. --- src/lpm.c | 167 +++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 116 insertions(+), 51 deletions(-) (limited to 'src/lpm.c') diff --git a/src/lpm.c b/src/lpm.c index 78e3b3e..477dc98 100644 --- a/src/lpm.c +++ b/src/lpm.c @@ -1014,6 +1014,15 @@ static int strncicmp(const char* a, const char* b, int n) { return 0; } +static const char* strnstr(const char* haystack, const char* needle, int n) { + int len = strlen(needle); + for (int i = 0; i <= n - len; ++i) { + if (strncmp(&haystack[i], needle, len) == 0) + return &haystack[i]; + } + return NULL; +} + static const char* get_header(const char* buffer, const char* header, int* len) { const char* line_end = strstr(buffer, "\r\n"); const char* header_end = strstr(buffer, "\r\n\r\n"); @@ -1032,6 +1041,9 @@ static const char* get_header(const char* buffer, const char* header, int* len) return NULL; } +static int imin(int a, int b) { return a < b ? a : b; } +static int imax(int a, int b) { return a > b ? a : b; } + static int lpm_get(lua_State* L) { long response_code; char err[1024] = {0}; @@ -1043,6 +1055,7 @@ static int lpm_get(lua_State* L) { mbedtls_ssl_context ssl_context; mbedtls_ssl_context* ssl_ctx = NULL; mbedtls_net_context* net_ctx = NULL; + FILE* file = NULL; if (strcmp(protocol, "https") == 0) { int status; const char* port = lua_tostring(L, 3); @@ -1100,6 +1113,8 @@ static int lpm_get(lua_State* L) { } int bytes_read = 0; const char* header_end = NULL; + + while (!header_end && bytes_read < sizeof(buffer) - 1) { buffer_length = lpm_socket_read(s, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx); if (buffer_length < 0) { @@ -1133,75 +1148,125 @@ static int lpm_get(lua_State* L) { snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code); goto cleanup; } } + const char* transfer_encoding = get_header(buffer, "transfer-encoding", NULL); + int chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0; const char* content_length_value = get_header(buffer, "content-length", NULL); - int content_length = -1; - if (content_length_value) - content_length = atoi(content_length_value); + int content_length = content_length_value ? atoi(content_length_value) : -1; const char* path = luaL_optstring(L, 5, NULL); int callback_function = lua_type(L, 6) == LUA_TFUNCTION ? 6 : 0; - int body_length = buffer_length - (header_end - buffer); - int total_downloaded = body_length; - int remaining = content_length - body_length; + buffer_length -= (header_end - buffer); + if (buffer_length > 0) + memmove(buffer, header_end, buffer_length); + + int total_downloaded = 0; + int chunk_length = !chunked && content_length == -1 ? INT_MAX : content_length; + int chunk_written = 0; + luaL_Buffer B; if (path) { - FILE* file = lua_fopen(L, path, "wb"); + file = lua_fopen(L, path, "wb"); if (!file) { - snprintf(err, sizeof(err), "can't open file %s: %s", path, strerror(errno)); goto cleanup; + snprintf(err, sizeof(err), "can't open file %s: %s", path, strerror(errno)); + goto cleanup; } - fwrite(header_end, sizeof(char), body_length, file); - while (content_length == -1 || remaining > 0) { - int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx); - if (length == 0 || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) break; - if (length < 0) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); goto cleanup; - } - total_downloaded += length; - if (callback_function) { - lua_pushvalue(L, callback_function); - lua_pushinteger(L, total_downloaded); - lua_pushinteger(L, content_length); - lua_call(L, 2, 0); + } else + luaL_buffinit(L, &B); + while (1) { + // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk. + while (chunk_length == -1) { + const char* newline = strnstr(buffer, "\r\n", buffer_length); + if (newline) { + if (sscanf(buffer, "%x", &chunk_length) != 1) { + snprintf(err, sizeof(err), "error retrieving chunk length %s%s", hostname, rest); + goto cleanup; + } + if (chunk_length == 0) + goto finish; + buffer_length -= (newline + 2 - buffer); + if (buffer_length > 0) + memmove(buffer, newline + 2, buffer_length); + chunk_written = 0; + } else if (buffer_length >= sizeof(buffer)) { + snprintf(err, sizeof(err), "can't find chunk length for %s%s", hostname, rest); + goto cleanup; + } else { + int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length, ssl_ctx); + if (length <= 0 || (ssl_ctx && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) { + mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); + goto cleanup; + } + buffer_length += length; } - fwrite(buffer, sizeof(char), length, file); - remaining -= length; } - fclose(file); - lua_pushnil(L); - } else { - luaL_Buffer B; - luaL_buffinit(L, &B); - luaL_addlstring(&B, header_end, body_length); - while (content_length == -1 || remaining > 0) { - int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx); - if (length == 0 || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) break; - if (length < 0) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); goto cleanup; + if (buffer_length > 0) { + int to_write = imin(chunk_length - chunk_written, buffer_length); + if (to_write > 0) { + total_downloaded += to_write; + chunk_written += to_write; + if (callback_function) { + lua_pushvalue(L, callback_function); + lua_pushinteger(L, total_downloaded); + lua_pushinteger(L, content_length); + lua_call(L, 2, 0); + } + if (file) + fwrite(buffer, sizeof(char), to_write, file); + else + luaL_addlstring(&B, buffer, to_write); + buffer_length -= to_write; + if (buffer_length > 0) + memmove(buffer, &buffer[to_write], buffer_length); } - total_downloaded += length; - if (callback_function) { - lua_pushvalue(L, callback_function); - lua_pushinteger(L, total_downloaded); - lua_call(L, 1, 0); + if (chunk_written == chunk_length) { + if (!chunked) { + goto finish; + } + if (buffer_length >= 2) { + if (!strnstr(buffer, "\r\n", 2)) { + snprintf(err, sizeof(err), "invalid end to chunk for %s%s", hostname, rest); + goto cleanup; + } + memmove(buffer, &buffer[2], buffer_length - 2); + buffer_length -= 2; + chunk_length = -1; + } + } + } + if (chunk_length > 0) { + buffer_length = lpm_socket_read(s, buffer, imin(sizeof(buffer), chunk_length - chunk_written + (chunked ? 2 : 0)), ssl_ctx); + if ((!ssl_ctx && buffer_length == 0) || (ssl_ctx && content_length == -1 && buffer_length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) + goto finish; + if (buffer_length <= 0) { + mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "error retrieving full chunk for %s%s", hostname, rest); + goto cleanup; } - luaL_addlstring(&B, buffer, length); - remaining -= length; } - luaL_pushresult(&B); - } - if (content_length != -1 && remaining != 0) { - snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); goto cleanup; - } - if (callback_function) { - lua_pushvalue(L, callback_function); - lua_pushboolean(L, 1); - lua_call(L, 1, 0); } - lua_newtable(L); + finish: + if (file) { + fclose(file); + file = NULL; + lua_pushnil(L); + } else { + luaL_pushresult(&B); + } + if (content_length != -1 && total_downloaded != content_length) { + snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); + goto cleanup; + } + if (callback_function) { + lua_pushvalue(L, callback_function); + lua_pushboolean(L, 1); + lua_call(L, 1, 0); + } + lua_newtable(L); cleanup: if (ssl_ctx) mbedtls_ssl_free(ssl_ctx); if (net_ctx) mbedtls_net_free(net_ctx); + if (file) + fclose(file); if (s != -2) close(s); if (err[0]) -- cgit v1.2.3