aboutsummaryrefslogtreecommitdiff
path: root/src/lpm.c
diff options
context:
space:
mode:
authorAdam Harrison <adamdharrison@gmail.com>2024-07-23 19:08:28 -0400
committerAdam Harrison <adamdharrison@gmail.com>2024-07-23 19:11:18 -0400
commitbef2b5ce4a969e436ca1d266ca47fc22a6c241ea (patch)
treec5f0c06adc30416daa5c28eb6854501cddb06dfd /src/lpm.c
parent1564bb1ca6570868bbf45c896d679ac45c73fbdc (diff)
downloadlite-xl-plugin-manager-bef2b5ce4a969e436ca1d266ca47fc22a6c241ea.tar.gz
lite-xl-plugin-manager-bef2b5ce4a969e436ca1d266ca47fc22a6c241ea.zip
Made C mutli-threading possible.
Diffstat (limited to 'src/lpm.c')
-rw-r--r--src/lpm.c824
1 files changed, 536 insertions, 288 deletions
diff --git a/src/lpm.c b/src/lpm.c
index 15af8c2..afc2948 100644
--- a/src/lpm.c
+++ b/src/lpm.c
@@ -4,6 +4,7 @@
#include <windows.h>
#include <fileapi.h>
#else
+ #include <pthread.h>
#include <netdb.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
@@ -14,6 +15,7 @@
#define MAX_PATH PATH_MAX
#endif
+#include <assert.h>
#include <git2.h>
#include <string.h>
#include <stdio.h>
@@ -55,6 +57,94 @@
#define HTTPS_RESPONSE_HEADER_BUFFER_LENGTH 8192
+typedef struct {
+ #if _WIN32
+ HANDLE thread;
+ void* (*func)(void*);
+ void* data;
+ #else
+ pthread_t thread;
+ #endif
+} thread_t;
+
+typedef struct {
+ #if _WIN32
+ HANDLE mutex;
+ #else
+ pthread_mutex_t mutex;
+ #endif
+} mutex_t;
+
+static mutex_t* new_mutex() {
+ mutex_t* mutex = malloc(sizeof(mutex_t));
+ #if _WIN32
+ mutex->mutex = CreateMutex(NULL, FALSE, NULL);
+ #else
+ pthread_mutex_init(&mutex->mutex, NULL);
+ #endif
+ return mutex;
+}
+
+static void free_mutex(mutex_t* mutex) {
+ #if _WIN32
+ CloseHandle(mutex->mutex);
+ #else
+ pthread_mutex_destroy(&mutex->mutex);
+ #endif
+ free(mutex);
+}
+
+static void lock_mutex(mutex_t* mutex) {
+ #if _WIN32
+ WaitForSingleObject(mutex->mutex, INFINITE);
+ #else
+ pthread_mutex_lock(&mutex->mutex);
+ #endif
+}
+
+static void unlock_mutex(mutex_t* mutex) {
+ #if _WIN32
+ ReleaseMutex(mutex->mutex);
+ #else
+ pthread_mutex_unlock(&mutex->mutex);
+ #endif
+}
+
+
+#if _WIN32
+static DWORD windows_thread_callback(void* data) {
+ thread_t* thread = data;
+ thread->data = thread->func(thread->data);
+ return 0;
+}
+#endif
+
+static thread_t* create_thread(void* (*func)(void*), void* data) {
+ thread_t* thread = malloc(sizeof(thread_t));
+ #if _WIN32
+ thread->func = func;
+ thread->data = data;
+ thread->thread = CreateThread(NULL, 0, windows_thread_callback, thread, 0, NULL);
+ #else
+ pthread_create(&thread->thread, NULL, func, data);
+ #endif
+ return thread;
+}
+
+static void* join_thread(thread_t* thread) {
+ if (!thread)
+ return NULL;
+ void* retval;
+ #if _WIN32
+ WaitForSingleObject(thread->thread, INFINITE);
+ #else
+ pthread_join(thread->thread, &retval);
+ #endif
+ free(thread);
+ return retval;
+}
+
+
#if _WIN32
static LPCWSTR lua_toutf16(lua_State* L, const char* str) {
if (str && str[0] == 0)
@@ -95,17 +185,17 @@ static const char* lua_toutf8(lua_State* L, LPCWSTR str) {
}
static const int luaL_win32_error(lua_State* L, DWORD error_id, const char* message, ...) {
- va_list va;
- va_start(va, message);
- lua_pushvfstring(L, message, va);
- va_end(va);
- wchar_t message_buffer[2048];
- size_t size = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
- NULL, error_id, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), message_buffer, 2048, NULL);
- lua_pushliteral(L, ": ");
- lua_toutf8(L, message_buffer);
- lua_concat(L, 3);
- return lua_error(L);
+ va_list va;
+ va_start(va, message);
+ lua_pushvfstring(L, message, va);
+ va_end(va);
+ wchar_t message_buffer[2048];
+ size_t size = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, error_id, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), message_buffer, 2048, NULL);
+ lua_pushliteral(L, ": ");
+ lua_toutf8(L, message_buffer);
+ lua_concat(L, 3);
+ return lua_error(L);
}
#endif
@@ -459,9 +549,23 @@ static int lpm_git_transport_certificate_check_cb(struct git_cert *cert, int val
return 0; // If no_verify_ssl is enabled, basically always return 0 when this is set as callback.
}
-static int lpm_git_transfer_progress_cb(const git_transfer_progress *stats, void *payload) {
- lua_State* L = payload;
- lua_pushvalue(L, 2);
+
+typedef struct {
+ git_repository* repository;
+ lua_State* L;
+ char refspec[512];
+ int depth;
+ int threaded;
+ int callback_function;
+ git_transfer_progress progress;
+ int progress_update;
+ int complete;
+ int error_code;
+ char data[512];
+ thread_t* thread;
+} fetch_context_t;
+
+static int lpm_fetch_callback(lua_State* L, const git_transfer_progress *stats) {
lua_pushinteger(L, stats->received_bytes);
lua_pushinteger(L, stats->total_objects);
lua_pushinteger(L, stats->indexed_objects);
@@ -469,86 +573,126 @@ static int lpm_git_transfer_progress_cb(const git_transfer_progress *stats, void
lua_pushinteger(L, stats->local_objects);
lua_pushinteger(L, stats->total_deltas);
lua_pushinteger(L, stats->indexed_deltas);
- lua_call(L, 7, 1);
- int value = lua_tointeger(L, -1);
+ return lua_pcall(L, 7, 0, 0);
+}
+
+static int lpm_git_transfer_progress_cb(const git_transfer_progress *stats, void *payload) {
+ fetch_context_t* context = (fetch_context_t*)payload;
+ if (!context->threaded) {
+ if (context->callback_function) {
+ lua_rawgeti(context->L, LUA_REGISTRYINDEX, context->callback_function);
+ lpm_fetch_callback(context->L, stats);
+ }
+ } else {
+ context->progress = *stats;
+ context->progress_update = 1;
+ }
+ return 0;
+}
+
+static int lua_is_main_thread(lua_State* L) {
+ int is_main = lua_pushthread(L);
lua_pop(L, 1);
- return value;
+ return is_main;
}
-static int lpm_fetch(lua_State* L) {
- git_init();
- git_repository* repository = luaL_checkgitrepo(L, 1);
+static void* lpm_fetch_thread(void* ctx) {
git_remote* remote;
- if (git_remote_lookup(&remote, repository, "origin")) {
- git_repository_free(repository);
- return luaL_error(L, "git remote fetch error: %s", git_error_last_string());
+ fetch_context_t* context = (fetch_context_t*)ctx;
+ int error = git_remote_lookup(&remote, context->repository, "origin");
+ if (error && !context->error_code) {
+ snprintf(context->data, sizeof(context->data), "git remote fetch error: %s", git_error_last_string());
+ context->error_code = error;
+ return NULL;
}
- const char* refspec = luaL_optstring(L, 3, NULL);
git_fetch_options fetch_opts = GIT_FETCH_OPTIONS_INIT;
fetch_opts.download_tags = GIT_REMOTE_DOWNLOAD_TAGS_ALL;
- fetch_opts.callbacks.payload = L;
+ fetch_opts.callbacks.payload = context;
#if (LIBGIT2_VER_MAJOR == 1 && LIBGIT2_VER_MINOR >= 7) || LIBGIT2_VER_MAJOR > 1
- fetch_opts.depth = lua_toboolean(L, 4) ? GIT_FETCH_DEPTH_FULL : 1;
+ fetch_opts.depth = context->depth;
#endif
if (no_verify_ssl)
fetch_opts.callbacks.certificate_check = lpm_git_transport_certificate_check_cb;
- if (lua_type(L, 2) == LUA_TFUNCTION)
- fetch_opts.callbacks.transfer_progress = lpm_git_transfer_progress_cb;
- git_strarray array = { (char**)&refspec, 1 };
- int error = git_remote_connect(remote, GIT_DIRECTION_FETCH, &fetch_opts.callbacks, NULL, NULL) ||
- git_remote_download(remote, refspec ? &array : NULL, &fetch_opts) ||
+ fetch_opts.callbacks.transfer_progress = lpm_git_transfer_progress_cb;
+ char* strings[] = { context->refspec };
+ git_strarray array = { strings, 1 };
+
+ error = git_remote_connect(remote, GIT_DIRECTION_FETCH, &fetch_opts.callbacks, NULL, NULL) ||
+ git_remote_download(remote, context->refspec[0] ? &array : NULL, &fetch_opts) ||
git_remote_update_tips(remote, &fetch_opts.callbacks, fetch_opts.update_fetchhead, fetch_opts.download_tags, NULL);
- if (!error) {
+ if (!error && !context->error_code) {
git_buf branch_name = {0};
if (!git_remote_default_branch(&branch_name, remote)) {
- lua_pushlstring(L, branch_name.ptr, branch_name.size);
+ strncpy(context->data, branch_name.ptr, sizeof(context->data));
git_buf_dispose(&branch_name);
- } else {
- lua_pushnil(L);
}
}
git_remote_disconnect(remote);
git_remote_free(remote);
- git_repository_free(repository);
- if (error)
- return luaL_error(L, "git remote fetch error: %s", git_error_last_string());
- if (lua_type(L, 2) == LUA_TFUNCTION) {
- lua_pushvalue(L, 2);
- lua_pushboolean(L, 1);
- lua_pushvalue(L, -3);
- lua_call(L, 2, 0);
+ if (error && !context->error_code) {
+ snprintf(context->data, sizeof(context->data), "git remote fetch error: %s", git_error_last_string());
+ context->error_code = error;
}
- return 1;
+ context->complete = 1;
+ return NULL;
}
-static int mbedtls_snprintf(int mbedtls, char* buffer, int len, int status, const char* str, ...) {
- char mbed_buffer[256];
- mbedtls_strerror(status, mbed_buffer, sizeof(mbed_buffer));
- int error_len = mbedtls ? strlen(mbed_buffer) : strlen(strerror(status));
- va_list va;
- int offset = 0;
- va_start(va, str);
- offset = vsnprintf(buffer, len, str, va);
- va_end(va);
- if (offset < len - 2) {
- strcat(buffer, ": ");
- if (offset < len - error_len - 2)
- strcat(buffer, mbedtls ? mbed_buffer : strerror(status));
+
+static int lpm_fetchk(lua_State* L, int status, lua_KContext ctx) {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, (int)ctx);
+ fetch_context_t* context = lua_touserdata(L, -1);
+ lua_pop(L, 1);
+ if (context->threaded && !context->error_code && context->callback_function && context->progress_update) {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function);
+ context->error_code = lpm_fetch_callback(L, &context->progress);
+ if (context->error_code)
+ strncpy(context->data, lua_tostring(L, -1), sizeof(context->data));
}
- return strlen(buffer);
+ if (context->complete || context->error_code) {
+ join_thread(context->thread);
+ git_repository_free(context->repository);
+ lua_pushstring(L, context->data[0] == 0 ? NULL : context->data);
+ if (context->callback_function)
+ luaL_unref(L, LUA_REGISTRYINDEX, context->callback_function);
+ luaL_unref(L, LUA_REGISTRYINDEX, (int)ctx);
+ if (context->error_code) {
+ return lua_error(L);
+ }
+ return 1;
+ }
+ assert(context->threaded);
+ return lua_yieldk(L, 0, (lua_KContext)ctx, lpm_fetchk);
}
-static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) {
- char vsnbuffer[1024];
- char mbed_buffer[128];
- mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer));
- va_list va;
- va_start(va, str);
- vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va);
- va_end(va);
- return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer);
+
+static int lpm_fetch(lua_State* L) {
+ git_init();
+ int args = lua_gettop(L);
+ fetch_context_t* context = lua_newuserdata(L, sizeof(fetch_context_t));
+ memset(context, 0, sizeof(fetch_context_t));
+ context->repository = luaL_checkgitrepo(L, 1);
+ const char* refspec = args >= 3 ? luaL_optstring(L, 3, NULL) : NULL;
+ context->depth = args >= 4 && lua_toboolean(L, 4) ? GIT_FETCH_DEPTH_FULL : 1;
+ context->L = L;
+ context->threaded = !lua_is_main_thread(L);
+ if (refspec)
+ strncpy(context->refspec, refspec, sizeof(context->refspec));
+ if (lua_type(L, 2) == LUA_TFUNCTION) {
+ lua_pushvalue(L, 2);
+ context->callback_function = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
+ int ctx = luaL_ref(L, LUA_REGISTRYINDEX);
+ if (lua_is_main_thread(L)) {
+ lpm_fetch_thread(context);
+ lpm_fetchk(L, 0, ctx);
+ return 0;
+ } else {
+ context->thread = create_thread(lpm_fetch_thread, context);
+ return lua_yieldk(L, 0, (lua_KContext)ctx, lpm_fetchk);
+ }
}
+
static void lpm_tls_debug(void *ctx, int level, const char *file, int line, const char *str) {
fprintf(stderr, "%s:%04d: |%d| %s", file, line, level, str);
fflush(stderr);
@@ -564,6 +708,19 @@ static int lpm_trace(lua_State* L) {
return 0;
}
+
+static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) {
+ char vsnbuffer[1024];
+ char mbed_buffer[128];
+ mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer));
+ va_list va;
+ va_start(va, str);
+ vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va);
+ va_end(va);
+ return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer);
+}
+
+
static int lpm_certs(lua_State* L) {
const char* type = luaL_checkstring(L, 1);
int status;
@@ -998,18 +1155,6 @@ static int lpm_extract(lua_State* L) {
}
-static int lpm_socket_write(int fd, const char* buf, int len, mbedtls_ssl_context* ctx) {
- if (ctx)
- return mbedtls_ssl_write(ctx, buf, len);
- return write(fd, buf, len);
-}
-
-static int lpm_socket_read(int fd, char* buf, int len, mbedtls_ssl_context* ctx) {
- if (ctx)
- return mbedtls_ssl_read(ctx, buf, len);
- return read(fd, buf, len);
-}
-
static int strncicmp(const char* a, const char* b, int n) {
for (int i = 0; i < n; ++i) {
if (a[i] == 0 && b[i] != 0) return -1;
@@ -1052,240 +1197,343 @@ static const char* get_header(const char* buffer, const char* header, int* len)
static int imin(int a, int b) { return a < b ? a : b; }
static int imax(int a, int b) { return a > b ? a : b; }
-static int lpm_get(lua_State* L) {
- long response_code;
- char err[1024] = {0};
- const char* protocol = luaL_checkstring(L, 1);
- const char* hostname = luaL_checkstring(L, 2);
-
- int s = -2;
- mbedtls_net_context net_context;
- mbedtls_ssl_context ssl_context;
- mbedtls_ssl_context* ssl_ctx = NULL;
- mbedtls_net_context* net_ctx = NULL;
- FILE* file = NULL;
- if (strcmp(protocol, "https") == 0) {
- int status;
- const char* port = lua_tostring(L, 3);
- // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30
- ssl_ctx = &ssl_context;
- mbedtls_ssl_init(&ssl_context);
- if ((status = mbedtls_ssl_setup(&ssl_context, &ssl_config)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't set up ssl for %s: %d", hostname, status); goto cleanup;
- }
- net_ctx = &net_context;
- mbedtls_net_init(&net_context);
- mbedtls_net_set_block(&net_context);
- mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, NULL, mbedtls_net_recv_timeout);
- if ((status = mbedtls_net_connect(&net_context, hostname, port, MBEDTLS_NET_PROTO_TCP)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't connect to hostname %s", hostname); goto cleanup;
- } else if ((status = mbedtls_ssl_set_hostname(&ssl_context, hostname)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't set hostname %s", hostname); goto cleanup;
- } else if ((status = mbedtls_ssl_handshake(&ssl_context)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't handshake with %s", hostname); goto cleanup;
- } else if (((status = mbedtls_ssl_get_verify_result(&ssl_context)) != 0) && !no_verify_ssl) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't verify result for %s", hostname); goto cleanup;
- }
- } else {
- int port = luaL_checkinteger(L, 3);
- struct hostent *host = gethostbyname(hostname);
- struct sockaddr_in dest_addr = {0};
- if (!host)
- return luaL_error(L, "can't resolve hostname %s", hostname);
- s = socket(AF_INET, SOCK_STREAM, 0);
- #ifdef _WIN32
- DWORD timeout = 5 * 1000;
- setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout);
- #else
- struct timeval tv;
- tv.tv_sec = 5;
- tv.tv_usec = 0;
- setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv);
- #endif
- dest_addr.sin_family = AF_INET;
- dest_addr.sin_port = htons(port);
- dest_addr.sin_addr.s_addr = *(long*)(host->h_addr);
- const char* ip = inet_ntoa(dest_addr.sin_addr);
- if (connect(s, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)) == -1 ) {
- close(s);
- return luaL_error(L, "can't connect to host %s [%s] on port %d", hostname, ip, port);
- }
- }
+typedef enum {
+ STATE_CONNECT,
+ STATE_HANDSHAKE,
+ STATE_SEND,
+ STATE_RECV_HEADER,
+ STATE_RECV_BODY
+} get_state_e;
+
+typedef struct {
+ get_state_e state;
+ int s;
+ int is_ssl;
+ mbedtls_ssl_context ssl;
+ mbedtls_net_context net;
+ int lua_buffer;
+ FILE* file;
+ char address[1024];
+ int error_code;
+ char error[256];
+ char hostname[256];
+ char rest[2048];
+ int callback_function;
- const char* rest = luaL_checkstring(L, 4);
char buffer[HTTPS_RESPONSE_HEADER_BUFFER_LENGTH];
- int buffer_length = snprintf(buffer, sizeof(buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", rest, hostname);
- buffer_length = lpm_socket_write(s, buffer, buffer_length, ssl_ctx);
- if (buffer_length < 0) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't write to socket %s", hostname); goto cleanup;
- }
- const char* header_end = NULL;
-
-
- buffer_length = 0;
- while (!header_end && buffer_length < sizeof(buffer) - 1) {
- int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length - 1, ssl_ctx);
- if (length < 0) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't read from socket %s", hostname); goto cleanup;
- } else if (length > 0) {
- buffer_length += length;
- buffer[buffer_length] = 0;
- header_end = strstr(buffer, "\r\n\r\n");
- }
- }
- if (!header_end) {
- snprintf(err, sizeof(err), "can't parse response headers for %s://%s%s: %s", protocol, hostname, rest, buffer_length >= sizeof(buffer) - 1 ? "response header buffer length exceeded" : "malformed response");
- goto cleanup;
- }
- header_end += 4;
- const char* protocol_end = strstr(buffer, " ");
- int code = atoi(protocol_end + 1);
- if (code != 200) {
- if (code >= 301 && code <= 303) {
- int len;
- const char* location = get_header(buffer, "location", &len);
- if (location) {
- lua_pushnil(L);
- lua_newtable(L);
- lua_pushlstring(L, location, len);
- lua_setfield(L, -2, "location");
- } else
- snprintf(err, sizeof(err), "received invalid %d-response from %s://%s%s: %d", code, protocol, hostname, rest, code);
- goto cleanup;
- } else {
- snprintf(err, sizeof(err), "received non 200-response from %s://%s%s: %d", protocol, hostname, rest, code); goto cleanup;
+ int buffer_length;
+
+ int content_length;
+ int chunk_length;
+ int chunked;
+ int chunk_written;
+ int total_downloaded;
+} get_context_t;
+
+
+static int lpm_socket_write(get_context_t* context, int len) {
+ return context->is_ssl ? mbedtls_ssl_write(&context->ssl, context->buffer, len) : write(context->s, context->buffer, len);
+}
+
+static int lpm_socket_read(get_context_t* context, int len) {
+ if (len == -1)
+ len = sizeof(context->buffer) - context->buffer_length;
+ if (len == 0)
+ return len;
+ len = context->is_ssl ? mbedtls_ssl_read(&context->ssl, &context->buffer[context->buffer_length], len) : read(context->s, &context->buffer[context->buffer_length], len);
+ if (len > 0)
+ context->buffer_length += len;
+ return len;
+}
+
+
+static int lpm_get_error(get_context_t* context, int error_code, const char* str, ...) {
+ if (error_code) {
+ context->error_code = error_code;
+ char mbed_buffer[256];
+ mbedtls_strerror(error_code, mbed_buffer, sizeof(mbed_buffer));
+ int error_len = context->is_ssl ? strlen(mbed_buffer) : strlen(strerror(error_code));
+ va_list va;
+ int offset = 0;
+ va_start(va, str);
+ offset = vsnprintf(context->buffer, sizeof(context->buffer), str, va);
+ va_end(va);
+ if (offset < sizeof(context->buffer) - 2) {
+ strcat(context->buffer, ": ");
+ if (offset < sizeof(context->buffer) - error_len - 2)
+ strcat(context->buffer, context->is_ssl ? mbed_buffer : strerror(error_code));
}
}
- const char* transfer_encoding = get_header(buffer, "transfer-encoding", NULL);
- int chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0;
- const char* content_length_value = get_header(buffer, "content-length", NULL);
- int content_length = content_length_value ? atoi(content_length_value) : -1;
- const char* path = luaL_optstring(L, 5, NULL);
- int callback_function = lua_type(L, 6) == LUA_TFUNCTION ? 6 : 0;
+ return error_code;
+}
- buffer_length -= (header_end - buffer);
- if (buffer_length > 0)
- memmove(buffer, header_end, buffer_length);
+static int lpm_set_error(get_context_t* context, const char* str, ...) {
+ va_list va;
+ int offset = 0;
+ va_start(va, str);
+ offset = vsnprintf(context->error, sizeof(context->error), str, va);
+ va_end(va);
+ context->error_code = -1;
+ return offset;
+}
- int total_downloaded = 0;
- int chunk_length = !chunked && content_length == -1 ? INT_MAX : content_length;
- int chunk_written = 0;
- luaL_Buffer B;
- if (path) {
- file = lua_fopen(L, path, "wb");
- if (!file) {
- snprintf(err, sizeof(err), "can't open file %s: %s", path, strerror(errno));
- goto cleanup;
+static int lpm_getk(lua_State* L, int status, lua_KContext ctx) {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx);
+ get_context_t* context = (get_context_t*)lua_touserdata(L, -1);
+ lua_pop(L,1);
+ switch (context->state) {
+ case STATE_HANDSHAKE: {
+ int status = mbedtls_ssl_handshake(&context->ssl);
+ if (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
+ return lua_yieldk(L, 0, ctx, lpm_getk);
+ if (
+ lpm_get_error(context, status, "can't handshake") ||
+ lpm_get_error(context, mbedtls_ssl_get_verify_result(&context->ssl), "can't verify result")
+ )
+ goto cleanup;
+ context->state = STATE_SEND;
}
- } else
- luaL_buffinit(L, &B);
- while (1) {
- // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk.
- while (chunk_length == -1) {
- char* newline = (char*)strnstr_local(buffer, "\r\n", buffer_length);
- if (newline) {
- *newline = '\0';
- if (sscanf(buffer, "%x", &chunk_length) != 1) {
- snprintf(err, sizeof(err), "error retrieving chunk length for %s://%s%s", protocol, hostname, rest);
- goto cleanup;
- }
- if (chunk_length == 0)
- goto finish;
- buffer_length -= (newline + 2 - buffer);
- if (buffer_length > 0)
- memmove(buffer, newline + 2, buffer_length);
- chunk_written = 0;
- } else if (buffer_length >= sizeof(buffer)) {
- snprintf(err, sizeof(err), "can't find chunk length for %s://%s%s", protocol, hostname, rest);
+ case STATE_SEND: {
+ context->buffer_length = snprintf(context->buffer, sizeof(context->buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", context->rest, context->hostname);
+ int length = lpm_socket_write(context, context->buffer_length);
+ if (length < context->buffer_length && lpm_get_error(context, length, "can't write to socket"))
goto cleanup;
- } else {
- int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length, ssl_ctx);
- if (length <= 0 || (ssl_ctx && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s://%s%s", protocol, hostname, rest);
+ context->buffer_length = 0;
+ context->buffer[0] = 0;
+ context->state = STATE_RECV_HEADER;
+ }
+ case STATE_RECV_HEADER: {
+ const char* header_end;
+ while (1) {
+ header_end = strstr(context->buffer, "\r\n\r\n");
+ if (!header_end && context->buffer_length >= sizeof(context->buffer) - 1 && lpm_set_error(context, "response header buffer length exceeded"))
goto cleanup;
+ if (!header_end) {
+ int length = lpm_socket_read(context, -1);
+ if (length < 0 && lpm_get_error(context, length, "can't read from socket"))
+ goto cleanup;
+ if (length == 0)
+ return lua_yieldk(L, 0, ctx, lpm_getk);
+ } else {
+ header_end += 4;
+ const char* protocol_end = strnstr_local(context->buffer, " ", context->buffer_length);
+ int code = atoi(protocol_end + 1);
+ if (code != 200) {
+ if (code >= 301 && code <= 303) {
+ const char* location = get_header(context->buffer, "location", &context->buffer_length);
+ if (location) {
+ lua_pushnil(L);
+ lua_newtable(L);
+ lua_pushlstring(L, location, context->buffer_length);
+ lua_setfield(L, -2, "location");
+ } else
+ lpm_set_error(context, "received invalid %d-response", code);
+ } else
+ lpm_set_error(context, "received non 200-response of %d", code);
+ goto report;
+ }
+ const char* transfer_encoding = get_header(context->buffer, "transfer-encoding", NULL);
+ context->chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0;
+ const char* content_length_value = get_header(context->buffer, "content-length", NULL);
+ context->content_length = content_length_value ? atoi(content_length_value) : -1;
+ context->buffer_length -= (header_end - context->buffer);
+ if (context->buffer_length > 0)
+ memmove(context->buffer, header_end, context->buffer_length);
+ context->chunk_length = !context->chunked && context->content_length == -1 ? INT_MAX : context->content_length;
+ context->state = STATE_RECV_BODY;
+ break;
}
- buffer_length += length;
}
}
- if (buffer_length > 0) {
- int to_write = imin(chunk_length - chunk_written, buffer_length);
- if (to_write > 0) {
- total_downloaded += to_write;
- chunk_written += to_write;
- if (callback_function) {
- lua_pushvalue(L, callback_function);
- lua_pushinteger(L, total_downloaded);
- if (content_length == -1)
- lua_pushnil(L);
- else
- lua_pushinteger(L, content_length);
- lua_call(L, 2, 0);
- }
- if (file)
- fwrite(buffer, sizeof(char), to_write, file);
- else
- luaL_addlstring(&B, buffer, to_write);
- buffer_length -= to_write;
- if (buffer_length > 0)
- memmove(buffer, &buffer[to_write], buffer_length);
- }
- if (chunk_written == chunk_length) {
- if (!chunked)
- goto finish;
- if (buffer_length >= 2) {
- if (!strnstr_local(buffer, "\r\n", 2)) {
- snprintf(err, sizeof(err), "invalid end to chunk for %s://%s%s", protocol, hostname, rest);
+ case STATE_RECV_BODY: {
+ while (1) {
+ // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk.
+ while (context->chunk_length == -1) {
+ char* newline = (char*)strnstr_local(context->buffer, "\r\n", context->buffer_length);
+ if (newline) {
+ *newline = '\0';
+ if ((sscanf(context->buffer, "%x", &context->chunk_length) != 1 && lpm_set_error(context, "error retrieving chunk length")))
+ goto cleanup;
+ else if (context->chunk_length == 0)
+ goto finish;
+ context->buffer_length -= (newline + 2 - context->buffer);
+ if (context->buffer_length > 0)
+ memmove(context->buffer, newline + 2, context->buffer_length);
+ } else if (context->buffer_length >= sizeof(context->buffer) && lpm_set_error(context, "can't find chunk length")) {
goto cleanup;
+ } else {
+ int length = lpm_socket_read(context, -1);
+ if ((length <= 0 || (context->is_ssl && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) && lpm_get_error(context, length, "error retrieving full repsonse"))
+ goto cleanup;
+ if (length == 0)
+ return lua_yieldk(L, 0, ctx, lpm_getk);
}
- memmove(buffer, &buffer[2], buffer_length - 2);
- buffer_length -= 2;
- chunk_length = -1;
+ }
+ if (context->buffer_length > 0) {
+ int to_write = imin(context->chunk_length - context->chunk_written, context->buffer_length);
+ if (to_write > 0) {
+ context->total_downloaded += to_write;
+ context->chunk_written += to_write;
+ if (context->callback_function) {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function);
+ lua_pushinteger(L, context->total_downloaded);
+ if (context->content_length == -1)
+ lua_pushnil(L);
+ else
+ lua_pushinteger(L, context->content_length);
+ lua_call(L, 2, 0);
+ }
+ if (context->file)
+ fwrite(context->buffer, sizeof(char), to_write, context->file);
+ else {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer);
+ lua_pushlstring(L, context->buffer, to_write);
+ lua_rawseti(L, -2, lua_rawlen(L, -2) + 1);
+ lua_pop(L, 1);
+ }
+ context->buffer_length -= to_write;
+ if (context->buffer_length > 0)
+ memmove(context->buffer, &context->buffer[to_write], context->buffer_length);
+ }
+ if (context->chunk_written == context->chunk_length) {
+ if (!context->chunked)
+ goto finish;
+ if (context->buffer_length >= 2) {
+ if (!strnstr_local(context->buffer, "\r\n", 2) && lpm_set_error(context, "invalid end to chunk"))
+ goto cleanup;
+ memmove(context->buffer, &context->buffer[2], context->buffer_length - 2);
+ context->buffer_length -= 2;
+ context->chunk_length = -1;
+ }
+ }
+ }
+ if (context->chunk_length > 0) {
+ int length = lpm_socket_read(context, imin(sizeof(context->buffer) - context->buffer_length, context->chunk_length - context->chunk_written + (context->chunked ? 2 : 0)));
+ if ((!context->is_ssl && length == 0) || (context->is_ssl && context->content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY))
+ goto finish;
+ if (length < 0 && lpm_get_error(context, length, "error retrieving full chunk"))
+ goto cleanup;
+ if (length == 0)
+ return lua_yieldk(L, 0, ctx, lpm_getk);
}
}
}
- if (chunk_length > 0) {
- int length = lpm_socket_read(s, &buffer[buffer_length], imin(sizeof(buffer) - buffer_length, chunk_length - chunk_written + (chunked ? 2 : 0)), ssl_ctx);
- if ((!ssl_ctx && length == 0) || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY))
- goto finish;
- if (length <= 0) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full chunk for %s://%s%s", protocol, hostname, rest);
- goto cleanup;
- }
- buffer_length += length;
- }
}
finish:
- if (file) {
- fclose(file);
- file = NULL;
- lua_pushnil(L);
- } else {
- luaL_pushresult(&B);
- }
- if (content_length != -1 && total_downloaded != content_length) {
- snprintf(err, sizeof(err), "error retrieving full response for %s://%s%s", protocol, hostname, rest);
- goto cleanup;
- }
- if (callback_function) {
- lua_pushvalue(L, callback_function);
- lua_pushboolean(L, 1);
- lua_call(L, 1, 0);
+ if (context->file) {
+ lua_pushnil(L);
+ lua_newtable(L);
+ } else {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer);
+ size_t len = lua_rawlen(L, -1);
+ luaL_Buffer b;
+ int table = lua_gettop(L);
+ luaL_buffinit(L, &b);
+ for (int i = 1; i <= len; ++i) {
+ lua_rawgeti(L, table, i);
+ size_t str_len;
+ const char* str = lua_tolstring(L, -1, &str_len);
+ lua_pop(L, 1);
+ luaL_addlstring(&b, str, str_len);
}
+ lua_pop(L, 1);
+ luaL_pushresult(&b);
lua_newtable(L);
+ }
+ if (context->content_length != -1 && context->total_downloaded != context->content_length && lpm_set_error(context, "error retrieving full response"))
+ goto cleanup;
+ report:
+ if (context->callback_function && !context->error_code) {
+ lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function);
+ lua_pushboolean(L, 1);
+ lua_call(L, 1, 0);
+ }
cleanup:
- if (ssl_ctx)
- mbedtls_ssl_free(ssl_ctx);
- if (net_ctx)
- mbedtls_net_free(net_ctx);
- if (file)
- fclose(file);
- if (s != -2)
- close(s);
- if (err[0])
- return luaL_error(L, "%s", err);
+ if (context->is_ssl) {
+ mbedtls_ssl_free(&context->ssl);
+ mbedtls_net_free(&context->net);
+ } else {
+ close(context->s);
+ }
+ if (context->callback_function)
+ luaL_unref(L, LUA_REGISTRYINDEX, context->callback_function);
+ if (context->file)
+ fclose(context->file);
+ else
+ luaL_unref(L, LUA_REGISTRYINDEX, context->lua_buffer);
+ if (context->error_code)
+ return luaL_error(L, "%s", context->error);
return 2;
}
+static int lpm_get(lua_State* L) {
+ get_context_t* context = lua_newuserdata(L, sizeof(get_context_t));
+ memset(context, 0, sizeof(get_context_t));
+ int threaded = !lua_is_main_thread(L);
+
+ const char* protocol = luaL_checkstring(L, 1);
+ strncpy(context->hostname, luaL_checkstring(L, 2), sizeof(context->hostname));
+ strncpy(context->rest, luaL_checkstring(L, 4), sizeof(context->rest));
+ const char* path = luaL_optstring(L, 5, NULL);
+ if (path) {
+ if ((context->file = lua_fopen(L, path, "wb")) == NULL)
+ return luaL_error(L, "can't open file %s: %s", path, strerror(errno));
+ } else {
+ lua_newtable(L);
+ context->lua_buffer = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
+ if (lua_type(L, 6) == LUA_TFUNCTION) {
+ lua_pushvalue(L, 6);
+ context->callback_function = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
+ context->state = STATE_CONNECT;
+
+ if (strcmp(protocol, "https") == 0) {
+ const char* port = lua_tostring(L, 3);
+ // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30
+ mbedtls_ssl_init(&context->ssl);
+ mbedtls_net_init(&context->net);
+ if (threaded)
+ mbedtls_net_set_nonblock(&context->net);
+ else
+ mbedtls_net_set_block(&context->net);
+ mbedtls_ssl_set_bio(&context->ssl, &context->net, mbedtls_net_send, mbedtls_net_recv, NULL);
+ if (
+ lpm_get_error(context, mbedtls_ssl_setup(&context->ssl, &ssl_config), "can't set up ssl") ||
+ lpm_get_error(context, mbedtls_net_connect(&context->net, context->hostname, port, MBEDTLS_NET_PROTO_TCP), "can't set hostname") ||
+ lpm_get_error(context, mbedtls_ssl_set_hostname(&context->ssl, context->hostname), "can't set hostname")
+ ) {
+ mbedtls_ssl_free(&context->ssl);
+ mbedtls_net_free(&context->net);
+ return luaL_error(L, "%s", context->error);
+ }
+ context->is_ssl = 1;
+ context->state = STATE_HANDSHAKE;
+ } else {
+ int port = luaL_checkinteger(L, 3);
+ struct hostent *host = gethostbyname(context->hostname);
+ struct sockaddr_in dest_addr = {0};
+ if (!host)
+ return luaL_error(L, "can't resolve hostname %s", context->hostname);
+ context->s = socket(AF_INET, SOCK_STREAM, 0);
+ if (threaded)
+ fcntl(context->s, F_SETFL, fcntl(context->s, F_GETFL, 0) | O_NONBLOCK);
+ dest_addr.sin_family = AF_INET;
+ dest_addr.sin_port = htons(port);
+ dest_addr.sin_addr.s_addr = *(long*)(host->h_addr);
+ const char* ip = inet_ntoa(dest_addr.sin_addr);
+ if (connect(context->s, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)) == -1 ) {
+ close(context->s);
+ return luaL_error(L, "can't connect to host %s [%s] on port %d", context->hostname, ip, port);
+ }
+ context->state = STATE_SEND;
+ }
+ if (!threaded)
+ return lpm_getk(L, 0, luaL_ref(L, LUA_REGISTRYINDEX));
+ return lua_yieldk(L, 0, luaL_ref(L, LUA_REGISTRYINDEX), lpm_getk);
+}
+
+
static int lpm_chdir(lua_State* L) {
#ifdef _WIN32
if (_wchdir(lua_toutf16(L, luaL_checkstring(L, 1))))