From 2e5b922b88fffa0ca48278388bf323da746bd9c3 Mon Sep 17 00:00:00 2001 From: Adam Harrison Date: Wed, 13 Mar 2024 15:18:21 -0400 Subject: Started genercizing sockets. --- src/lpm.c | 329 ++++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 179 insertions(+), 150 deletions(-) diff --git a/src/lpm.c b/src/lpm.c index fef74cc..fae2040 100644 --- a/src/lpm.c +++ b/src/lpm.c @@ -51,6 +51,29 @@ #define HTTPS_RESPONSE_HEADER_BUFFER_LENGTH 8192 +static int imin(int a, int b) { return a < b ? a : b; } +static int imax(int a, int b) { return a > b ? a : b; } +static int strncicmp(const char* a, const char* b, int n) { + for (int i = 0; i < n; ++i) { + if (a[i] == 0 && b[i] != 0) return -1; + if (a[i] != 0 && b[i] == 0) return 1; + int lowera = tolower(a[i]), lowerb = tolower(b[i]); + if (lowera == lowerb) continue; + if (lowera < lowerb) return -1; + return 1; + } + return 0; +} + +static const char* strnstr_local(const char* haystack, const char* needle, int n) { + int len = strlen(needle); + for (int i = 0; i <= n - len; ++i) { + if (strncmp(&haystack[i], needle, len) == 0) + return &haystack[i]; + } + return NULL; +} + #if _WIN32 static LPCWSTR lua_toutf16(lua_State* L, const char* str) { if (str && str[0] == 0) @@ -444,12 +467,8 @@ static int lpm_init(lua_State* L) { return 0; } -static int no_verify_ssl, has_setup_ssl, print_trace; -static mbedtls_x509_crt x509_certificate; -static mbedtls_entropy_context entropy_context; -static mbedtls_ctr_drbg_context drbg_context; +static int no_verify_ssl, print_trace_ssl; static mbedtls_ssl_config ssl_config; -static mbedtls_ssl_context ssl_context; static int lpm_git_transport_certificate_check_cb(struct git_cert *cert, int valid, const char *host, void *payload) { return 0; // If no_verify_ssl is enabled, basically always return 0 when this is set as callback. @@ -517,34 +536,6 @@ static int lpm_fetch(lua_State* L) { return 1; } -static int mbedtls_snprintf(int mbedtls, char* buffer, int len, int status, const char* str, ...) { - char mbed_buffer[256]; - mbedtls_strerror(status, mbed_buffer, sizeof(mbed_buffer)); - int error_len = mbedtls ? strlen(mbed_buffer) : strlen(strerror(status)); - va_list va; - int offset = 0; - va_start(va, str); - offset = vsnprintf(buffer, len, str, va); - va_end(va); - if (offset < len - 2) { - strcat(buffer, ": "); - if (offset < len - error_len - 2) - strcat(buffer, mbedtls ? mbed_buffer : strerror(status)); - } - return strlen(buffer); -} - -static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) { - char vsnbuffer[1024]; - char mbed_buffer[128]; - mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer)); - va_list va; - va_start(va, str); - vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va); - va_end(va); - return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer); -} - static void lpm_tls_debug(void *ctx, int level, const char *file, int line, const char *str) { fprintf(stderr, "%s:%04d: |%d| %s", file, line, level, str); fflush(stderr); @@ -556,12 +547,30 @@ static void lpm_libgit2_debug(git_trace_level_t level, const char *msg) { } static int lpm_trace(lua_State* L) { - print_trace = lua_toboolean(L, 1) ? 1 : 0; + print_trace_ssl = lua_toboolean(L, 1) ? 1 : 0; return 0; } + +static int luaL_mbedtls_error(lua_State* L, int code, const char* str, ...) { + char vsnbuffer[1024]; + char mbed_buffer[128]; + mbedtls_strerror(code, mbed_buffer, sizeof(mbed_buffer)); + va_list va; + va_start(va, str); + vsnprintf(vsnbuffer, sizeof(vsnbuffer), str, va); + va_end(va); + return luaL_error(L, "%s: %s", vsnbuffer, mbed_buffer); +} + + static int lpm_certs(lua_State* L) { const char* type = luaL_checkstring(L, 1); + static mbedtls_x509_crt x509_certificate; + static mbedtls_entropy_context entropy_context; + static mbedtls_ctr_drbg_context drbg_context; + static int has_setup_ssl; + int status; if (has_setup_ssl) { mbedtls_ssl_config_free(&ssl_config); @@ -572,11 +581,11 @@ static int lpm_certs(lua_State* L) { mbedtls_x509_crt_init(&x509_certificate); mbedtls_entropy_init(&entropy_context); mbedtls_ctr_drbg_init(&drbg_context); + mbedtls_ssl_config_init(&ssl_config); + if ((status = mbedtls_ctr_drbg_seed(&drbg_context, mbedtls_entropy_func, &entropy_context, NULL, 0)) != 0) return luaL_mbedtls_error(L, status, "failed to setup mbedtls_x509"); - mbedtls_ssl_config_init(&ssl_config); - status = mbedtls_ssl_config_defaults(&ssl_config, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); - if (status) + else if (status = mbedtls_ssl_config_defaults(&ssl_config, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT) != 0) return luaL_mbedtls_error(L, status, "can't set ssl_config defaults"); mbedtls_ssl_conf_max_version(&ssl_config, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3); mbedtls_ssl_conf_min_version(&ssl_config, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3); @@ -584,7 +593,7 @@ static int lpm_certs(lua_State* L) { mbedtls_ssl_conf_rng(&ssl_config, mbedtls_ctr_drbg_random, &drbg_context); mbedtls_ssl_conf_read_timeout(&ssl_config, 5000); #if defined(MBEDTLS_DEBUG_C) - if (print_trace) { + if (print_trace_ssl) { mbedtls_debug_set_threshold(5); mbedtls_ssl_conf_dbg(&ssl_config, lpm_tls_debug, NULL); git_init(); @@ -595,7 +604,7 @@ static int lpm_certs(lua_State* L) { if (strcmp(type, "noverify") == 0) { no_verify_ssl = 1; mbedtls_ssl_conf_authmode(&ssl_config, MBEDTLS_SSL_VERIFY_OPTIONAL); - if (print_trace) { + if (print_trace_ssl) { fprintf(stderr, "[ssl] SSL verify set to optional.\n"); fflush(stderr); } @@ -609,12 +618,12 @@ static int lpm_certs(lua_State* L) { status = mbedtls_x509_crt_parse_path(&x509_certificate, path); if (status < 0) return luaL_mbedtls_error(L, status, "mbedtls_x509_crt_parse_path failed to parse all CA certificates in %s", path); - if (status > 0 && print_trace) { + if (status > 0 && print_trace_ssl) { fprintf(stderr, "[ssl] mbedtls_x509_crt_parse_path on %s failed to parse %d certificates, but still succeeded.\n", path, status); fflush(stderr); } mbedtls_ssl_conf_ca_chain(&ssl_config, &x509_certificate, NULL); - if (print_trace) { + if (print_trace_ssl) { fprintf(stderr, "[ssl] SSL directory set to %s.\n", git_cert_path); fflush(stderr); } @@ -989,38 +998,117 @@ static int lpm_extract(lua_State* L) { return 0; } +typedef enum lpm_socket_type_e { + SOCKET_FD, + SOCKET_MBEDTLS, + SOCKET_OPENSSL +} lpm_socket_type_e; + +typedef enum lpm_socket_error_e { + ERROR_NONE = 0, + ERROR_VARIED = -1, + ERROR_CLOSED = -2 +} lpm_socket_error_e; + +typedef struct lpm_socket { + lpm_socket_context_type_e type; + union { + int fd; + struct { + mbedtls_ssl_context ssl; + mbedtls_net_context net; + } + }; +} lpm_socket_t; + -static int lpm_socket_write(int fd, const char* buf, int len, mbedtls_ssl_context* ctx) { - if (ctx) - return mbedtls_ssl_write(ctx, buf, len); - return write(fd, buf, len); +static lpm_socket_error_e lpm_socket_check_error(lpm_socket_t* ctx, int result, char* err, size_t err_len, const char* pattern, ...) { + if (result >= 0) + return result; + va_list va; + va_start(va, pattern); + int len = vsnprintf(err, err_len, pattern, type); + switch (ctx->type) { + case SOCKET_FD: snprintf(&err[len], err_len - len, ": %s", strerror(error)); + case SOCKET_MBEDTLS: + len += snprintf(&err[len], err_len - len, ": "); + mbedtls_strerror(result, &err[len], err_len - len); + } + va_end(va); + return -1; } -static int lpm_socket_read(int fd, char* buf, int len, mbedtls_ssl_context* ctx) { - if (ctx) - return mbedtls_ssl_read(ctx, buf, len); - return read(fd, buf, len); + +static int lpm_socket_write(lpm_socket_t* ctx, const char* buf, int len) { + switch (ctx->type) { + case SOCKET_FD: return write(ctx->fd, buf, len); + case SOCKET_MBEDTLS: return mbedtls_ssl_write(&ctx->ssl, buf, len); + } } -static int strncicmp(const char* a, const char* b, int n) { - for (int i = 0; i < n; ++i) { - if (a[i] == 0 && b[i] != 0) return -1; - if (a[i] != 0 && b[i] == 0) return 1; - int lowera = tolower(a[i]), lowerb = tolower(b[i]); - if (lowera == lowerb) continue; - if (lowera < lowerb) return -1; - return 1; + +static int lpm_socket_read(lpm_socket_t* ctx, char* buf, int len) { + switch (ctx->type) { + case SOCKET_FD: return read(ctx->fd, buf, len); + case SOCKET_MBEDTLS: return mbedtls_ssl_read(&ctx->ssl, buf, len); } - return 0; } -static const char* strnstr_local(const char* haystack, const char* needle, int n) { - int len = strlen(needle); - for (int i = 0; i <= n - len; ++i) { - if (strncmp(&haystack[i], needle, len) == 0) - return &haystack[i]; + +static int lpm_socket_close(lpm_socket_t* socket) { + switch (ctx->type) { + case SOCKET_FD: return close(ctx->fd); + case SOCKET_MBEDTLS: + mbedtls_ssl_free(&ctx->ssl); + mbedtls_net_free(&ctx->net); + return 0; + break; } - return NULL; +} + + +static int lpm_socket_connect(lpm_socket_t* sock, const char* protocol, const char* hostname, unsigned short port, char* err, size_t err_len) { + if (strcmp(protocol, "https") == 0) { + sock->type = SOCKET_MBEDTLS; + mbedtls_net_context net_context; + mbedtls_ssl_context ssl_context; + char port_string[32]; + snprintf(port_string, sizeof(port_string), "%d", port); + int status; + // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30 + mbedtls_ssl_init(&sock->ssl); + mbedtls_net_init(&sock->net); + if ( + !lpm_socket_check_error(sock, mbedtls_ssl_setup(&ssl_context, &ssl_config), err, err_len, "can't setup ssl for %s", hostname) && + !lpm_socket_check_error(sock, mbedtls_net_set_block(&net_context), err, err_len, "can't set blocking for %s", hostname) && + !lpm_socket_check_error(sock, mbedtls_ssl_set_bio(&sock->ssl, &socket->net, mbedtls_net_send, NULL, mbedtls_net_recv_timeout), err, err_len, "can't set bio %s", hostname) && + !lpm_socket_check_error(sock, mbedtls_net_connect(&net_context, hostname, port_string, MBEDTLS_NET_PROTO_TCP), err, err_len, "can't connect to hostname %s", hostname) && + !lpm_socket_check_error(sock, mbedtls_ssl_set_hostname(&ssl_context, hostname), err, err_len, "can't set hostname %s", hostname) && + !lpm_socket_check_error(sock, mbedtls_ssl_handshake(&ssl_context), err, err_len, "can't handshake %s", hostname) && + (no_verify_ssl || !lpm_socket_check_error(sock, mbedtls_ssl_get_verify_result(&ssl_context), err, err_len), "can't verify result %s", hostname) + ) + return 0; + } else { + sock->type = SOCKET_FD; + struct hostent *host = gethostbyname(hostname); + struct sockaddr_in dest_addr = {0}; + if (!host) + return snprintf(sock->err, sizeof(sock->err), "can't resolve hostname %s", hostname); + sock->fd = socket(AF_INET, SOCK_STREAM, 0); + #ifdef _WIN32 + DWORD timeout = 5 * 1000; + #else + struct timeval tv = { 5, 0 }; + #endif + setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)); + dest_addr.sin_family = AF_INET; + dest_addr.sin_port = htons(port); + dest_addr.sin_addr.s_addr = *(long*)(host->h_addr); + if (!lpm_socket_check_error(sock, "can't connect to host %s [%d]", connect(sock->fd, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)), err, err_len, inet_ntoa(dest_addr.sin_addr), port)) + return 0; + } + lpm_socket_close(s); + return -1; } static const char* get_header(const char* buffer, const char* header, int* len) { @@ -1041,85 +1129,32 @@ static const char* get_header(const char* buffer, const char* header, int* len) return NULL; } -static int imin(int a, int b) { return a < b ? a : b; } -static int imax(int a, int b) { return a > b ? a : b; } static int lpm_get(lua_State* L) { - long response_code; char err[1024] = {0}; + const char* protocol = luaL_checkstring(L, 1); const char* hostname = luaL_checkstring(L, 2); + int port = luaL_checkinteger(L, 3); + const char* rest = luaL_checkstring(L, 4); - int s = -2; - mbedtls_net_context net_context; - mbedtls_ssl_context ssl_context; - mbedtls_ssl_context* ssl_ctx = NULL; - mbedtls_net_context* net_ctx = NULL; - FILE* file = NULL; - if (strcmp(protocol, "https") == 0) { - int status; - const char* port = lua_tostring(L, 3); - // https://gist.github.com/Barakat/675c041fd94435b270a25b5881987a30 - ssl_ctx = &ssl_context; - mbedtls_ssl_init(&ssl_context); - if ((status = mbedtls_ssl_setup(&ssl_context, &ssl_config)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't set up ssl for %s: %d", hostname, status); goto cleanup; - } - net_ctx = &net_context; - mbedtls_net_init(&net_context); - mbedtls_net_set_block(&net_context); - mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, NULL, mbedtls_net_recv_timeout); - if ((status = mbedtls_net_connect(&net_context, hostname, port, MBEDTLS_NET_PROTO_TCP)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't connect to hostname %s", hostname); goto cleanup; - } else if ((status = mbedtls_ssl_set_hostname(&ssl_context, hostname)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't set hostname %s", hostname); goto cleanup; - } else if ((status = mbedtls_ssl_handshake(&ssl_context)) != 0) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't handshake with %s", hostname); goto cleanup; - } else if (((status = mbedtls_ssl_get_verify_result(&ssl_context)) != 0) && !no_verify_ssl) { - mbedtls_snprintf(1, err, sizeof(err), status, "can't verify result for %s", hostname); goto cleanup; - } - } else { - int port = luaL_checkinteger(L, 3); - struct hostent *host = gethostbyname(hostname); - struct sockaddr_in dest_addr = {0}; - if (!host) - return luaL_error(L, "can't resolve hostname %s", hostname); - s = socket(AF_INET, SOCK_STREAM, 0); - #ifdef _WIN32 - DWORD timeout = 5 * 1000; - setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout); - #else - struct timeval tv; - tv.tv_sec = 5; - tv.tv_usec = 0; - setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv); - #endif - dest_addr.sin_family = AF_INET; - dest_addr.sin_port = htons(port); - dest_addr.sin_addr.s_addr = *(long*)(host->h_addr); - const char* ip = inet_ntoa(dest_addr.sin_addr); - if (connect(s, (struct sockaddr *) &dest_addr, sizeof(struct sockaddr)) == -1 ) { - close(s); - return luaL_error(L, "can't connect to host %s [%s] on port %d", hostname, ip, port); - } - } + lpm_socket_t sock = {0}; + if (!lpm_socket_open(&socket, protocol, hostname, port)) + goto error; - const char* rest = luaL_checkstring(L, 4); + FILE* file = NULL; char buffer[HTTPS_RESPONSE_HEADER_BUFFER_LENGTH]; int buffer_length = snprintf(buffer, sizeof(buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", rest, hostname); - buffer_length = lpm_socket_write(s, buffer, buffer_length, ssl_ctx); - if (buffer_length < 0) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't write to socket %s", hostname); goto cleanup; - } + if ((buffer_length = lpm_socket_check_error(&sock, lpm_socket_write(sock, buffer, buffer_length), err, sizeof(err), "can't write to socket")) < 0) + goto cleanup; + int bytes_read = 0; const char* header_end = NULL; - - while (!header_end && bytes_read < sizeof(buffer) - 1) { - buffer_length = lpm_socket_read(s, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx); - if (buffer_length < 0) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't read from socket %s", hostname); goto cleanup; - } else if (buffer_length > 0) { + buffer_length = lpm_socket_check_error(&sock, lpm_socket_read(sock, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx), err, sizeof(err), "can't read from socket"); + if (buffer_length < 0) + goto cleanup; + if (buffer_length > 0) { bytes_read += buffer_length; buffer[bytes_read] = 0; header_end = strstr(buffer, "\r\n\r\n"); @@ -1145,7 +1180,8 @@ static int lpm_get(lua_State* L) { snprintf(err, sizeof(err), "received invalid %d-response from %s%s: %d", code, hostname, rest, code); goto cleanup; } else { - snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code); goto cleanup; + snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code); + goto cleanup; } } const char* transfer_encoding = get_header(buffer, "transfer-encoding", NULL); @@ -1191,11 +1227,11 @@ static int lpm_get(lua_State* L) { snprintf(err, sizeof(err), "can't find chunk length for %s%s", hostname, rest); goto cleanup; } else { - int length = lpm_socket_read(s, &buffer[buffer_length], sizeof(buffer) - buffer_length, ssl_ctx); - if (length <= 0 || (ssl_ctx && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) { - mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); + int length = lpm_socket_check_error(sock, "error reading from socket", lpm_socket_read(&buffer[buffer_length], sizeof(buffer) - buffer_length), err, sizeof(err)); + if (length == ERROR_CLOSED) + sprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); + if (length <= 0) goto cleanup; - } buffer_length += length; } } @@ -1237,9 +1273,9 @@ static int lpm_get(lua_State* L) { } if (chunk_length > 0) { int length = lpm_socket_read(s, &buffer[buffer_length], imin(sizeof(buffer) - buffer_length, chunk_length - chunk_written + (chunked ? 2 : 0)), ssl_ctx); - if ((!ssl_ctx && length == 0) || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) + if (length == ERROR_CLOSED) goto finish; - if (length <= 0) { + if (lpm_socket_check_error(sock, "error retrieving full chunk" length <= 0) { mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full chunk for %s%s", hostname, rest); goto cleanup; } @@ -1247,13 +1283,10 @@ static int lpm_get(lua_State* L) { } } finish: - if (file) { - fclose(file); - file = NULL; + if (file) lua_pushnil(L); - } else { + else luaL_pushresult(&B); - } if (content_length != -1 && total_downloaded != content_length) { snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); goto cleanup; @@ -1265,14 +1298,10 @@ static int lpm_get(lua_State* L) { } lua_newtable(L); cleanup: - if (ssl_ctx) - mbedtls_ssl_free(ssl_ctx); - if (net_ctx) - mbedtls_net_free(net_ctx); if (file) fclose(file); - if (s != -2) - close(s); + lpm_socket_close(s); + error: if (err[0]) return luaL_error(L, "%s", err); return 2; @@ -1415,7 +1444,7 @@ static const luaL_Reg system_lib[] = { #endif #endif #ifndef ARCH_PLATFORM - #if _WIN32 + #ifdef _WIN32 #define ARCH_PLATFORM "windows" #elif __ANDROID__ #define ARCH_PLATFORM "android" -- cgit v1.2.3