From 1d38779d7cf816b1e3152ac337f9f8fe9210dd5b Mon Sep 17 00:00:00 2001 From: Adam Harrison Date: Wed, 18 Jan 2023 21:52:08 -0500 Subject: Updated error handling for TLS connections, properly handled peer closing with no content_length. --- src/lpm.c | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) (limited to 'src/lpm.c') diff --git a/src/lpm.c b/src/lpm.c index 300dd24..9af5970 100644 --- a/src/lpm.c +++ b/src/lpm.c @@ -425,10 +425,10 @@ static int lpm_fetch(lua_State* L) { return 0; } -static int mbedtls_snprintf(char* buffer, int len, int status, const char* str, ...) { - char mbed_buffer[128]; +static int mbedtls_snprintf(int mbedtls, char* buffer, int len, int status, const char* str, ...) { + char mbed_buffer[256]; mbedtls_strerror(status, mbed_buffer, sizeof(mbed_buffer)); - int error_len = strlen(mbed_buffer); + int error_len = mbedtls ? strlen(mbed_buffer) : strlen(strerror(status)); va_list va; int offset = 0; va_start(va, str); @@ -437,7 +437,7 @@ static int mbedtls_snprintf(char* buffer, int len, int status, const char* str, if (offset < len - 2) { strcat(buffer, ": "); if (offset < len - error_len - 2) - strcat(buffer, mbed_buffer); + strcat(buffer, mbedtls ? mbed_buffer : strerror(status)); } return strlen(buffer); } @@ -786,20 +786,20 @@ static int lpm_get(lua_State* L) { ssl_ctx = &ssl_context; mbedtls_ssl_init(&ssl_context); if ((status = mbedtls_ssl_setup(&ssl_context, &ssl_config)) != 0) { - mbedtls_snprintf(err, sizeof(err), status, "can't set up ssl for %s: %d", hostname, status); goto cleanup; + mbedtls_snprintf(1, err, sizeof(err), status, "can't set up ssl for %s: %d", hostname, status); goto cleanup; } net_ctx = &net_context; mbedtls_net_init(&net_context); mbedtls_net_set_block(&net_context); mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, NULL, mbedtls_net_recv_timeout); if ((status = mbedtls_net_connect(&net_context, hostname, port, MBEDTLS_NET_PROTO_TCP)) != 0) { - mbedtls_snprintf(err, sizeof(err), status, "can't connect to hostname %s", hostname); goto cleanup; + mbedtls_snprintf(1, err, sizeof(err), status, "can't connect to hostname %s", hostname); goto cleanup; } else if ((status = mbedtls_ssl_set_hostname(&ssl_context, hostname)) != 0) { - mbedtls_snprintf(err, sizeof(err), status, "can't set hostname %s", hostname); goto cleanup; + mbedtls_snprintf(1, err, sizeof(err), status, "can't set hostname %s", hostname); goto cleanup; } else if ((status = mbedtls_ssl_handshake(&ssl_context)) != 0) { - mbedtls_snprintf(err, sizeof(err), status, "can't handshake with %s", hostname); goto cleanup; + mbedtls_snprintf(1, err, sizeof(err), status, "can't handshake with %s", hostname); goto cleanup; } else if (((status = mbedtls_ssl_get_verify_result(&ssl_context)) != 0) && !no_verify_ssl) { - mbedtls_snprintf(err, sizeof(err), status, "can't verify result for %s", hostname); goto cleanup; + mbedtls_snprintf(1, err, sizeof(err), status, "can't verify result for %s", hostname); goto cleanup; } } else { int port = luaL_checkinteger(L, 3); @@ -832,14 +832,14 @@ static int lpm_get(lua_State* L) { int buffer_length = snprintf(buffer, sizeof(buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", rest, hostname); buffer_length = lpm_socket_write(s, buffer, buffer_length, ssl_ctx); if (buffer_length < 0) { - snprintf(err, sizeof(err), "can't write to socket %s: %s", hostname, strerror(errno)); goto cleanup; + mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't write to socket %s", hostname); goto cleanup; } int bytes_read = 0; const char* header_end = NULL; while (!header_end && bytes_read < sizeof(buffer)) { buffer_length = lpm_socket_read(s, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx); if (buffer_length < 0) { - snprintf(err, sizeof(err), "can't read from socket %s: %s", hostname,strerror(errno)); goto cleanup; + mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? buffer_length : errno, "can't read from socket %s", hostname); goto cleanup; } bytes_read += buffer_length; buffer[bytes_read] = 0; @@ -882,9 +882,9 @@ static int lpm_get(lua_State* L) { fwrite(header_end, sizeof(char), body_length, file); while (content_length == -1 || remaining > 0) { int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx); - if (length == 0) break; + if (length == 0 || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) break; if (length < 0) { - snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s (%d)", hostname, rest, strerror(errno), length); goto cleanup; + mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); goto cleanup; } if (callback_function) { lua_pushvalue(L, callback_function); @@ -904,9 +904,9 @@ static int lpm_get(lua_State* L) { luaL_addlstring(&B, header_end, body_length); while (content_length == -1 || remaining > 0) { int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx); - if (length == 0) break; + if (length == 0 || (ssl_ctx && content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) break; if (length < 0) { - snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s (%d)", hostname, rest, strerror(errno), length); goto cleanup; + mbedtls_snprintf(ssl_ctx ? 1 : 0, err, sizeof(err), ssl_ctx ? length : errno, "error retrieving full response for %s%s", hostname, rest); goto cleanup; } if (callback_function) { lua_pushvalue(L, callback_function); -- cgit v1.2.3