aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/lpm.c43
-rw-r--r--src/lpm.lua61
2 files changed, 88 insertions, 16 deletions
diff --git a/src/lpm.c b/src/lpm.c
index 21012ca..0309496 100644
--- a/src/lpm.c
+++ b/src/lpm.c
@@ -363,6 +363,22 @@ static int lpm_git_transport_certificate_check_cb(struct git_cert *cert, int val
return 0; // If no_verify_ssl is enabled, basically always return 0 when this is set as callback.
}
+static int lpm_git_transfer_progress_cb(const git_transfer_progress *stats, void *payload) {
+ lua_State* L = payload;
+ lua_pushvalue(L, 2);
+ lua_pushinteger(L, stats->received_bytes);
+ lua_pushinteger(L, stats->total_objects);
+ lua_pushinteger(L, stats->indexed_objects);
+ lua_pushinteger(L, stats->received_objects);
+ lua_pushinteger(L, stats->local_objects);
+ lua_pushinteger(L, stats->total_deltas);
+ lua_pushinteger(L, stats->indexed_deltas);
+ lua_call(L, 7, 1);
+ int value = lua_tointeger(L, -1);
+ lua_pop(L, 1);
+ return value;
+}
+
static int lpm_fetch(lua_State* L) {
git_repository* repository = luaL_checkgitrepo(L, 1);
git_remote* remote;
@@ -372,8 +388,11 @@ static int lpm_fetch(lua_State* L) {
}
git_fetch_options fetch_opts = GIT_FETCH_OPTIONS_INIT;
fetch_opts.download_tags = GIT_REMOTE_DOWNLOAD_TAGS_ALL;
+ fetch_opts.callbacks.payload = L;
if (no_verify_ssl)
fetch_opts.callbacks.certificate_check = lpm_git_transport_certificate_check_cb;
+ if (lua_type(L, 2) == LUA_TFUNCTION)
+ fetch_opts.callbacks.transfer_progress = lpm_git_transfer_progress_cb;
if (git_remote_fetch(remote, NULL, &fetch_opts, NULL)) {
git_remote_free(remote);
git_repository_free(repository);
@@ -381,6 +400,11 @@ static int lpm_fetch(lua_State* L) {
}
git_remote_free(remote);
git_repository_free(repository);
+ if (lua_type(L, 2) == LUA_TFUNCTION) {
+ lua_pushvalue(L, 2);
+ lua_pushboolean(L, 1);
+ lua_call(L, 1, 0);
+ }
return 0;
}
@@ -778,8 +802,10 @@ static int lpm_get(lua_State* L) {
if (content_length_value)
content_length = atoi(content_length_value);
const char* path = luaL_optstring(L, 5, NULL);
+ int callback_function = lua_type(L, 6) == LUA_TFUNCTION ? 6 : 0;
int body_length = buffer_length - (header_end - buffer);
+ int total_downloaded = body_length;
int remaining = content_length - body_length;
if (path) {
FILE* file = fopen(path, "wb");
@@ -790,8 +816,14 @@ static int lpm_get(lua_State* L) {
if (length < 0) {
snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s (%d)", hostname, rest, strerror(errno), length); goto cleanup;
}
+ if (callback_function) {
+ lua_pushvalue(L, callback_function);
+ lua_pushinteger(L, total_downloaded);
+ lua_call(L, 1, 0);
+ }
fwrite(buffer, sizeof(char), length, file);
remaining -= length;
+ total_downloaded += length;
}
fclose(file);
lua_pushnil(L);
@@ -805,14 +837,25 @@ static int lpm_get(lua_State* L) {
if (length < 0) {
snprintf(err, sizeof(err), "error retrieving full response for %s%s: %s (%d)", hostname, rest, strerror(errno), length); goto cleanup;
}
+ if (callback_function) {
+ lua_pushvalue(L, callback_function);
+ lua_pushinteger(L, total_downloaded);
+ lua_call(L, 1, 0);
+ }
luaL_addlstring(&B, buffer, length);
remaining -= length;
+ total_downloaded += length;
}
luaL_pushresult(&B);
}
if (content_length != -1 && remaining != 0) {
snprintf(err, sizeof(err), "error retrieving full response for %s%s", hostname, rest); goto cleanup;
}
+ if (callback_function) {
+ lua_pushvalue(L, callback_function);
+ lua_pushboolean(L, 1);
+ lua_call(L, 1, 0);
+ }
lua_newtable(L);
cleanup:
if (ssl_ctx)
diff --git a/src/lpm.lua b/src/lpm.lua
index 03e0b76..70299eb 100644
--- a/src/lpm.lua
+++ b/src/lpm.lua
@@ -451,7 +451,7 @@ function common.chdir(dir, callback)
if not status then error(err) end
end
-local HOME, USERDIR, CACHEDIR, JSON, VERBOSE, MOD_VERSION, QUIET, FORCE, AUTO_PULL_REMOTES, ARCH, ASSUME_YES, NO_INSTALL_OPTIONAL, TMPDIR, DATADIR, BINARY, POST, repositories, lite_xls, system_bottle
+local HOME, USERDIR, CACHEDIR, JSON, VERBOSE, MOD_VERSION, QUIET, FORCE, AUTO_PULL_REMOTES, ARCH, ASSUME_YES, NO_INSTALL_OPTIONAL, TMPDIR, DATADIR, BINARY, POST, repositories, lite_xls, system_bottle, progress_bar_label, write_progress_bar
local Plugin, Repository, LiteXL, Bottle = {}, {}, {}, {}
@@ -467,6 +467,13 @@ end
local function fatal_warning(message)
if not FORCE then error(message .. "; use --force to override") else log_warning(message) end
end
+local function log_progress_action(message)
+ if write_progress_bar then
+ progress_bar_label = message
+ else
+ log_action(message)
+ end
+end
local function prompt(message)
io.stderr:write(message .. " [Y/n]: ")
if ASSUME_YES then io.stderr:write("Y\n") return true end
@@ -475,22 +482,23 @@ local function prompt(message)
end
-function common.get(source, target, checksum, depth)
+function common.get(source, target, checksum, callback, depth)
if not source then error("requires url") end
if (depth or 0) > 10 then error("too many redirects") end
local _, _, protocol, hostname, port, rest = source:find("^(https?)://([^:/?]+):?(%d*)(.*)$")
+ log_progress_action("Downloading " .. source .. "...")
if not protocol then error("malfomed url " .. source) end
if not port or port == "" then port = protocol == "https" and 443 or 80 end
if not checksum then
- local res, headers = system.get(protocol, hostname, port, rest, target)
- if headers.location then return common.get(headers.location, target, checksum, (depth or 0) + 1) end
+ local res, headers = system.get(protocol, hostname, port, rest, target, callback)
+ if headers.location then return common.get(headers.location, target, checksum, callback, (depth or 0) + 1) end
return res
end
if not system.stat(CACHEDIR .. PATHSEP .. "files") then common.mkdirp(CACHEDIR .. PATHSEP .. "files") end
local cache_path = CACHEDIR .. PATHSEP .. "files" .. PATHSEP .. checksum
if not system.stat(cache_path) then
- local res, headers = system.get(protocol, hostname, port, rest, cache_path)
- if headers.location then return common.get(headers.location, target, checksum, (depth or 0) + 1) end
+ local res, headers = system.get(protocol, hostname, port, rest, cache_path, callback)
+ if headers.location then return common.get(headers.location, target, checksum, callback, (depth or 0) + 1) end
if checksum ~= "SKIP" and system.hash(cache_path, "file") ~= checksum then fatal_warning("checksum doesn't match for " .. source) end
end
common.copy(cache_path, target)
@@ -662,7 +670,7 @@ function Plugin:install(bottle, installing)
if self.url then
log_action("Downloading file " .. self.url .. "...")
local path = temporary_install_path .. (self.organization == 'complex' and self.path and system.stat(self.local_path).type ~= "dir" and (PATHSEP .. "init.lua") or "")
- common.get(self.url, path, self.checksum)
+ common.get(self.url, path, self.checksum, write_progress_bar)
log_action("Downloaded file " .. self.url .. " to " .. path)
if system.hash(path, "file") ~= self.checksum then fatal_warning("checksum doesn't match for " .. path) end
elseif self.remote then
@@ -686,7 +694,7 @@ function Plugin:install(bottle, installing)
local path = install_path .. PATHSEP .. (file.path or common.basename(file.url))
local temporary_path = temporary_install_path .. PATHSEP .. (file.path or common.basename(file.url))
log_action("Downloading file " .. file.url .. "...")
- common.get(file.url, temporary_path, file.checksum)
+ common.get(file.url, temporary_path, file.checksum, write_progress_bar)
log_action("Downloaded file " .. file.url .. " to " .. path)
if file.arch then system.chmod(temporary_path, 448) end -- chmod any ARCH tagged file to rwx-------
end
@@ -829,7 +837,7 @@ function Repository:generate_manifest()
if path:find("^http") then
if path:find("%.lua") then
plugin_map[name].url = path
- local file = common.get(path)
+ local file = common.get(path, nil, nil, write_progress_bar)
plugin_map[name].checksum = system.hash(file)
else
plugin_map[name].remote = path
@@ -869,9 +877,9 @@ function Repository:add(pull_remotes)
if not self.branch and not self.commit then
local path = self.local_path .. PATHSEP .. "master"
common.mkdirp(path)
- log_action("Retrieving " .. self.remote .. ":master/main...")
+ log_progress_action("Fetching " .. self.remote .. ":master/main...")
system.init(path, self.remote)
- system.fetch(path)
+ system.fetch(path, write_progress_bar)
if not pcall(system.reset, path, "refs/remotes/origin/master", "hard") then
if pcall(system.reset, path, "refs/remotes/origin/main", "hard") then
common.rename(path, self.local_path .. PATHSEP .. "main")
@@ -882,15 +890,13 @@ function Repository:add(pull_remotes)
else
self.branch = "master"
end
- log_action("Retrieved " .. self.remote .. ":master/main.")
else
local path = self.local_path .. PATHSEP .. (self.commit or self.branch)
common.mkdirp(path)
- log_action("Retrieving " .. self.remote .. ":" .. (self.commit or self.branch) .. "...")
+ log_progress_action("Fetching " .. self.remote .. ":" .. (self.commit or self.branch) .. "...")
system.init(path, self.remote)
- system.fetch(path)
+ system.fetch(path, write_progress_bar)
common.reset(path, self.commit or self.branch, "hard")
- log_action("Retrieved " .. self:url() .. "...")
self.manifest = nil
end
local manifest, remotes = self:parse_manifest()
@@ -980,7 +986,7 @@ function LiteXL:install()
local archive = basename:find("%.zip$") or basename:find("%.tar%.gz$")
local path = self.local_path .. PATHSEP .. (archive and basename or "lite-xl")
log_action("Downloading file " .. file.url .. "...")
- common.get(file.url, path, file.checksum)
+ common.get(file.url, path, file.checksum, write_progress_bar)
log_action("Downloaded file " .. file.url .. " to " .. path)
if file.checksum ~= "SKIP" and system.hash(path, "file") ~= file.checksum then fatal_warning("checksum doesn't match for " .. path) end
if archive then
@@ -1656,6 +1662,29 @@ in any circumstance unless explicitly supplied.
TMPDIR = common.normalize_path(ARGS["tmpdir"]) or CACHEDIR .. PATHSEP .. "tmp"
if ARGS["trace"] then system.trace(true) end
+ if not QUIET then
+ local start_time, last_read
+ local function format_bytes(bytes)
+ if bytes < 1024 then return string.format("%6d B", math.floor(bytes)) end
+ if bytes < 1*1024*1024 then return string.format("%6.1f kB", bytes / 1024) end
+ if bytes < 1*1024*1024*1024 then return string.format("%6.1f MB", bytes / (1024*1024)) end
+ return string.format("%6.2f GB", bytes / (1024*1024*1024))
+ end
+ write_progress_bar = function(total_read, total_objects, indexed_obejcts, received_objects, local_objects, local_deltas, indexed_deltas)
+ if type(total_read) == "boolean" then
+ io.stdout:write("\n")
+ io.stdout:flush()
+ return
+ end
+ if not start_time or total_read < last_read then start_time = os.time() end
+ local status_line = string.format("%s [%s/s]: %s", format_bytes(total_read), format_bytes(total_read / math.max(os.time() - start_time, 1)), progress_bar_label)
+ io.stdout:write(string.rep("\b", #status_line))
+ io.stdout:write(status_line)
+ io.stdout:flush()
+ last_read = total_read
+ end
+ end
+
repositories = {}
if ARGS[2] == "purge" then return lpm_purge() end
local ssl_certs = ARGS["ssl-certs"] or os.getenv("SSL_CERT_DIR") or os.getenv("SSL_CERT_FILE")