diff options
author | Adam Harrison <adamdharrison@gmail.com> | 2022-11-30 00:10:35 -0500 |
---|---|---|
committer | Adam Harrison <adamdharrison@gmail.com> | 2022-11-30 00:10:35 -0500 |
commit | 5788acb842cb828471120fd512c0c43a90048a6b (patch) | |
tree | 3087c06e92d0e031c72d5e39705e1bd0dd6e06a6 | |
parent | e038c8cc6818dab3f8ccd4d78ae8797d591cc13f (diff) | |
download | lite-xl-plugin-manager-5788acb842cb828471120fd512c0c43a90048a6b.tar.gz lite-xl-plugin-manager-5788acb842cb828471120fd512c0c43a90048a6b.zip |
Added in redirect support.
-rw-r--r-- | src/lpm.c | 80 | ||||
-rw-r--r-- | src/lpm.lua | 10 |
2 files changed, 68 insertions, 22 deletions
@@ -393,6 +393,7 @@ static int lpm_certs(lua_State* L) { mbedtls_ssl_conf_min_version(&ssl_config, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3); mbedtls_ssl_conf_authmode(&ssl_config, MBEDTLS_SSL_VERIFY_REQUIRED); mbedtls_ssl_conf_rng(&ssl_config, mbedtls_ctr_drbg_random, &drbg_context); + mbedtls_ssl_conf_read_timeout(&ssl_config, 5000); has_setup_ssl = 1; if (strcmp(type, "dir") == 0) { git_libgit2_opts(GIT_OPT_SET_SSL_CERT_LOCATIONS, NULL, path); @@ -583,6 +584,24 @@ static int strncicmp(const char* a, const char* b, int n) { return 0; } +static const char* get_header(const char* buffer, const char* header, int* len) { + const char* line_end = strstr(buffer, "\r\n"); + const char* header_end = strstr(buffer, "\r\n\r\n"); + int header_len = strlen(header); + while (line_end && line_end < header_end) { + if (strncicmp(line_end + 2, header, header_len) == 0) { + const char* offset = line_end + header_len + 3; + while (*offset == ' ') { ++offset; } + const char* end = strstr(offset, "\r\n"); + if (len) + *len = end - offset; + return offset; + } + line_end = strstr(line_end + 2, "\r\n"); + } + return NULL; +} + static int lpm_get(lua_State* L) { long response_code; char err[1024] = {0}; @@ -607,7 +626,7 @@ static int lpm_get(lua_State* L) { } mbedtls_net_init(&net_context); mbedtls_net_set_block(&net_context); - mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, NULL, mbedtls_net_recv_timeout); if ((status = mbedtls_net_connect(&net_context, hostname, port, MBEDTLS_NET_PROTO_TCP)) != 0) { snprintf(err, sizeof(err), "can't connect to hostname %s: %d", hostname, status); goto cleanup; } else if ((status = mbedtls_ssl_set_hostname(&ssl_context, hostname)) != 0) { @@ -624,6 +643,15 @@ static int lpm_get(lua_State* L) { if (!host) return luaL_error(L, "can't resolve hostname %s", hostname); s = socket(AF_INET, SOCK_STREAM, 0); + #ifdef _WIN32 + DWORD timeout = 5 * 1000; + setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout); + #else + struct timeval tv; + tv.tv_sec = 5; + tv.tv_usec = 0; + setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv); + #endif dest_addr.sin_family = AF_INET; dest_addr.sin_port = htons(port); dest_addr.sin_addr.s_addr = *(long*)(host->h_addr); @@ -641,31 +669,43 @@ static int lpm_get(lua_State* L) { if (buffer_length < 0) { snprintf(err, sizeof(err), "can't write to socket %s: %s", hostname, strerror(errno)); goto cleanup; } - buffer_length = lpm_socket_read(s, buffer, sizeof(buffer) - 1, ssl_ctx); - buffer[4095] = 0; - if (buffer_length < 0) { - snprintf(err, sizeof(err), "can't read from socket %s: %s", hostname,strerror(errno)); goto cleanup; + int bytes_read = 0; + const char* header_end = NULL; + while (!header_end && bytes_read < sizeof(buffer)) { + buffer_length = lpm_socket_read(s, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx); + if (buffer_length < 0) { + snprintf(err, sizeof(err), "can't read from socket %s: %s", hostname,strerror(errno)); goto cleanup; + } + bytes_read += buffer_length; + buffer[bytes_read] = 0; + header_end = strstr(buffer, "\r\n\r\n"); } - const char* header_end = strstr(buffer, "\r\n\r\n"); if (!header_end) { - snprintf(err, sizeof(err), "can't parse response headers for %s", hostname); goto cleanup; + snprintf(err, sizeof(err), "can't parse response headers for %s%s", hostname, rest); goto cleanup; } header_end += 4; const char* protocol_end = strstr(buffer, " "); int code = atoi(protocol_end + 1); if (code != 200) { - snprintf(err, sizeof(err), "received non 200-response from %s: %d", hostname, code); goto cleanup; - } - const char* line_end = strstr(buffer, "\r\n"); - int content_length = -1; - while (line_end && line_end < header_end) { - if (strncicmp(line_end + 2, "content-length:", 15) == 0) { - const char* offset = line_end + 17; - while (*offset == ' ') { ++offset; } - content_length = atoi(offset); + if (code >= 301 && code <= 303) { + int len; + const char* location = get_header(buffer, "location", &len); + if (location) { + lua_pushnil(L); + lua_newtable(L); + lua_pushlstring(L, location, len); + lua_setfield(L, -2, "location"); + } else + snprintf(err, sizeof(err), "received invalid %d-response from %s%s: %d", code, hostname, rest, code); + goto cleanup; + } else { + snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code); goto cleanup; } - line_end = strstr(line_end + 2, "\r\n"); } + const char* content_length_value = get_header(buffer, "content-length", NULL); + int content_length = -1; + if (content_length_value) + content_length = atoi(content_length_value); const char* path = luaL_optstring(L, 5, NULL); int body_length = buffer_length - (header_end - buffer); @@ -677,7 +717,7 @@ static int lpm_get(lua_State* L) { int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx); if (length == 0) break; if (length < 0) { - snprintf(err, sizeof(err), "error retrieving full response for %s: %s", hostname, strerror(errno)); goto cleanup; + snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s", hostname, rest, strerror(errno)); goto cleanup; } fwrite(buffer, sizeof(char), length, file); remaining -= length; @@ -692,7 +732,7 @@ static int lpm_get(lua_State* L) { int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx); if (length == 0) break; if (length < 0) { - snprintf(err, sizeof(err), "error retrieving full response for %s: %s", hostname, strerror(errno)); goto cleanup; + snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s", hostname, rest, strerror(errno)); goto cleanup; } luaL_addlstring(&B, buffer, length); remaining -= length; @@ -700,7 +740,7 @@ static int lpm_get(lua_State* L) { luaL_pushresult(&B); } if (content_length != -1 && remaining != 0) { - snprintf(err, sizeof(err), "error retrieving full response for %s", hostname); goto cleanup; + snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); goto cleanup; } lua_newtable(L); cleanup: diff --git a/src/lpm.lua b/src/lpm.lua index b38a231..034bfdc 100644 --- a/src/lpm.lua +++ b/src/lpm.lua @@ -458,16 +458,22 @@ local function prompt(message) return not response:find("%S") or response:find("^%s*[yY]%s*$") end + function common.get(source, target, checksum) if not source then error("requires url") end local _, _, protocol, hostname, port, rest = source:find("^(https?)://([^:/?]+):?(%d*)(.*)$") if not protocol then error("malfomed url " .. source) end if not port or port == "" then port = protocol == "https" and 443 or 80 end - if not checksum then return system.get(protocol, hostname, port, rest, target) end + if not checksum then + local res, headers = system.get(protocol, hostname, port, rest, target) + if headers.location then return common.get(headers.location, target, checksum) end + return res + end if not system.stat(CACHEDIR .. PATHSEP .. "files") then common.mkdirp(CACHEDIR .. PATHSEP .. "files") end local cache_path = CACHEDIR .. PATHSEP .. "files" .. PATHSEP .. checksum if not system.stat(cache_path) then - system.get(source, cache_path) + local res, headers = system.get(source, cache_path) + if headers.location then return common.get(headers.location, target, checksum) end if checksum ~= "SKIP" and system.hash(cache_path, "file") ~= checksum then fatal_warning("checksum doesn't match for " .. source) end end common.copy(cache_path, target) |