Files
nvim/lua/util.lua
T
2026-04-19 00:40:24 +02:00

469 lines
12 KiB
Lua

local log = require("log")
local M = {}
M.os_name = vim.uv.os_uname().sysname
---@alias OutputStream "stdout" | "stderr" | "in_place"
---@class ow.FormatOptions
---@field buf? integer Buffer to apply formatting to
---@field cmd string[] Command to run. The following keywords get replaces by the specified values:
--- * %file% - path to the file for `buf`
--- * %filename% - name of the file for `buf`
--- * %row_start% - first row of selection
--- * %row_end% - last row of selection
--- * %col_start% - first column position of selection
--- * %col_end% - last column position of selection
--- * %byte_start% - byte count of first cell in selection
--- * %byte_end% - byte count of last cell in selection
---@field output? OutputStream What stream to use as the result. May be one of `stdout`, `stderr` or `in_place`.
---@field auto_indent? boolean Perform auto indent on formatted range
---@field only_selection? boolean Only send the selected lines to `stdin`
---@field ignore_ret? boolean Ignore non-zero return codes
---@field ignore_stderr? boolean Ignore stderr output when not using stderr for output
---@field env? table<string, string> Map of environment variables
--- Format buffer
---@param opts ow.FormatOptions
function M.format(opts)
opts = {
buf = opts.buf or vim.api.nvim_get_current_buf(),
cmd = opts.cmd,
output = opts.output or "stdout",
auto_indent = opts.auto_indent,
only_selection = opts.only_selection,
ignore_ret = opts.ignore_ret,
ignore_stderr = opts.ignore_stderr,
env = opts.env,
}
local file = vim.api.nvim_buf_get_name(opts.buf)
local filename = vim.fn.fnamemodify(file, ":t")
local mode = vim.fn.mode()
local is_visual = mode == "v" or mode == "V" or mode == ""
-- All 1-indexed, inclusive
local row_start, row_end
local col_start, col_end
if is_visual then
row_start, col_start = unpack(vim.fn.getpos("v"), 2, 3)
row_end, col_end = unpack(vim.fn.getpos("."), 2, 3)
if
row_start > row_end
or (row_start == row_end and col_start > col_end)
then
row_start, row_end, col_start, col_end =
row_end, row_start, col_end, col_start
end
if mode == "V" then
col_start = 1
col_end = #vim.api.nvim_buf_get_lines(
opts.buf,
row_end - 1,
row_end,
false
)[1]
end
else
row_start = 1
col_start = 1
row_end = vim.api.nvim_buf_line_count(opts.buf)
col_end = #vim.api.nvim_buf_get_lines(
opts.buf,
row_end - 1,
row_end,
false
)[1]
end
local byte_start = vim.api.nvim_buf_get_offset(opts.buf, row_start - 1)
+ col_start
- 1
local byte_end = vim.api.nvim_buf_get_offset(opts.buf, row_end - 1)
+ col_end
local input
if is_visual and opts.only_selection then
input = vim.api.nvim_buf_get_text(
opts.buf,
row_start - 1,
col_start - 1,
row_end - 1,
col_end,
{}
)
else
input = vim.api.nvim_buf_get_lines(opts.buf, 0, -1, false)
end
local tmp
if opts.output == "in_place" then
tmp = os.tmpname()
vim.fn.writefile(input, tmp, "s")
file = tmp
end
for i, arg in ipairs(opts.cmd) do
arg = arg:gsub("%%file%%", file)
arg = arg:gsub("%%filename%%", filename)
if is_visual then
arg = arg:gsub("%%row_start%%", row_start)
arg = arg:gsub("%%row_end%%", row_end)
arg = arg:gsub("%%col_start%%", col_start)
arg = arg:gsub("%%col_end%%", col_end)
arg = arg:gsub("%%byte_start%%", byte_start)
arg = arg:gsub("%%byte_end%%", byte_end)
end
opts.cmd[i] = arg
end
local resp = vim.system(opts.cmd, {
stdin = input,
env = opts.env,
}):wait()
local stdout = resp.stdout or ""
local stderr = resp.stderr or ""
local tmp_out
if tmp then
local f = io.open(tmp, "r")
if not f then
return
end
tmp_out = f:read("*a")
f:close()
os.remove(tmp)
end
if
(not opts.ignore_ret and resp.code ~= 0)
or (opts.output ~= "stderr" and not opts.ignore_stderr and stderr ~= "")
then
local msg = ""
if stderr ~= "" then
msg = ":\n" .. stderr
end
log.error("Failed to format (%d)%s", resp.code, msg)
return
end
local output = ""
if opts.output == "stdout" then
output = stdout
elseif opts.output == "stderr" then
output = stderr
elseif opts.output == "in_place" then
output = tmp_out or ""
end
output = output:gsub("%s+$", "")
local old_lines = input
local new_lines =
vim.split(output:gsub("\r\n", "\n"), "\n", { plain = true })
local diff = vim.text.diff(
table.concat(old_lines, "\n"),
table.concat(new_lines, "\n"),
{ result_type = "indices", algorithm = "histogram" }
)
if not diff or #diff == 0 then
return
end
---@type lsp.TextEdit[]
local text_edits = {}
---@diagnostic disable-next-line: param-type-mismatch
for _, hunk in ipairs(diff) do
local old_start, old_count, new_start, new_count = unpack(hunk)
local lines = {}
for j = new_start, new_start + new_count - 1 do
table.insert(lines, new_lines[j])
end
local new_text = table.concat(lines, "\n") .. "\n"
if new_count == 0 then
new_text = ""
end
local start_line = row_start - 1 + old_start - 1
local end_line = row_start - 1 + old_start - 1 + old_count
if old_count == 0 then
-- Insertion: old_start means "after line N" (where N is 1-indexed),
-- which equals the 0-indexed position old_start
start_line = start_line + 1
end_line = end_line + 1
end
table.insert(text_edits, {
range = {
start = {
line = start_line,
character = 0,
},
["end"] = {
line = end_line,
character = 0,
},
},
newText = new_text,
})
end
local view = vim.fn.winsaveview()
vim.lsp.util.apply_text_edits(text_edits, opts.buf, "utf-16")
if opts.auto_indent then
vim.api.nvim_cmd({
cmd = "normal",
args = { "==" },
bang = true,
range = {
row_start,
math.min(row_end, vim.api.nvim_buf_line_count(opts.buf)),
},
}, { output = false })
end
vim.fn.winrestview(view)
end
--- Check if `val` is a list of type `t` (if given)
---@param val any
---@param kt type
---@param vt type
---@return boolean
function M.is_map(val, kt, vt)
if type(val) ~= "table" then
return false
end
for k, v in pairs(val) do
if type(k) ~= kt then
return false
end
if type(v) ~= vt then
return false
end
end
return true
end
--- Check if `val` is a list of type `t` (if given)
---@param val any
---@param t? type
---@return boolean
function M.is_list(val, t)
if not vim.islist(val) then
return false
end
for k, v in pairs(val) do
if type(k) ~= "number" then
return false
end
if t and type(v) ~= t then
return false
end
end
return true
end
--- Check if `val` is a list of type `t` (if given), or nil
---@param val? any
---@param t? type
---@return boolean
function M.is_list_or_nil(val, t)
if val == nil then
return true
else
return M.is_list(val, t)
end
end
---@class ow.Util.Debouncer
---@field package _fn fun(...)
---@field package _delay integer
---@field package _timer uv.uv_timer_t
---@field package _gen integer
---@field package _fired_gen integer
---@field package _args? table
---@field package _cb_main fun()
---@field package _cb_uv fun()
local Debouncer = {}
Debouncer.__index = Debouncer
---@param fn fun(...)
---@param delay integer
---@return ow.Util.Debouncer
function Debouncer.new(fn, delay)
local self = setmetatable({
_fn = fn,
_delay = delay,
_timer = assert(vim.uv.new_timer()),
_gen = 0,
_fired_gen = 0,
_args = nil,
}, Debouncer)
self._cb_main = vim.schedule_wrap(function()
-- Identity check: the libuv fire may have been superseded by a
-- re-arm or a cancel between the timer firing and this scheduled
-- callback running.
if self._fired_gen ~= self._gen or self._args == nil then
return
end
local args = self._args
self._args = nil
self._fn(vim.F.unpack_len(args))
end)
self._cb_uv = function()
self._fired_gen = self._gen
self._cb_main()
end
return self
end
function Debouncer:__call(...)
self._args = vim.F.pack_len(...)
self._gen = self._gen + 1
self._timer:start(self._delay, 0, self._cb_uv)
end
function Debouncer:cancel()
self._timer:stop()
self._args = nil
end
function Debouncer:flush()
if self._args == nil then
return
end
self._timer:stop()
local args = self._args
self._args = nil
self._fn(vim.F.unpack_len(args))
end
---@return boolean
function Debouncer:pending()
return self._args ~= nil
end
function Debouncer:close()
self._timer:stop()
if not self._timer:is_closing() then
self._timer:close()
end
self._args = nil
end
---@generic F: fun(...)
---@param fn F
---@param delay integer
---@return F | ow.Util.Debouncer
function M.debounce(fn, delay)
return Debouncer.new(fn, delay)
end
---@class ow.Util.KeyedDebouncer<T>
---@field package _fn fun(key: T, ...)
---@field package _delay integer
---@field package _slots table<T, ow.Util.Debouncer>
local KeyedDebouncer = {}
KeyedDebouncer.__index = KeyedDebouncer
---@generic T
---@param fn fun(key: T, ...)
---@param delay integer
---@return ow.Util.KeyedDebouncer<T>
function KeyedDebouncer.new(fn, delay)
return setmetatable({
_fn = fn,
_delay = delay,
_slots = {},
}, KeyedDebouncer)
end
---@generic T
---@param self ow.Util.KeyedDebouncer<T>
---@param key T
function KeyedDebouncer:__call(key, ...)
local slot = self._slots[key]
if not slot then
slot = Debouncer.new(function(...)
self._fn(key, ...)
end, self._delay)
self._slots[key] = slot
end
slot(...)
end
---@generic T
---@param self ow.Util.KeyedDebouncer<T>
---@param key T
function KeyedDebouncer:cancel(key)
local slot = self._slots[key]
if slot then
slot:close()
self._slots[key] = nil
end
end
---@generic T
---@param self ow.Util.KeyedDebouncer<T>
---@param key T
function KeyedDebouncer:flush(key)
local slot = self._slots[key]
if slot then
slot:flush()
end
end
---@generic T
---@param self ow.Util.KeyedDebouncer<T>
---@param key T
---@return boolean
function KeyedDebouncer:pending(key)
local slot = self._slots[key]
return slot ~= nil and slot:pending()
end
function KeyedDebouncer:close()
for _, slot in pairs(self._slots) do
slot:close()
end
self._slots = {}
end
---@diagnostic disable-next-line: undefined-doc-name
---@generic T, F: fun(key: T, ...)
---@param fn F
---@param delay integer
---@return F | ow.Util.KeyedDebouncer<T>
function M.keyed_debounce(fn, delay)
return KeyedDebouncer.new(fn, delay)
end
function M.get_hl_source(name)
local hl = vim.api.nvim_get_hl(0, { name = name })
while hl.link do
hl = vim.api.nvim_get_hl(0, { name = hl.link })
end
return hl
end
return M