aboutsummaryrefslogtreecommitdiff
path: root/src/lpm.c
diff options
context:
space:
mode:
authorAdam Harrison <adamdharrison@gmail.com>2024-03-13 15:18:21 -0400
committerAdam Harrison <adamdharrison@gmail.com>2024-03-13 15:18:21 -0400
commit2e5b922b88fffa0ca48278388bf323da746bd9c3 (patch)
tree87fe0b2084b9947b932ec682e6a712ad99d3d379 /src/lpm.c
parentf677376167b401b2c8f9154ac30fcd21b8aff238 (diff)
downloadlite-xl-plugin-manager-genericize-sockets.tar.gz
lite-xl-plugin-manager-genericize-sockets.zip
Started genercizing sockets.genericize-sockets
Diffstat (limited to 'src/lpm.c')
-rw-r--r--src/lpm.c329
1 files changed, 179 insertions, 150 deletions
diff --git a/src/lpm.c b/src/lpm.c
index fef74cc..fae2040 100644
--- a/src/lpm.c
+++ b/src/lpm.c
@@ -51,6 +51,29 @@
#define HTTPS_RESPONSE_HEADER_BUFFER_LENGTH 8192
+static int imin(int a, int b) { return a < b ? a : b; }
+static int imax(int a, int b) { return a > b ? a : b; }
+static int strncicmp(const char* a, const char* b, int n) {
+ for (int i = 0; i < n; ++i) {
+ if (a[i] == 0 && b[i] != 0) return -1;
+ if (a[i] != 0 && b[i] == 0) return 1;
+ int lowera = tolower(a[i]), lowerb = tolower(b[i]);
+ if (lowera == lowerb) continue;
+ if (lowera < lowerb) return -1;
+ return 1;
+ }
+ return 0;
+}
+
+static const char* strnstr_local(const char* haystack, const char* needle, int n) {
+ int len = strlen(needle);
+ for (int i = 0; i <= n - len; ++i) {
+ if (strncmp(&haystack[i], needle, len) == 0)
+ return &haystack[i];
+ }
+ return NULL;
+}
+
#if _WIN32
static LPCWSTR lua_toutf16(lua_State* L, const char* str) {
if (str && str[0] == 0)
@@ -444,12 +467,8 @@ static int lpm_init(lua_State* L) {
return 0;
}
-static int no_verify_ssl, has_setup_ssl, print_trace;
-static mbedtls_x509_crt x509_certificate;
-static mbedtls_entropy_context entropy_context;
-static mbedtls_ctr_drbg_context drbg_context;
+static int no_verify_ssl, print_trace_ssl;
static mbedtls_ssl_config ssl_config;
-static mbedtls_ssl_context ssl_context;
static int lpm_git_transport_certificate_check_cb(struct git_cert *cert, int valid, const char *host, void *payload) {
return 0; // If no_verify_ssl is enabled, basically always return 0 when this is set as callback.
@@ -517,34 +536,6 @@ static int lpm_fetch(lua_State* L) {
return 1;
}
-static int mbedtls_snprintf(int mbedtls, char* buffer, int len, int status, const char* str, ...) {
- char mbed_buffer[256];
- mbedtls_strerror(status, mbed_buffer, sizeof(mbed_buffer));
- int error_len = mbedtls ? strlen(mbed_buffer) : strlen(strerror(status));
- va_list va;
- int offset = 0;
- va_start(va, str);
- offset = vsnprintf(buffer, len, str, va);
- va_end(va);
- if (offset < len - 2) {
- strcat(buffer, ": ");
- if (offset < len - error_len - 2)
- strcat(buffer, mbedtls ? mbed_buffer : strerror(status));
- }
- return strlen(buffer);
-}
-
-static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) {
- char vsnbuffer[1024];
- char mbed_buffer[128];
- mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer));
- va_list va;
- va_start(va, str);
- vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va);
- va_end(va);
- return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer);
-}
-
static void lpm_tls_debug(void *ctx, int level, const char *file, int line, const char *str) {
fprintf(stderr, "%s:%04d: |%d| %s", file, line, level, str);
fflush(stderr);
@@ -556,12 +547,30 @@ static void lpm_libgit2_debug(git_trace_level_t level, const char *msg) {
}
static int lpm_trace(lua_State* L) {
- print_trace = lua_toboolean(L, 1) ? 1 : 0;
+ print_trace_ssl = lua_toboolean(L, 1) ? 1 : 0;
return 0;
}
+
+static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) {
+ char vsnbuffer[1024];
+ char mbed_buffer[128];
+ mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer));
+ va_list va;
+ va_start(va, str);
+ vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va);
+ va_end(va);
+ return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer);
+}
+
+
static int lpm_certs(lua_State* L) {
const char* type = luaL_checkstring(L, 1);
+ static mbedtls_x509_crt x509_certificate;
+ static mbedtls_entropy_context entropy_context;
+ static mbedtls_ctr_drbg_context drbg_context;
+ static int has_setup_ssl;
+
int status;
if (has_setup_ssl) {
mbedtls_ssl_config_free(&ssl_config);
@@ -572,11 +581,11 @@ static int lpm_certs(lua_State* L) {
mbedtls_x509_crt_init(&x509_certificate);
mbedtls_entropy_init(&entropy_context);
mbedtls_ctr_drbg_init(&drbg_context);
+ mbedtls_ssl_config_init(&ssl_config);
+
if ((status = mbedtls_ctr_drbg_seed(&drbg_context, mbedtls_entropy_func, &entropy_context, NULL, 0)) != 0)
return luaL_mbedtls_error(L, status, "failed to setup mbedtls_x509");
- mbedtls_ssl_config_init(&ssl_config);
- status = mbedtls_ssl_config_defaults(&ssl_config, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
- if (status)
+ else if (status = mbedtls_ssl_config_defaults(&ssl_config, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT) != 0)
return luaL_mbedtls_error(L, status, "can't set ssl_config defaults");
mbedtls_ssl_conf_max_version(&ssl_config, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
mbedtls_ssl_conf_min_version(&ssl_config, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
@@ -584,7 +593,7 @@ static int lpm_certs(lua_State* L) {
mbedtls_ssl_conf_rng(&ssl_config, mbedtls_ctr_drbg_random, &drbg_context);
mbedtls_ssl_conf_read_timeout(&ssl_config, 5000);
#if defined(MBEDTLS_DEBUG_C)
- if (print_trace) {
+ if (print_trace_ssl) {
mbedtls_debug_set_threshold(5);
mbedtls_ssl_conf_dbg(&ssl_config, lpm_tls_debug, NULL);
git_init();
@@ -595,7 +604,7 @@ static int lpm_certs(lua_State* L) {
if (strcmp(type, "noverify") == 0) {
no_verify_ssl = 1;
mbedtls_ssl_conf_authmode(&ssl_config, MBEDTLS_SSL_VERIFY_OPTIONAL);
- if (print_trace) {
+ if (print_trace_ssl) {
fprintf(stderr, "[ssl] SSL verify set to optional.\n");
fflush(stderr);
}
@@ -609,12 +618,12 @@ static int lpm_certs(lua_State* L) {
status = mbedtls_x509_crt_parse_path(&x509_certificate, path);
if (status < 0)
return luaL_mbedtls_error(L, status, "mbedtls_x509_crt_parse_path failed to parse all CA certificates in %s", path);
- if (status > 0 && print_trace) {
+ if (status > 0 && print_trace_ssl) {
fprintf(stderr, "[ssl] mbedtls_x509_crt_parse_path on %s failed to parse %d certificates, but still succeeded.\n", path, status);
fflush(stderr);
}
mbedtls_ssl_conf_ca_chain(&ssl_config, &x509_certificate, NULL);
- if (print_trace) {
+ if (print_trace_ssl) {
fprintf(stderr, "[ssl] SSL directory set to %s.\n", git_cert_path);
fflush(stderr);
}
@@ -989,38 +998,117 @@ static int lpm_extract(lua_State* L) {
return 0;
}
+typedef enum lpm_socket_type_e {
+ SOCKET_FD,
+ SOCKET_MBEDTLS,
+ SOCKET_OPENSSL
+} lpm_socket_type_e;
+
+typedef enum lpm_socket_error_e {
+ ERROR_NONE = 0,
+ ERROR_VARIED = -1,
+ ERROR_CLOSED = -2
+} lpm_socket_error_e;
+
+typedef struct lpm_socket {
+ lpm_socket_context_type_e type;
+ union {
+ int fd;
+ struct {
+ mbedtls_ssl_context ssl;
+ mbedtls_net_context net;
+ }
+ };
+} lpm_socket_t;
+
-static int lpm_socket_write(int fd, const char* buf, int len, mbedtls_ssl_context* ctx) {
- if (ctx)
- return mbedtls_ssl_write(ctx, buf, len);
- return write(fd, buf, len);
+static lpm_socket_error_e lpm_socket_check_error(lpm_socket_t* ctx, int result, char* err, size_t err_len, const char* pattern, ...) {
+ if (result >= 0)
+ return result;
+ va_list va;
+ va_start(va, pattern);
+ int len = vsnprintf(err, err_len, pattern, type);
+ switch (ctx->type) {
+ case SOCKET_FD: snprintf(&err[len], err_len - len, ": %s", strerror(error));
+ case SOCKET_MBEDTLS:
+ len += snprintf(&err[len], err_len - len, ": ");
+ mbedtls_strerror(result, &err[len], err_len - len);
+ }
+ va_end(va);
+ return -1;
}
-static int lpm_socket_read(int fd, char* buf, int len, mbedtls_ssl_context* ctx) {
- if (ctx)
- return mbedtls_ssl_read(ctx, buf, len);
- return read(fd, buf, len);
+
+static int lpm_socket_write(lpm_socket_t* ctx, const char* buf, int len) {
+ switch (ctx->type) {
+ case SOCKET_FD: return write(ctx->fd, buf, len);
+ case SOCKET_MBEDTLS: return mbedtls_ssl_write(&ctx->ssl, buf, len);
+ }
}
-static int strncicmp(const char* a, const char* b, int n) {
- for (int i = 0; i < n; ++i) {
- if (a[i] == 0 && b[i] != 0) return -1;
- if (a[i] != 0 && b[i] == 0) return 1;
- int lowera = tolower(a[i]), lowerb = tolower(b[i]);
- if (lowera == lowerb) continue;
- if (lowera < lowerb) return -1;
- return 1;
+
+static int lpm_socket_read(lpm_socket_t* ctx, char* buf, int len) {
+ switch (ctx->type) {
+ case SOCKET_FD: return read(ctx->fd, buf, len);
+ case SOCKET_MBEDTLS: return mbedtls_ssl_read(&ctx->ssl, buf, len);
}
- return 0;
}
-static const char* strnstr_local(const char* haystack, const char* needle, int n) {
- int len = strlen(needle);
- for (int i = 0; i <= n - len; ++i) {
- if (strncmp(&haystack[i], needle, len) == 0)
- return &haystack[i];
+
+static int lpm_socket_close(lpm_socket_t* socket) {
+ switch (ctx->type) {
+ case SOCKET_FD: return close(ctx->fd);
+ case SOCKET_MBEDTLS:
+ mbedtls_ssl_free(&ctx->ssl);
+ mbedtls_net_free(&ctx->net);
+ return 0;
+ break;
}
- return NULL;
+}
+
+
+static int lpm_socket_connect(lpm_socket_t* sock, const char* protocol, const char* hostname, unsigned short port, char* err, size_t err_len) {
+ if (strcmp(protocol, "https") == 0) {
+ sock->type = SOCKET_MBEDTLS;
+ mbedtls_net_context net_context;
+ mbedtls_ssl_context ssl_context;
+ char port_string[32];
+ snprintf(port_string, sizeof(port_string), "%d", port);
+ int status;
+ // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30
+ mbedtls_ssl_init(&sock->ssl);
+ mbedtls_net_init(&sock->net);
+ if (
+ !lpm_socket_check_error(sock, mbedtls_ssl_setup(&ssl_context, &ssl_config), err, err_len, "can't setup ssl for %s", hostname) &&
+ !lpm_socket_check_error(sock, mbedtls_net_set_block(&net_context), err, err_len, "can't set blocking for %s", hostname) &&
+ !lpm_socket_check_error(sock, mbedtls_ssl_set_bio(&sock->ssl, &socket->net, mbedtls_net_send, NULL, mbedtls_net_recv_timeout), err, err_len, "can't set bio %s", hostname) &&
+ !lpm_socket_check_error(sock, mbedtls_net_connect(&net_context, hostname, port_string, MBEDTLS_NET_PROTO_TCP), err, err_len, "can't connect to hostname %s", hostname) &&
+ !lpm_socket_check_error(sock, mbedtls_ssl_set_hostname(&ssl_context, hostname), err, err_len, "can't set hostname %s", hostname) &&
+ !lpm_socket_check_error(sock, mbedtls_ssl_handshake(&ssl_context), err, err_len, "can't handshake %s", hostname) &&
+ (no_verify_ssl || !lpm_socket_check_error(sock, mbedtls_ssl_get_verify_result(&ssl_context), err, err_len), "can't verify result %s", hostname)
+ )
+ return 0;
+ } else {
+ sock->type = SOCKET_FD;
+ struct hostent *host = gethostbyname(hostname);
+ struct sockaddr_in dest_addr = {0};
+ if (!host)
+ return snprintf(sock->err, sizeof(sock->err), "can't resolve hostname %s", hostname);
+ sock->fd = socket(AF_INET, SOCK_STREAM, 0);
+ #ifdef _WIN32
+ DWORD timeout = 5 * 1000;
+ #else
+ struct timeval tv = { 5, 0 };
+ #endif
+ setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout));
+ dest_addr.sin_family = AF_INET;
+ dest_addr.sin_port = htons(port);
+ dest_addr.sin_addr.s_addr = *(long*)(host->h_addr);
+ if (!lpm_socket_check_error(sock, "can't connect to host %s [%d]", connect(sock->fd, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)), err, err_len, inet_ntoa(dest_addr.sin_addr), port))
+ return 0;
+ }
+ lpm_socket_close(s);
+ return -1;
}
static const char* get_header(const char* buffer, const char* header, int* len) {
@@ -1041,85 +1129,32 @@ static const char* get_header(const char* buffer, const char* header, int* len)
return NULL;
}
-static int imin(int a, int b) { return a < b ? a : b; }
-static int imax(int a, int b) { return a > b ? a : b; }
static int lpm_get(lua_State* L) {
- long response_code;
char err[1024] = {0};
+
const char* protocol = luaL_checkstring(L, 1);
const char* hostname = luaL_checkstring(L, 2);
+ int port = luaL_checkinteger(L, 3);
+ const char* rest = luaL_checkstring(L, 4);
- int s = -2;
- mbedtls_net_context net_context;
- mbedtls_ssl_context ssl_context;
- mbedtls_ssl_context* ssl_ctx = NULL;
- mbedtls_net_context* net_ctx = NULL;
- FILE* file = NULL;
- if (strcmp(protocol, "https") == 0) {
- int status;
- const char* port = lua_tostring(L, 3);
- // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30
- ssl_ctx = &ssl_context;
- mbedtls_ssl_init(&ssl_context);
- if ((status = mbedtls_ssl_setup(&ssl_context, &ssl_config)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't set up ssl for %s: %d", hostname, status); goto cleanup;
- }
- net_ctx = &net_context;
- mbedtls_net_init(&net_context);
- mbedtls_net_set_block(&net_context);
- mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, NULL, mbedtls_net_recv_timeout);
- if ((status = mbedtls_net_connect(&net_context, hostname, port, MBEDTLS_NET_PROTO_TCP)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't connect to hostname %s", hostname); goto cleanup;
- } else if ((status = mbedtls_ssl_set_hostname(&ssl_context, hostname)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't set hostname %s", hostname); goto cleanup;
- } else if ((status = mbedtls_ssl_handshake(&ssl_context)) != 0) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't handshake with %s", hostname); goto cleanup;
- } else if (((status = mbedtls_ssl_get_verify_result(&ssl_context)) != 0) && !no_verify_ssl) {
- mbedtls_snprintf(1, err, sizeof(err), status, "can't verify result for %s", hostname); goto cleanup;
- }
- } else {
- int port = luaL_checkinteger(L, 3);
- struct hostent *host = gethostbyname(hostname);
- struct sockaddr_in dest_addr = {0};
- if (!host)
- return luaL_error(L, "can't resolve hostname %s", hostname);
- s = socket(AF_INET, SOCK_STREAM, 0);
- #ifdef _WIN32
- DWORD timeout = 5 * 1000;
- setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout);
- #else
- struct timeval tv;
- tv.tv_sec = 5;
- tv.tv_usec = 0;
- setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv);
- #endif
- dest_addr.sin_family = AF_INET;
- dest_addr.sin_port = htons(port);
- dest_addr.sin_addr.s_addr = *(long*)(host->h_addr);
- const char* ip = inet_ntoa(dest_addr.sin_addr);
- if (connect(s, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)) == -1 ) {
- close(s);
- return luaL_error(L, "can't connect to host %s [%s] on port %d", hostname, ip, port);
- }
- }
+ lpm_socket_t sock = {0};
+ if (!lpm_socket_open(&socket, protocol, hostname, port))
+ goto error;
- const char* rest = luaL_checkstring(L, 4);
+ FILE* file = NULL;
char buffer[HTTPS_RESPONSE_HEADER_BUFFER_LENGTH];
int buffer_length = snprintf(buffer, sizeof(buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", rest, hostname);
- buffer_length = lpm_socket_write(s, buffer, buffer_length, ssl_ctx);
- if (buffer_length < 0) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't write to socket %s", hostname); goto cleanup;
- }
+ if ((buffer_length = lpm_socket_check_error(&sock, lpm_socket_write(sock, buffer, buffer_length), err, sizeof(err), "can't write to socket")) < 0)
+ goto cleanup;
+
int bytes_read = 0;
const char* header_end = NULL;
-
-
while (!header_end && bytes_read < sizeof(buffer) - 1) {
- buffer_length = lpm_socket_read(s, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx);
- if (buffer_length < 0) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't read from socket %s", hostname); goto cleanup;
- } else if (buffer_length > 0) {
+ buffer_length = lpm_socket_check_error(&sock, lpm_socket_read(sock, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx), err, sizeof(err), "can't read from socket");
+ if (buffer_length < 0)
+ goto cleanup;
+ if (buffer_length > 0) {
bytes_read += buffer_length;
buffer[bytes_read] = 0;
header_end = strstr(buffer, "\r\n\r\n");
@@ -1145,7 +1180,8 @@ static int lpm_get(lua_State* L) {
snprintf(err, sizeof(err), "received invalid %d-response from %s%s: %d", code, hostname, rest, code);
goto cleanup;
} else {
- snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code); goto cleanup;
+ snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code);
+ goto cleanup;
}
}
const char* transfer_encoding = get_header(buffer, "transfer-encoding", NULL);
@@ -1191,11 +1227,11 @@ static int lpm_get(lua_State* L) {
snprintf(err, sizeof(err), "can't find chunk length for %s%s", hostname, rest);
goto cleanup;
} else {
- int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length, ssl_ctx);
- if (length <= 0 || (ssl_ctx && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) {
- mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest);
+ int length = lpm_socket_check_error(sock, "error reading from socket", lpm_socket_read(&buffer[buffer_length], sizeof(buffer) - buffer_length), err, sizeof(err));
+ if (length == ERROR_CLOSED)
+ sprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest);
+ if (length <= 0)
goto cleanup;
- }
buffer_length += length;
}
}
@@ -1237,9 +1273,9 @@ static int lpm_get(lua_State* L) {
}
if (chunk_length > 0) {
int length = lpm_socket_read(s, &buffer[buffer_length], imin(sizeof(buffer) - buffer_length, chunk_length - chunk_written + (chunked ? 2 : 0)), ssl_ctx);
- if ((!ssl_ctx && length == 0) || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY))
+ if (length == ERROR_CLOSED)
goto finish;
- if (length <= 0) {
+ if (lpm_socket_check_error(sock, "error retrieving full chunk" length <= 0) {
mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full chunk for %s%s", hostname, rest);
goto cleanup;
}
@@ -1247,13 +1283,10 @@ static int lpm_get(lua_State* L) {
}
}
finish:
- if (file) {
- fclose(file);
- file = NULL;
+ if (file)
lua_pushnil(L);
- } else {
+ else
luaL_pushresult(&B);
- }
if (content_length != -1 && total_downloaded != content_length) {
snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest);
goto cleanup;
@@ -1265,14 +1298,10 @@ static int lpm_get(lua_State* L) {
}
lua_newtable(L);
cleanup:
- if (ssl_ctx)
- mbedtls_ssl_free(ssl_ctx);
- if (net_ctx)
- mbedtls_net_free(net_ctx);
if (file)
fclose(file);
- if (s != -2)
- close(s);
+ lpm_socket_close(s);
+ error:
if (err[0])
return luaL_error(L, "%s", err);
return 2;
@@ -1415,7 +1444,7 @@ static const luaL_Reg system_lib[] = {
#endif
#endif
#ifndef ARCH_PLATFORM
- #if _WIN32
+ #ifdef _WIN32
#define ARCH_PLATFORM "windows"
#elif __ANDROID__
#define ARCH_PLATFORM "android"