From bef2b5ce4a969e436ca1d266ca47fc22a6c241ea Mon Sep 17 00:00:00 2001 From: Adam Harrison Date: Tue, 23 Jul 2024 19:08:28 -0400 Subject: Made C mutli-threading possible. --- src/lpm.c | 824 ++++++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 536 insertions(+), 288 deletions(-) diff --git a/src/lpm.c b/src/lpm.c index 15af8c2..afc2948 100644 --- a/src/lpm.c +++ b/src/lpm.c @@ -4,6 +4,7 @@ #include #include #else + #include #include #include #include @@ -14,6 +15,7 @@ #define MAX_PATH PATH_MAX #endif +#include #include #include #include @@ -55,6 +57,94 @@ #define HTTPS_RESPONSE_HEADER_BUFFER_LENGTH 8192 +typedef struct { + #if _WIN32 + HANDLE thread; + void* (*func)(void*); + void* data; + #else + pthread_t thread; + #endif +} thread_t; + +typedef struct { + #if _WIN32 + HANDLE mutex; + #else + pthread_mutex_t mutex; + #endif +} mutex_t; + +static mutex_t* new_mutex() { + mutex_t* mutex = malloc(sizeof(mutex_t)); + #if _WIN32 + mutex->mutex = CreateMutex(NULL, FALSE, NULL); + #else + pthread_mutex_init(&mutex->mutex, NULL); + #endif + return mutex; +} + +static void free_mutex(mutex_t* mutex) { + #if _WIN32 + CloseHandle(mutex->mutex); + #else + pthread_mutex_destroy(&mutex->mutex); + #endif + free(mutex); +} + +static void lock_mutex(mutex_t* mutex) { + #if _WIN32 + WaitForSingleObject(mutex->mutex, INFINITE); + #else + pthread_mutex_lock(&mutex->mutex); + #endif +} + +static void unlock_mutex(mutex_t* mutex) { + #if _WIN32 + ReleaseMutex(mutex->mutex); + #else + pthread_mutex_unlock(&mutex->mutex); + #endif +} + + +#if _WIN32 +static DWORD windows_thread_callback(void* data) { + thread_t* thread = data; + thread->data = thread->func(thread->data); + return 0; +} +#endif + +static thread_t* create_thread(void* (*func)(void*), void* data) { + thread_t* thread = malloc(sizeof(thread_t)); + #if _WIN32 + thread->func = func; + thread->data = data; + thread->thread = CreateThread(NULL, 0, windows_thread_callback, thread, 0, NULL); + #else + pthread_create(&thread->thread, NULL, func, data); + #endif + return thread; +} + +static void* join_thread(thread_t* thread) { + if (!thread) + return NULL; + void* retval; + #if _WIN32 + WaitForSingleObject(thread->thread, INFINITE); + #else + pthread_join(thread->thread, &retval); + #endif + free(thread); + return retval; +} + + #if _WIN32 static LPCWSTR lua_toutf16(lua_State* L, const char* str) { if (str && str[0] == 0) @@ -95,17 +185,17 @@ static const char* lua_toutf8(lua_State* L, LPCWSTR str) { } static const int luaL_win32_error(lua_State* L, DWORD error_id, const char* message, ...) { - va_list va; - va_start(va, message); - lua_pushvfstring(L, message, va); - va_end(va); - wchar_t message_buffer[2048]; - size_t size = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, error_id, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), message_buffer, 2048, NULL); - lua_pushliteral(L, ": "); - lua_toutf8(L, message_buffer); - lua_concat(L, 3); - return lua_error(L); + va_list va; + va_start(va, message); + lua_pushvfstring(L, message, va); + va_end(va); + wchar_t message_buffer[2048]; + size_t size = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_id, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), message_buffer, 2048, NULL); + lua_pushliteral(L, ": "); + lua_toutf8(L, message_buffer); + lua_concat(L, 3); + return lua_error(L); } #endif @@ -459,9 +549,23 @@ static int lpm_git_transport_certificate_check_cb(struct git_cert *cert, int val return 0; // If no_verify_ssl is enabled, basically always return 0 when this is set as callback. } -static int lpm_git_transfer_progress_cb(const git_transfer_progress *stats, void *payload) { - lua_State* L = payload; - lua_pushvalue(L, 2); + +typedef struct { + git_repository* repository; + lua_State* L; + char refspec[512]; + int depth; + int threaded; + int callback_function; + git_transfer_progress progress; + int progress_update; + int complete; + int error_code; + char data[512]; + thread_t* thread; +} fetch_context_t; + +static int lpm_fetch_callback(lua_State* L, const git_transfer_progress *stats) { lua_pushinteger(L, stats->received_bytes); lua_pushinteger(L, stats->total_objects); lua_pushinteger(L, stats->indexed_objects); @@ -469,86 +573,126 @@ static int lpm_git_transfer_progress_cb(const git_transfer_progress *stats, void lua_pushinteger(L, stats->local_objects); lua_pushinteger(L, stats->total_deltas); lua_pushinteger(L, stats->indexed_deltas); - lua_call(L, 7, 1); - int value = lua_tointeger(L, -1); + return lua_pcall(L, 7, 0, 0); +} + +static int lpm_git_transfer_progress_cb(const git_transfer_progress *stats, void *payload) { + fetch_context_t* context = (fetch_context_t*)payload; + if (!context->threaded) { + if (context->callback_function) { + lua_rawgeti(context->L, LUA_REGISTRYINDEX, context->callback_function); + lpm_fetch_callback(context->L, stats); + } + } else { + context->progress = *stats; + context->progress_update = 1; + } + return 0; +} + +static int lua_is_main_thread(lua_State* L) { + int is_main = lua_pushthread(L); lua_pop(L, 1); - return value; + return is_main; } -static int lpm_fetch(lua_State* L) { - git_init(); - git_repository* repository = luaL_checkgitrepo(L, 1); +static void* lpm_fetch_thread(void* ctx) { git_remote* remote; - if (git_remote_lookup(&remote, repository, "origin")) { - git_repository_free(repository); - return luaL_error(L, "git remote fetch error: %s", git_error_last_string()); + fetch_context_t* context = (fetch_context_t*)ctx; + int error = git_remote_lookup(&remote, context->repository, "origin"); + if (error && !context->error_code) { + snprintf(context->data, sizeof(context->data), "git remote fetch error: %s", git_error_last_string()); + context->error_code = error; + return NULL; } - const char* refspec = luaL_optstring(L, 3, NULL); git_fetch_options fetch_opts = GIT_FETCH_OPTIONS_INIT; fetch_opts.download_tags = GIT_REMOTE_DOWNLOAD_TAGS_ALL; - fetch_opts.callbacks.payload = L; + fetch_opts.callbacks.payload = context; #if (LIBGIT2_VER_MAJOR == 1 && LIBGIT2_VER_MINOR >= 7) || LIBGIT2_VER_MAJOR > 1 - fetch_opts.depth = lua_toboolean(L, 4) ? GIT_FETCH_DEPTH_FULL : 1; + fetch_opts.depth = context->depth; #endif if (no_verify_ssl) fetch_opts.callbacks.certificate_check = lpm_git_transport_certificate_check_cb; - if (lua_type(L, 2) == LUA_TFUNCTION) - fetch_opts.callbacks.transfer_progress = lpm_git_transfer_progress_cb; - git_strarray array = { (char**)&refspec, 1 }; - int error = git_remote_connect(remote, GIT_DIRECTION_FETCH, &fetch_opts.callbacks, NULL, NULL) || - git_remote_download(remote, refspec ? &array : NULL, &fetch_opts) || + fetch_opts.callbacks.transfer_progress = lpm_git_transfer_progress_cb; + char* strings[] = { context->refspec }; + git_strarray array = { strings, 1 }; + + error = git_remote_connect(remote, GIT_DIRECTION_FETCH, &fetch_opts.callbacks, NULL, NULL) || + git_remote_download(remote, context->refspec[0] ? &array : NULL, &fetch_opts) || git_remote_update_tips(remote, &fetch_opts.callbacks, fetch_opts.update_fetchhead, fetch_opts.download_tags, NULL); - if (!error) { + if (!error && !context->error_code) { git_buf branch_name = {0}; if (!git_remote_default_branch(&branch_name, remote)) { - lua_pushlstring(L, branch_name.ptr, branch_name.size); + strncpy(context->data, branch_name.ptr, sizeof(context->data)); git_buf_dispose(&branch_name); - } else { - lua_pushnil(L); } } git_remote_disconnect(remote); git_remote_free(remote); - git_repository_free(repository); - if (error) - return luaL_error(L, "git remote fetch error: %s", git_error_last_string()); - if (lua_type(L, 2) == LUA_TFUNCTION) { - lua_pushvalue(L, 2); - lua_pushboolean(L, 1); - lua_pushvalue(L, -3); - lua_call(L, 2, 0); + if (error && !context->error_code) { + snprintf(context->data, sizeof(context->data), "git remote fetch error: %s", git_error_last_string()); + context->error_code = error; } - return 1; + context->complete = 1; + return NULL; } -static int mbedtls_snprintf(int mbedtls, char* buffer, int len, int status, const char* str, ...) { - char mbed_buffer[256]; - mbedtls_strerror(status, mbed_buffer, sizeof(mbed_buffer)); - int error_len = mbedtls ? strlen(mbed_buffer) : strlen(strerror(status)); - va_list va; - int offset = 0; - va_start(va, str); - offset = vsnprintf(buffer, len, str, va); - va_end(va); - if (offset < len - 2) { - strcat(buffer, ": "); - if (offset < len - error_len - 2) - strcat(buffer, mbedtls ? mbed_buffer : strerror(status)); + +static int lpm_fetchk(lua_State* L, int status, lua_KContext ctx) { + lua_rawgeti(L, LUA_REGISTRYINDEX, (int)ctx); + fetch_context_t* context = lua_touserdata(L, -1); + lua_pop(L, 1); + if (context->threaded && !context->error_code && context->callback_function && context->progress_update) { + lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function); + context->error_code = lpm_fetch_callback(L, &context->progress); + if (context->error_code) + strncpy(context->data, lua_tostring(L, -1), sizeof(context->data)); } - return strlen(buffer); + if (context->complete || context->error_code) { + join_thread(context->thread); + git_repository_free(context->repository); + lua_pushstring(L, context->data[0] == 0 ? NULL : context->data); + if (context->callback_function) + luaL_unref(L, LUA_REGISTRYINDEX, context->callback_function); + luaL_unref(L, LUA_REGISTRYINDEX, (int)ctx); + if (context->error_code) { + return lua_error(L); + } + return 1; + } + assert(context->threaded); + return lua_yieldk(L, 0, (lua_KContext)ctx, lpm_fetchk); } -static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) { - char vsnbuffer[1024]; - char mbed_buffer[128]; - mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer)); - va_list va; - va_start(va, str); - vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va); - va_end(va); - return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer); + +static int lpm_fetch(lua_State* L) { + git_init(); + int args = lua_gettop(L); + fetch_context_t* context = lua_newuserdata(L, sizeof(fetch_context_t)); + memset(context, 0, sizeof(fetch_context_t)); + context->repository = luaL_checkgitrepo(L, 1); + const char* refspec = args >= 3 ? luaL_optstring(L, 3, NULL) : NULL; + context->depth = args >= 4 && lua_toboolean(L, 4) ? GIT_FETCH_DEPTH_FULL : 1; + context->L = L; + context->threaded = !lua_is_main_thread(L); + if (refspec) + strncpy(context->refspec, refspec, sizeof(context->refspec)); + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + context->callback_function = luaL_ref(L, LUA_REGISTRYINDEX); + } + int ctx = luaL_ref(L, LUA_REGISTRYINDEX); + if (lua_is_main_thread(L)) { + lpm_fetch_thread(context); + lpm_fetchk(L, 0, ctx); + return 0; + } else { + context->thread = create_thread(lpm_fetch_thread, context); + return lua_yieldk(L, 0, (lua_KContext)ctx, lpm_fetchk); + } } + static void lpm_tls_debug(void *ctx, int level, const char *file, int line, const char *str) { fprintf(stderr, "%s:%04d: |%d| %s", file, line, level, str); fflush(stderr); @@ -564,6 +708,19 @@ static int lpm_trace(lua_State* L) { return 0; } + +static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) { + char vsnbuffer[1024]; + char mbed_buffer[128]; + mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer)); + va_list va; + va_start(va, str); + vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va); + va_end(va); + return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer); +} + + static int lpm_certs(lua_State* L) { const char* type = luaL_checkstring(L, 1); int status; @@ -998,18 +1155,6 @@ static int lpm_extract(lua_State* L) { } -static int lpm_socket_write(int fd, const char* buf, int len, mbedtls_ssl_context* ctx) { - if (ctx) - return mbedtls_ssl_write(ctx, buf, len); - return write(fd, buf, len); -} - -static int lpm_socket_read(int fd, char* buf, int len, mbedtls_ssl_context* ctx) { - if (ctx) - return mbedtls_ssl_read(ctx, buf, len); - return read(fd, buf, len); -} - static int strncicmp(const char* a, const char* b, int n) { for (int i = 0; i < n; ++i) { if (a[i] == 0 && b[i] != 0) return -1; @@ -1052,240 +1197,343 @@ static const char* get_header(const char* buffer, const char* header, int* len) static int imin(int a, int b) { return a < b ? a : b; } static int imax(int a, int b) { return a > b ? a : b; } -static int lpm_get(lua_State* L) { - long response_code; - char err[1024] = {0}; - const char* protocol = luaL_checkstring(L, 1); - const char* hostname = luaL_checkstring(L, 2); - - int s = -2; - mbedtls_net_context net_context; - mbedtls_ssl_context ssl_context; - mbedtls_ssl_context* ssl_ctx = NULL; - mbedtls_net_context* net_ctx = NULL; - FILE* file = NULL; - if (strcmp(protocol, "https") == 0) { - int status; - const char* port = lua_tostring(L, 3); - // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30 - ssl_ctx = &ssl_context; - mbedtls_ssl_init(&ssl_context); - if ((status = mbedtls_ssl_setup(&ssl_context, &ssl_config)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't set up ssl for %s: %d", hostname, status); goto cleanup; - } - net_ctx = &net_context; - mbedtls_net_init(&net_context); - mbedtls_net_set_block(&net_context); - mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, NULL, mbedtls_net_recv_timeout); - if ((status = mbedtls_net_connect(&net_context, hostname, port, MBEDTLS_NET_PROTO_TCP)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't connect to hostname %s", hostname); goto cleanup; - } else if ((status = mbedtls_ssl_set_hostname(&ssl_context, hostname)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't set hostname %s", hostname); goto cleanup; - } else if ((status = mbedtls_ssl_handshake(&ssl_context)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't handshake with %s", hostname); goto cleanup; - } else if (((status = mbedtls_ssl_get_verify_result(&ssl_context)) != 0) && !no_verify_ssl) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't verify result for %s", hostname); goto cleanup; - } - } else { - int port = luaL_checkinteger(L, 3); - struct hostent *host = gethostbyname(hostname); - struct sockaddr_in dest_addr = {0}; - if (!host) - return luaL_error(L, "can't resolve hostname %s", hostname); - s = socket(AF_INET, SOCK_STREAM, 0); - #ifdef _WIN32 - DWORD timeout = 5 * 1000; - setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout); - #else - struct timeval tv; - tv.tv_sec = 5; - tv.tv_usec = 0; - setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv); - #endif - dest_addr.sin_family = AF_INET; - dest_addr.sin_port = htons(port); - dest_addr.sin_addr.s_addr = *(long*)(host->h_addr); - const char* ip = inet_ntoa(dest_addr.sin_addr); - if (connect(s, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)) == -1 ) { - close(s); - return luaL_error(L, "can't connect to host %s [%s] on port %d", hostname, ip, port); - } - } +typedef enum { + STATE_CONNECT, + STATE_HANDSHAKE, + STATE_SEND, + STATE_RECV_HEADER, + STATE_RECV_BODY +} get_state_e; + +typedef struct { + get_state_e state; + int s; + int is_ssl; + mbedtls_ssl_context ssl; + mbedtls_net_context net; + int lua_buffer; + FILE* file; + char address[1024]; + int error_code; + char error[256]; + char hostname[256]; + char rest[2048]; + int callback_function; - const char* rest = luaL_checkstring(L, 4); char buffer[HTTPS_RESPONSE_HEADER_BUFFER_LENGTH]; - int buffer_length = snprintf(buffer, sizeof(buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", rest, hostname); - buffer_length = lpm_socket_write(s, buffer, buffer_length, ssl_ctx); - if (buffer_length < 0) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't write to socket %s", hostname); goto cleanup; - } - const char* header_end = NULL; - - - buffer_length = 0; - while (!header_end && buffer_length < sizeof(buffer) - 1) { - int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length - 1, ssl_ctx); - if (length < 0) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't read from socket %s", hostname); goto cleanup; - } else if (length > 0) { - buffer_length += length; - buffer[buffer_length] = 0; - header_end = strstr(buffer, "\r\n\r\n"); - } - } - if (!header_end) { - snprintf(err, sizeof(err), "can't parse response headers for %s://%s%s: %s", protocol, hostname, rest, buffer_length >= sizeof(buffer) - 1 ? "response header buffer length exceeded" : "malformed response"); - goto cleanup; - } - header_end += 4; - const char* protocol_end = strstr(buffer, " "); - int code = atoi(protocol_end + 1); - if (code != 200) { - if (code >= 301 && code <= 303) { - int len; - const char* location = get_header(buffer, "location", &len); - if (location) { - lua_pushnil(L); - lua_newtable(L); - lua_pushlstring(L, location, len); - lua_setfield(L, -2, "location"); - } else - snprintf(err, sizeof(err), "received invalid %d-response from %s://%s%s: %d", code, protocol, hostname, rest, code); - goto cleanup; - } else { - snprintf(err, sizeof(err), "received non 200-response from %s://%s%s: %d", protocol, hostname, rest, code); goto cleanup; + int buffer_length; + + int content_length; + int chunk_length; + int chunked; + int chunk_written; + int total_downloaded; +} get_context_t; + + +static int lpm_socket_write(get_context_t* context, int len) { + return context->is_ssl ? mbedtls_ssl_write(&context->ssl, context->buffer, len) : write(context->s, context->buffer, len); +} + +static int lpm_socket_read(get_context_t* context, int len) { + if (len == -1) + len = sizeof(context->buffer) - context->buffer_length; + if (len == 0) + return len; + len = context->is_ssl ? mbedtls_ssl_read(&context->ssl, &context->buffer[context->buffer_length], len) : read(context->s, &context->buffer[context->buffer_length], len); + if (len > 0) + context->buffer_length += len; + return len; +} + + +static int lpm_get_error(get_context_t* context, int error_code, const char* str, ...) { + if (error_code) { + context->error_code = error_code; + char mbed_buffer[256]; + mbedtls_strerror(error_code, mbed_buffer, sizeof(mbed_buffer)); + int error_len = context->is_ssl ? strlen(mbed_buffer) : strlen(strerror(error_code)); + va_list va; + int offset = 0; + va_start(va, str); + offset = vsnprintf(context->buffer, sizeof(context->buffer), str, va); + va_end(va); + if (offset < sizeof(context->buffer) - 2) { + strcat(context->buffer, ": "); + if (offset < sizeof(context->buffer) - error_len - 2) + strcat(context->buffer, context->is_ssl ? mbed_buffer : strerror(error_code)); } } - const char* transfer_encoding = get_header(buffer, "transfer-encoding", NULL); - int chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0; - const char* content_length_value = get_header(buffer, "content-length", NULL); - int content_length = content_length_value ? atoi(content_length_value) : -1; - const char* path = luaL_optstring(L, 5, NULL); - int callback_function = lua_type(L, 6) == LUA_TFUNCTION ? 6 : 0; + return error_code; +} - buffer_length -= (header_end - buffer); - if (buffer_length > 0) - memmove(buffer, header_end, buffer_length); +static int lpm_set_error(get_context_t* context, const char* str, ...) { + va_list va; + int offset = 0; + va_start(va, str); + offset = vsnprintf(context->error, sizeof(context->error), str, va); + va_end(va); + context->error_code = -1; + return offset; +} - int total_downloaded = 0; - int chunk_length = !chunked && content_length == -1 ? INT_MAX : content_length; - int chunk_written = 0; - luaL_Buffer B; - if (path) { - file = lua_fopen(L, path, "wb"); - if (!file) { - snprintf(err, sizeof(err), "can't open file %s: %s", path, strerror(errno)); - goto cleanup; +static int lpm_getk(lua_State* L, int status, lua_KContext ctx) { + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx); + get_context_t* context = (get_context_t*)lua_touserdata(L, -1); + lua_pop(L,1); + switch (context->state) { + case STATE_HANDSHAKE: { + int status = mbedtls_ssl_handshake(&context->ssl); + if (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE) + return lua_yieldk(L, 0, ctx, lpm_getk); + if ( + lpm_get_error(context, status, "can't handshake") || + lpm_get_error(context, mbedtls_ssl_get_verify_result(&context->ssl), "can't verify result") + ) + goto cleanup; + context->state = STATE_SEND; } - } else - luaL_buffinit(L, &B); - while (1) { - // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk. - while (chunk_length == -1) { - char* newline = (char*)strnstr_local(buffer, "\r\n", buffer_length); - if (newline) { - *newline = '\0'; - if (sscanf(buffer, "%x", &chunk_length) != 1) { - snprintf(err, sizeof(err), "error retrieving chunk length for %s://%s%s", protocol, hostname, rest); - goto cleanup; - } - if (chunk_length == 0) - goto finish; - buffer_length -= (newline + 2 - buffer); - if (buffer_length > 0) - memmove(buffer, newline + 2, buffer_length); - chunk_written = 0; - } else if (buffer_length >= sizeof(buffer)) { - snprintf(err, sizeof(err), "can't find chunk length for %s://%s%s", protocol, hostname, rest); + case STATE_SEND: { + context->buffer_length = snprintf(context->buffer, sizeof(context->buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", context->rest, context->hostname); + int length = lpm_socket_write(context, context->buffer_length); + if (length < context->buffer_length && lpm_get_error(context, length, "can't write to socket")) goto cleanup; - } else { - int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length, ssl_ctx); - if (length <= 0 || (ssl_ctx && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s://%s%s", protocol, hostname, rest); + context->buffer_length = 0; + context->buffer[0] = 0; + context->state = STATE_RECV_HEADER; + } + case STATE_RECV_HEADER: { + const char* header_end; + while (1) { + header_end = strstr(context->buffer, "\r\n\r\n"); + if (!header_end && context->buffer_length >= sizeof(context->buffer) - 1 && lpm_set_error(context, "response header buffer length exceeded")) goto cleanup; + if (!header_end) { + int length = lpm_socket_read(context, -1); + if (length < 0 && lpm_get_error(context, length, "can't read from socket")) + goto cleanup; + if (length == 0) + return lua_yieldk(L, 0, ctx, lpm_getk); + } else { + header_end += 4; + const char* protocol_end = strnstr_local(context->buffer, " ", context->buffer_length); + int code = atoi(protocol_end + 1); + if (code != 200) { + if (code >= 301 && code <= 303) { + const char* location = get_header(context->buffer, "location", &context->buffer_length); + if (location) { + lua_pushnil(L); + lua_newtable(L); + lua_pushlstring(L, location, context->buffer_length); + lua_setfield(L, -2, "location"); + } else + lpm_set_error(context, "received invalid %d-response", code); + } else + lpm_set_error(context, "received non 200-response of %d", code); + goto report; + } + const char* transfer_encoding = get_header(context->buffer, "transfer-encoding", NULL); + context->chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0; + const char* content_length_value = get_header(context->buffer, "content-length", NULL); + context->content_length = content_length_value ? atoi(content_length_value) : -1; + context->buffer_length -= (header_end - context->buffer); + if (context->buffer_length > 0) + memmove(context->buffer, header_end, context->buffer_length); + context->chunk_length = !context->chunked && context->content_length == -1 ? INT_MAX : context->content_length; + context->state = STATE_RECV_BODY; + break; } - buffer_length += length; } } - if (buffer_length > 0) { - int to_write = imin(chunk_length - chunk_written, buffer_length); - if (to_write > 0) { - total_downloaded += to_write; - chunk_written += to_write; - if (callback_function) { - lua_pushvalue(L, callback_function); - lua_pushinteger(L, total_downloaded); - if (content_length == -1) - lua_pushnil(L); - else - lua_pushinteger(L, content_length); - lua_call(L, 2, 0); - } - if (file) - fwrite(buffer, sizeof(char), to_write, file); - else - luaL_addlstring(&B, buffer, to_write); - buffer_length -= to_write; - if (buffer_length > 0) - memmove(buffer, &buffer[to_write], buffer_length); - } - if (chunk_written == chunk_length) { - if (!chunked) - goto finish; - if (buffer_length >= 2) { - if (!strnstr_local(buffer, "\r\n", 2)) { - snprintf(err, sizeof(err), "invalid end to chunk for %s://%s%s", protocol, hostname, rest); + case STATE_RECV_BODY: { + while (1) { + // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk. + while (context->chunk_length == -1) { + char* newline = (char*)strnstr_local(context->buffer, "\r\n", context->buffer_length); + if (newline) { + *newline = '\0'; + if ((sscanf(context->buffer, "%x", &context->chunk_length) != 1 && lpm_set_error(context, "error retrieving chunk length"))) + goto cleanup; + else if (context->chunk_length == 0) + goto finish; + context->buffer_length -= (newline + 2 - context->buffer); + if (context->buffer_length > 0) + memmove(context->buffer, newline + 2, context->buffer_length); + } else if (context->buffer_length >= sizeof(context->buffer) && lpm_set_error(context, "can't find chunk length")) { goto cleanup; + } else { + int length = lpm_socket_read(context, -1); + if ((length <= 0 || (context->is_ssl && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) && lpm_get_error(context, length, "error retrieving full repsonse")) + goto cleanup; + if (length == 0) + return lua_yieldk(L, 0, ctx, lpm_getk); } - memmove(buffer, &buffer[2], buffer_length - 2); - buffer_length -= 2; - chunk_length = -1; + } + if (context->buffer_length > 0) { + int to_write = imin(context->chunk_length - context->chunk_written, context->buffer_length); + if (to_write > 0) { + context->total_downloaded += to_write; + context->chunk_written += to_write; + if (context->callback_function) { + lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function); + lua_pushinteger(L, context->total_downloaded); + if (context->content_length == -1) + lua_pushnil(L); + else + lua_pushinteger(L, context->content_length); + lua_call(L, 2, 0); + } + if (context->file) + fwrite(context->buffer, sizeof(char), to_write, context->file); + else { + lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer); + lua_pushlstring(L, context->buffer, to_write); + lua_rawseti(L, -2, lua_rawlen(L, -2) + 1); + lua_pop(L, 1); + } + context->buffer_length -= to_write; + if (context->buffer_length > 0) + memmove(context->buffer, &context->buffer[to_write], context->buffer_length); + } + if (context->chunk_written == context->chunk_length) { + if (!context->chunked) + goto finish; + if (context->buffer_length >= 2) { + if (!strnstr_local(context->buffer, "\r\n", 2) && lpm_set_error(context, "invalid end to chunk")) + goto cleanup; + memmove(context->buffer, &context->buffer[2], context->buffer_length - 2); + context->buffer_length -= 2; + context->chunk_length = -1; + } + } + } + if (context->chunk_length > 0) { + int length = lpm_socket_read(context, imin(sizeof(context->buffer) - context->buffer_length, context->chunk_length - context->chunk_written + (context->chunked ? 2 : 0))); + if ((!context->is_ssl && length == 0) || (context->is_ssl && context->content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) + goto finish; + if (length < 0 && lpm_get_error(context, length, "error retrieving full chunk")) + goto cleanup; + if (length == 0) + return lua_yieldk(L, 0, ctx, lpm_getk); } } } - if (chunk_length > 0) { - int length = lpm_socket_read(s, &buffer[buffer_length], imin(sizeof(buffer) - buffer_length, chunk_length - chunk_written + (chunked ? 2 : 0)), ssl_ctx); - if ((!ssl_ctx && length == 0) || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) - goto finish; - if (length <= 0) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full chunk for %s://%s%s", protocol, hostname, rest); - goto cleanup; - } - buffer_length += length; - } } finish: - if (file) { - fclose(file); - file = NULL; - lua_pushnil(L); - } else { - luaL_pushresult(&B); - } - if (content_length != -1 && total_downloaded != content_length) { - snprintf(err, sizeof(err), "error retrieving full response for %s://%s%s", protocol, hostname, rest); - goto cleanup; - } - if (callback_function) { - lua_pushvalue(L, callback_function); - lua_pushboolean(L, 1); - lua_call(L, 1, 0); + if (context->file) { + lua_pushnil(L); + lua_newtable(L); + } else { + lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer); + size_t len = lua_rawlen(L, -1); + luaL_Buffer b; + int table = lua_gettop(L); + luaL_buffinit(L, &b); + for (int i = 1; i <= len; ++i) { + lua_rawgeti(L, table, i); + size_t str_len; + const char* str = lua_tolstring(L, -1, &str_len); + lua_pop(L, 1); + luaL_addlstring(&b, str, str_len); } + lua_pop(L, 1); + luaL_pushresult(&b); lua_newtable(L); + } + if (context->content_length != -1 && context->total_downloaded != context->content_length && lpm_set_error(context, "error retrieving full response")) + goto cleanup; + report: + if (context->callback_function && !context->error_code) { + lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function); + lua_pushboolean(L, 1); + lua_call(L, 1, 0); + } cleanup: - if (ssl_ctx) - mbedtls_ssl_free(ssl_ctx); - if (net_ctx) - mbedtls_net_free(net_ctx); - if (file) - fclose(file); - if (s != -2) - close(s); - if (err[0]) - return luaL_error(L, "%s", err); + if (context->is_ssl) { + mbedtls_ssl_free(&context->ssl); + mbedtls_net_free(&context->net); + } else { + close(context->s); + } + if (context->callback_function) + luaL_unref(L, LUA_REGISTRYINDEX, context->callback_function); + if (context->file) + fclose(context->file); + else + luaL_unref(L, LUA_REGISTRYINDEX, context->lua_buffer); + if (context->error_code) + return luaL_error(L, "%s", context->error); return 2; } +static int lpm_get(lua_State* L) { + get_context_t* context = lua_newuserdata(L, sizeof(get_context_t)); + memset(context, 0, sizeof(get_context_t)); + int threaded = !lua_is_main_thread(L); + + const char* protocol = luaL_checkstring(L, 1); + strncpy(context->hostname, luaL_checkstring(L, 2), sizeof(context->hostname)); + strncpy(context->rest, luaL_checkstring(L, 4), sizeof(context->rest)); + const char* path = luaL_optstring(L, 5, NULL); + if (path) { + if ((context->file = lua_fopen(L, path, "wb")) == NULL) + return luaL_error(L, "can't open file %s: %s", path, strerror(errno)); + } else { + lua_newtable(L); + context->lua_buffer = luaL_ref(L, LUA_REGISTRYINDEX); + } + if (lua_type(L, 6) == LUA_TFUNCTION) { + lua_pushvalue(L, 6); + context->callback_function = luaL_ref(L, LUA_REGISTRYINDEX); + } + context->state = STATE_CONNECT; + + if (strcmp(protocol, "https") == 0) { + const char* port = lua_tostring(L, 3); + // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30 + mbedtls_ssl_init(&context->ssl); + mbedtls_net_init(&context->net); + if (threaded) + mbedtls_net_set_nonblock(&context->net); + else + mbedtls_net_set_block(&context->net); + mbedtls_ssl_set_bio(&context->ssl, &context->net, mbedtls_net_send, mbedtls_net_recv, NULL); + if ( + lpm_get_error(context, mbedtls_ssl_setup(&context->ssl, &ssl_config), "can't set up ssl") || + lpm_get_error(context, mbedtls_net_connect(&context->net, context->hostname, port, MBEDTLS_NET_PROTO_TCP), "can't set hostname") || + lpm_get_error(context, mbedtls_ssl_set_hostname(&context->ssl, context->hostname), "can't set hostname") + ) { + mbedtls_ssl_free(&context->ssl); + mbedtls_net_free(&context->net); + return luaL_error(L, "%s", context->error); + } + context->is_ssl = 1; + context->state = STATE_HANDSHAKE; + } else { + int port = luaL_checkinteger(L, 3); + struct hostent *host = gethostbyname(context->hostname); + struct sockaddr_in dest_addr = {0}; + if (!host) + return luaL_error(L, "can't resolve hostname %s", context->hostname); + context->s = socket(AF_INET, SOCK_STREAM, 0); + if (threaded) + fcntl(context->s, F_SETFL, fcntl(context->s, F_GETFL, 0) | O_NONBLOCK); + dest_addr.sin_family = AF_INET; + dest_addr.sin_port = htons(port); + dest_addr.sin_addr.s_addr = *(long*)(host->h_addr); + const char* ip = inet_ntoa(dest_addr.sin_addr); + if (connect(context->s, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)) == -1 ) { + close(context->s); + return luaL_error(L, "can't connect to host %s [%s] on port %d", context->hostname, ip, port); + } + context->state = STATE_SEND; + } + if (!threaded) + return lpm_getk(L, 0, luaL_ref(L, LUA_REGISTRYINDEX)); + return lua_yieldk(L, 0, luaL_ref(L, LUA_REGISTRYINDEX), lpm_getk); +} + + static int lpm_chdir(lua_State* L) { #ifdef _WIN32 if (_wchdir(lua_toutf16(L, luaL_checkstring(L, 1)))) -- cgit v1.2.3