aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/lpm.c167
1 files changed, 116 insertions, 51 deletions
diff --git a/src/lpm.c b/src/lpm.c
index 78e3b3e..477dc98 100644
--- a/src/lpm.c
+++ b/src/lpm.c
@@ -1014,6 +1014,15 @@ static int strncicmp(const char* a, const char* b, int n) {
return 0;
}
+static const char* strnstr(const char* haystack, const char* needle, int n) {
+ int len = strlen(needle);
+ for (int i = 0; i <= n - len; ++i) {
+ if (strncmp(&haystack[i], needle, len) == 0)
+ return &haystack[i];
+ }
+ return NULL;
+}
+
static const char* get_header(const char* buffer, const char* header, int* len) {
const char* line_end = strstr(buffer, "\r\n");
const char* header_end = strstr(buffer, "\r\n\r\n");
@@ -1032,6 +1041,9 @@ static const char* get_header(const char* buffer, const char* header, int* len)
return NULL;
}
+static int imin(int a, int b) { return a < b ? a : b; }
+static int imax(int a, int b) { return a > b ? a : b; }
+
static int lpm_get(lua_State* L) {
long response_code;
char err[1024] = {0};
@@ -1043,6 +1055,7 @@ static int lpm_get(lua_State* L) {
mbedtls_ssl_context ssl_context;
mbedtls_ssl_context* ssl_ctx = NULL;
mbedtls_net_context* net_ctx = NULL;
+ FILE* file = NULL;
if (strcmp(protocol, "https") == 0) {
int status;
const char* port = lua_tostring(L, 3);
@@ -1100,6 +1113,8 @@ static int lpm_get(lua_State* L) {
}
int bytes_read = 0;
const char* header_end = NULL;
+
+
while (!header_end && bytes_read < sizeof(buffer) - 1) {
buffer_length = lpm_socket_read(s, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx);
if (buffer_length < 0) {
@@ -1133,75 +1148,125 @@ static int lpm_get(lua_State* L) {
snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code); goto cleanup;
}
}
+ const char* transfer_encoding = get_header(buffer, "transfer-encoding", NULL);
+ int chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0;
const char* content_length_value = get_header(buffer, "content-length", NULL);
- int content_length = -1;
- if (content_length_value)
- content_length = atoi(content_length_value);
+ int content_length = content_length_value ? atoi(content_length_value) : -1;
const char* path = luaL_optstring(L, 5, NULL);
int callback_function = lua_type(L, 6) == LUA_TFUNCTION ? 6 : 0;
- int body_length = buffer_length - (header_end - buffer);
- int total_downloaded = body_length;
- int remaining = content_length - body_length;
+ buffer_length -= (header_end - buffer);
+ if (buffer_length > 0)
+ memmove(buffer, header_end, buffer_length);
+
+ int total_downloaded = 0;
+ int chunk_length = !chunked && content_length == -1 ? INT_MAX : content_length;
+ int chunk_written = 0;
+ luaL_Buffer B;
if (path) {
- FILE* file = lua_fopen(L, path, "wb");
+ file = lua_fopen(L, path, "wb");
if (!file) {
- snprintf(err, sizeof(err), "can't open file %s: %s", path, strerror(errno)); goto cleanup;
+ snprintf(err, sizeof(err), "can't open file %s: %s", path, strerror(errno));
+ goto cleanup;
}
- fwrite(header_end, sizeof(char), body_length, file);
- while (content_length == -1 || remaining > 0) {
- int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx);
- if (length == 0 || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) break;
- if (length < 0) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); goto cleanup;
- }
- total_downloaded += length;
- if (callback_function) {
- lua_pushvalue(L, callback_function);
- lua_pushinteger(L, total_downloaded);
- lua_pushinteger(L, content_length);
- lua_call(L, 2, 0);
+ } else
+ luaL_buffinit(L, &B);
+ while (1) {
+ // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk.
+ while (chunk_length == -1) {
+ const char* newline = strnstr(buffer, "\r\n", buffer_length);
+ if (newline) {
+ if (sscanf(buffer, "%x", &chunk_length) != 1) {
+ snprintf(err, sizeof(err), "error retrieving chunk length %s%s", hostname, rest);
+ goto cleanup;
+ }
+ if (chunk_length == 0)
+ goto finish;
+ buffer_length -= (newline + 2 - buffer);
+ if (buffer_length > 0)
+ memmove(buffer, newline + 2, buffer_length);
+ chunk_written = 0;
+ } else if (buffer_length >= sizeof(buffer)) {
+ snprintf(err, sizeof(err), "can't find chunk length for %s%s", hostname, rest);
+ goto cleanup;
+ } else {
+ int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length, ssl_ctx);
+ if (length <= 0 || (ssl_ctx && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) {
+ mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest);
+ goto cleanup;
+ }
+ buffer_length += length;
}
- fwrite(buffer, sizeof(char), length, file);
- remaining -= length;
}
- fclose(file);
- lua_pushnil(L);
- } else {
- luaL_Buffer B;
- luaL_buffinit(L, &B);
- luaL_addlstring(&B, header_end, body_length);
- while (content_length == -1 || remaining > 0) {
- int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx);
- if (length == 0 || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) break;
- if (length < 0) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); goto cleanup;
+ if (buffer_length > 0) {
+ int to_write = imin(chunk_length - chunk_written, buffer_length);
+ if (to_write > 0) {
+ total_downloaded += to_write;
+ chunk_written += to_write;
+ if (callback_function) {
+ lua_pushvalue(L, callback_function);
+ lua_pushinteger(L, total_downloaded);
+ lua_pushinteger(L, content_length);
+ lua_call(L, 2, 0);
+ }
+ if (file)
+ fwrite(buffer, sizeof(char), to_write, file);
+ else
+ luaL_addlstring(&B, buffer, to_write);
+ buffer_length -= to_write;
+ if (buffer_length > 0)
+ memmove(buffer, &buffer[to_write], buffer_length);
}
- total_downloaded += length;
- if (callback_function) {
- lua_pushvalue(L, callback_function);
- lua_pushinteger(L, total_downloaded);
- lua_call(L, 1, 0);
+ if (chunk_written == chunk_length) {
+ if (!chunked) {
+ goto finish;
+ }
+ if (buffer_length >= 2) {
+ if (!strnstr(buffer, "\r\n", 2)) {
+ snprintf(err, sizeof(err), "invalid end to chunk for %s%s", hostname, rest);
+ goto cleanup;
+ }
+ memmove(buffer, &buffer[2], buffer_length - 2);
+ buffer_length -= 2;
+ chunk_length = -1;
+ }
+ }
+ }
+ if (chunk_length > 0) {
+ buffer_length = lpm_socket_read(s, buffer, imin(sizeof(buffer), chunk_length - chunk_written + (chunked ? 2 : 0)), ssl_ctx);
+ if ((!ssl_ctx && buffer_length == 0) || (ssl_ctx && content_length == -1 && buffer_length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY))
+ goto finish;
+ if (buffer_length <= 0) {
+ mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "error retrieving full chunk for %s%s", hostname, rest);
+ goto cleanup;
}
- luaL_addlstring(&B, buffer, length);
- remaining -= length;
}
- luaL_pushresult(&B);
- }
- if (content_length != -1 && remaining != 0) {
- snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); goto cleanup;
- }
- if (callback_function) {
- lua_pushvalue(L, callback_function);
- lua_pushboolean(L, 1);
- lua_call(L, 1, 0);
}
- lua_newtable(L);
+ finish:
+ if (file) {
+ fclose(file);
+ file = NULL;
+ lua_pushnil(L);
+ } else {
+ luaL_pushresult(&B);
+ }
+ if (content_length != -1 && total_downloaded != content_length) {
+ snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest);
+ goto cleanup;
+ }
+ if (callback_function) {
+ lua_pushvalue(L, callback_function);
+ lua_pushboolean(L, 1);
+ lua_call(L, 1, 0);
+ }
+ lua_newtable(L);
cleanup:
if (ssl_ctx)
mbedtls_ssl_free(ssl_ctx);
if (net_ctx)
mbedtls_net_free(net_ctx);
+ if (file)
+ fclose(file);
if (s != -2)
close(s);
if (err[0])