aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Harrison <adamdharrison@gmail.com>2022-11-30 00:10:35 -0500
committerAdam Harrison <adamdharrison@gmail.com>2022-11-30 00:10:35 -0500
commit5788acb842cb828471120fd512c0c43a90048a6b (patch)
tree3087c06e92d0e031c72d5e39705e1bd0dd6e06a6
parente038c8cc6818dab3f8ccd4d78ae8797d591cc13f (diff)
downloadlite-xl-plugin-manager-5788acb842cb828471120fd512c0c43a90048a6b.tar.gz
lite-xl-plugin-manager-5788acb842cb828471120fd512c0c43a90048a6b.zip
Added in redirect support.
-rw-r--r--src/lpm.c80
-rw-r--r--src/lpm.lua10
2 files changed, 68 insertions, 22 deletions
diff --git a/src/lpm.c b/src/lpm.c
index 3e0021e..ebd2426 100644
--- a/src/lpm.c
+++ b/src/lpm.c
@@ -393,6 +393,7 @@ static int lpm_certs(lua_State* L) {
mbedtls_ssl_conf_min_version(&ssl_config, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
mbedtls_ssl_conf_authmode(&ssl_config, MBEDTLS_SSL_VERIFY_REQUIRED);
mbedtls_ssl_conf_rng(&ssl_config, mbedtls_ctr_drbg_random, &drbg_context);
+ mbedtls_ssl_conf_read_timeout(&ssl_config, 5000);
has_setup_ssl = 1;
if (strcmp(type, "dir") == 0) {
git_libgit2_opts(GIT_OPT_SET_SSL_CERT_LOCATIONS, NULL, path);
@@ -583,6 +584,24 @@ static int strncicmp(const char* a, const char* b, int n) {
return 0;
}
+static const char* get_header(const char* buffer, const char* header, int* len) {
+ const char* line_end = strstr(buffer, "\r\n");
+ const char* header_end = strstr(buffer, "\r\n\r\n");
+ int header_len = strlen(header);
+ while (line_end && line_end < header_end) {
+ if (strncicmp(line_end + 2, header, header_len) == 0) {
+ const char* offset = line_end + header_len + 3;
+ while (*offset == ' ') { ++offset; }
+ const char* end = strstr(offset, "\r\n");
+ if (len)
+ *len = end - offset;
+ return offset;
+ }
+ line_end = strstr(line_end + 2, "\r\n");
+ }
+ return NULL;
+}
+
static int lpm_get(lua_State* L) {
long response_code;
char err[1024] = {0};
@@ -607,7 +626,7 @@ static int lpm_get(lua_State* L) {
}
mbedtls_net_init(&net_context);
mbedtls_net_set_block(&net_context);
- mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, mbedtls_net_recv, NULL);
+ mbedtls_ssl_set_bio(&ssl_context, &net_context, mbedtls_net_send, NULL, mbedtls_net_recv_timeout);
if ((status = mbedtls_net_connect(&net_context, hostname, port, MBEDTLS_NET_PROTO_TCP)) != 0) {
snprintf(err, sizeof(err), "can't connect to hostname %s: %d", hostname, status); goto cleanup;
} else if ((status = mbedtls_ssl_set_hostname(&ssl_context, hostname)) != 0) {
@@ -624,6 +643,15 @@ static int lpm_get(lua_State* L) {
if (!host)
return luaL_error(L, "can't resolve hostname %s", hostname);
s = socket(AF_INET, SOCK_STREAM, 0);
+ #ifdef _WIN32
+ DWORD timeout = 5 * 1000;
+ setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout);
+ #else
+ struct timeval tv;
+ tv.tv_sec = 5;
+ tv.tv_usec = 0;
+ setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv);
+ #endif
dest_addr.sin_family = AF_INET;
dest_addr.sin_port = htons(port);
dest_addr.sin_addr.s_addr = *(long*)(host->h_addr);
@@ -641,31 +669,43 @@ static int lpm_get(lua_State* L) {
if (buffer_length < 0) {
snprintf(err, sizeof(err), "can't write to socket %s: %s", hostname, strerror(errno)); goto cleanup;
}
- buffer_length = lpm_socket_read(s, buffer, sizeof(buffer) - 1, ssl_ctx);
- buffer[4095] = 0;
- if (buffer_length < 0) {
- snprintf(err, sizeof(err), "can't read from socket %s: %s", hostname,strerror(errno)); goto cleanup;
+ int bytes_read = 0;
+ const char* header_end = NULL;
+ while (!header_end && bytes_read < sizeof(buffer)) {
+ buffer_length = lpm_socket_read(s, &buffer[bytes_read], sizeof(buffer) - bytes_read - 1, ssl_ctx);
+ if (buffer_length < 0) {
+ snprintf(err, sizeof(err), "can't read from socket %s: %s", hostname,strerror(errno)); goto cleanup;
+ }
+ bytes_read += buffer_length;
+ buffer[bytes_read] = 0;
+ header_end = strstr(buffer, "\r\n\r\n");
}
- const char* header_end = strstr(buffer, "\r\n\r\n");
if (!header_end) {
- snprintf(err, sizeof(err), "can't parse response headers for %s", hostname); goto cleanup;
+ snprintf(err, sizeof(err), "can't parse response headers for %s%s", hostname, rest); goto cleanup;
}
header_end += 4;
const char* protocol_end = strstr(buffer, " ");
int code = atoi(protocol_end + 1);
if (code != 200) {
- snprintf(err, sizeof(err), "received non 200-response from %s: %d", hostname, code); goto cleanup;
- }
- const char* line_end = strstr(buffer, "\r\n");
- int content_length = -1;
- while (line_end && line_end < header_end) {
- if (strncicmp(line_end + 2, "content-length:", 15) == 0) {
- const char* offset = line_end + 17;
- while (*offset == ' ') { ++offset; }
- content_length = atoi(offset);
+ if (code >= 301 && code <= 303) {
+ int len;
+ const char* location = get_header(buffer, "location", &len);
+ if (location) {
+ lua_pushnil(L);
+ lua_newtable(L);
+ lua_pushlstring(L, location, len);
+ lua_setfield(L, -2, "location");
+ } else
+ snprintf(err, sizeof(err), "received invalid %d-response from %s%s: %d", code, hostname, rest, code);
+ goto cleanup;
+ } else {
+ snprintf(err, sizeof(err), "received non 200-response from %s%s: %d", hostname, rest, code); goto cleanup;
}
- line_end = strstr(line_end + 2, "\r\n");
}
+ const char* content_length_value = get_header(buffer, "content-length", NULL);
+ int content_length = -1;
+ if (content_length_value)
+ content_length = atoi(content_length_value);
const char* path = luaL_optstring(L, 5, NULL);
int body_length = buffer_length - (header_end - buffer);
@@ -677,7 +717,7 @@ static int lpm_get(lua_State* L) {
int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx);
if (length == 0) break;
if (length < 0) {
- snprintf(err, sizeof(err), "error retrieving full response for %s: %s", hostname, strerror(errno)); goto cleanup;
+ snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s", hostname, rest, strerror(errno)); goto cleanup;
}
fwrite(buffer, sizeof(char), length, file);
remaining -= length;
@@ -692,7 +732,7 @@ static int lpm_get(lua_State* L) {
int length = lpm_socket_read(s, buffer, sizeof(buffer), ssl_ctx);
if (length == 0) break;
if (length < 0) {
- snprintf(err, sizeof(err), "error retrieving full response for %s: %s", hostname, strerror(errno)); goto cleanup;
+ snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s", hostname, rest, strerror(errno)); goto cleanup;
}
luaL_addlstring(&B, buffer, length);
remaining -= length;
@@ -700,7 +740,7 @@ static int lpm_get(lua_State* L) {
luaL_pushresult(&B);
}
if (content_length != -1 && remaining != 0) {
- snprintf(err, sizeof(err), "error retrieving full response for %s", hostname); goto cleanup;
+ snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); goto cleanup;
}
lua_newtable(L);
cleanup:
diff --git a/src/lpm.lua b/src/lpm.lua
index b38a231..034bfdc 100644
--- a/src/lpm.lua
+++ b/src/lpm.lua
@@ -458,16 +458,22 @@ local function prompt(message)
return not response:find("%S") or response:find("^%s*[yY]%s*$")
end
+
function common.get(source, target, checksum)
if not source then error("requires url") end
local _, _, protocol, hostname, port, rest = source:find("^(https?)://([^:/?]+):?(%d*)(.*)$")
if not protocol then error("malfomed url " .. source) end
if not port or port == "" then port = protocol == "https" and 443 or 80 end
- if not checksum then return system.get(protocol, hostname, port, rest, target) end
+ if not checksum then
+ local res, headers = system.get(protocol, hostname, port, rest, target)
+ if headers.location then return common.get(headers.location, target, checksum) end
+ return res
+ end
if not system.stat(CACHEDIR .. PATHSEP .. "files") then common.mkdirp(CACHEDIR .. PATHSEP .. "files") end
local cache_path = CACHEDIR .. PATHSEP .. "files" .. PATHSEP .. checksum
if not system.stat(cache_path) then
- system.get(source, cache_path)
+ local res, headers = system.get(source, cache_path)
+ if headers.location then return common.get(headers.location, target, checksum) end
if checksum ~= "SKIP" and system.hash(cache_path, "file") ~= checksum then fatal_warning("checksum doesn't match for " .. source) end
end
common.copy(cache_path, target)