From d236f64d08d33c77ba317964f86e0547092a8493 Mon Sep 17 00:00:00 2001 From: Adam Harrison Date: Wed, 6 Mar 2024 17:40:40 -0500 Subject: Fixed a minor bug with downloading when downloads get aborted, and also ensured that download messages on small terminals don't go nuts. --- src/lpm.c | 26 ++++++++++++++++++++++++-- src/lpm.lua | 13 +++++++++---- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/lpm.c b/src/lpm.c index 17054e4..78e3b3e 100644 --- a/src/lpm.c +++ b/src/lpm.c @@ -6,6 +6,7 @@ #else #include #include + #include #include #include #include @@ -165,6 +166,25 @@ static int lpm_tcflush(lua_State* L) { return 0; } +static int lpm_tcwidth(lua_State* L) { + int stream = luaL_checkinteger(L, 1); + #ifndef _WIN32 + if (isatty(stream)) { + struct winsize ws={0}; + ioctl(stream, TIOCGWINSZ, &ws); + lua_pushinteger(L, ws.ws_col); + return 1; + } + #else + CONSOLE_SCREEN_BUFFER_INFO csbi; + int columns, rows; + if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi)) { + lua_pushinteger(L, csbi.srWindow.Right - csbi.srWindow.Left + 1); + return 1; + } + #endif + return 0; +} static int lpm_symlink(lua_State* L) { #ifndef _WIN32 @@ -872,10 +892,11 @@ static int lpm_extract(lua_State* L) { while (remaining > 0) { int read_size = remaining < sizeof(buffer) ? remaining : sizeof(buffer); - if (mtar_read_data(&tar, buffer, read_size) != MTAR_ESUCCESS) { + int err = mtar_read_data(&tar, buffer, read_size); + if (err != MTAR_ESUCCESS) { fclose(file); mtar_close(&tar); - return luaL_error(L, "can't write file %s: %s", target, strerror(errno)); + return luaL_error(L, "can't read file %s: %s", target, mtar_strerror(err)); } fwrite(buffer, sizeof(char), read_size, file); @@ -1290,6 +1311,7 @@ static const luaL_Reg system_lib[] = { { "rmdir", lpm_rmdir }, // Removes a directory. { "hash", lpm_hash }, // Returns a hex sha256 hash. { "tcflush", lpm_tcflush }, // Flushes an terminal stream. + { "tcwidth", lpm_tcwidth }, // Gets the terminal width in columns. { "symlink", lpm_symlink }, // Creates a symlink. { "chmod", lpm_chmod }, // Chmod's a file. { "init", lpm_init }, // Initializes a git repository with the specified remote. diff --git a/src/lpm.lua b/src/lpm.lua index c990407..7b291cf 100644 --- a/src/lpm.lua +++ b/src/lpm.lua @@ -579,9 +579,10 @@ function common.get(source, options) local cache_path = cache_dir .. PATHSEP .. "files" .. PATHSEP .. (checksum ~= "SKIP" and checksum or system.hash(source)) local res if not system.stat(cache_path) then - res, headers = system.get(protocol, hostname, port, rest, cache_path, callback) + res, headers = system.get(protocol, hostname, port, rest, cache_path .. ".part", callback) if headers.location then return common.get(headers.location, common.merge(options, { depth = (depth or 0) + 1 })) end - if checksum ~= "SKIP" and system.hash(cache_path, "file") ~= checksum then fatal_warning("checksum doesn't match for " .. source) end + if checksum ~= "SKIP" and system.hash(cache_path .. ".part", "file") ~= checksum then fatal_warning("checksum doesn't match for " .. source) end + common.rename(cache_path .. ".part", cache_path) end if target then common.copy(cache_path, target) else res = io.open(cache_path, "rb"):read("*all") end return res @@ -2406,9 +2407,13 @@ not commonly used publically. return end if not start_time or not last_read or total_read < last_read then start_time = system.time() end - local status_line = string.format("%s [%s/s][%03d%%]: %s", format_bytes(total_read), format_bytes(total_read / (system.time() - start_time)), math.floor((received_objects and (received_objects/total_objects_or_content_length) or (total_read/total_objects_or_content_length) or 0)*100), progress_bar_label) + local status_line = string.format("%s [%s/s][%03d%%]: ", format_bytes(total_read), format_bytes(total_read / (system.time() - start_time)), math.floor((received_objects and (received_objects/total_objects_or_content_length) or (total_read/total_objects_or_content_length) or 0)*100)) + local terminal_width = system.tcwidth(1) + if not terminal_width then terminal_width = #status_line + #progress_bar_label end + local characters_remaining = terminal_width - #status_line + local message = progress_bar_label:sub(1, characters_remaining) io.stdout:write("\r") - io.stdout:write(status_line) + io.stdout:write(status_line .. message) io.stdout:flush() last_read = total_read end -- cgit v1.2.3