From 2abd1d653d54aff85cdc0d1034175df65e51485f Mon Sep 17 00:00:00 2001 From: Oscar Wallberg Date: Thu, 7 May 2026 15:28:26 +0200 Subject: [PATCH] feat(git/cmd): improved completion for :G --- lua/git/cmd.lua | 342 ++++++++++++++++++++++++++++++++++++++---- lua/git/repo.lua | 41 ++++- test/git/cmd_test.lua | 32 ++++ 3 files changed, 381 insertions(+), 34 deletions(-) diff --git a/lua/git/cmd.lua b/lua/git/cmd.lua index 0bab41d..f8b9555 100644 --- a/lua/git/cmd.lua +++ b/lua/git/cmd.lua @@ -284,6 +284,158 @@ function M.run(args) end end +---@param items string[] +---@param lead string +---@return string[] +local function prefix_filter(items, lead) + return vim.tbl_filter(function(it) + return vim.startswith(it, lead) + end, items) +end + +---@param prefix string +---@param dir string +---@param name_lead string +---@param entries string[] +---@return string[] +local function path_segments(prefix, dir, name_lead, entries) + local matches = {} + local seen = {} + for _, full_path in ipairs(entries) do + local rel = dir == "" and full_path or full_path:sub(#dir + 1) + local slash = rel:find("/", 1, true) + local segment = slash and rel:sub(1, slash) or rel + if not seen[segment] and segment:sub(1, #name_lead) == name_lead then + seen[segment] = true + table.insert(matches, prefix .. dir .. segment) + end + end + return matches +end + +---@param r ow.Git.Repo +---@param dir string +---@return string[] +local function list_files(r, dir) + local cmd = { "git", "ls-files" } + if dir ~= "" then + table.insert(cmd, dir) + end + local out = util.exec(cmd, { cwd = r.worktree, silent = true }) + return out and util.split_lines(out) or {} +end + +---@param r ow.Git.Repo +---@return string[] +local function list_remotes(r) + local out = util.exec( + { "git", "remote" }, + { cwd = r.worktree, silent = true } + ) + return out and util.split_lines(out) or {} +end + +---@type table +local SUBSUB_FALLBACK = { + submodule = { + "add", + "status", + "init", + "deinit", + "update", + "summary", + "foreach", + "sync", + "absorbgitdirs", + }, +} + +---@type table +local cached_completions = {} + +---@param sub string +---@return string[] +local function fetch_completions(sub) + if cached_completions[sub] then + return cached_completions[sub] + end + local out = util.exec( + { "git", sub, "--git-completion-helper-all" }, + { silent = true } + ) or util.exec( + { "git", sub, "--git-completion-helper" }, + { silent = true } + ) + local items = {} + if out then + for tok in out:gmatch("%S+") do + table.insert(items, tok) + end + end + cached_completions[sub] = items + return items +end + +---@param sub string +---@return string[] +local function fetch_subsubcommands(sub) + local subs = {} + for _, it in ipairs(fetch_completions(sub)) do + if it:sub(1, 1) ~= "-" and it ~= "--" then + table.insert(subs, it) + end + end + if #subs == 0 and SUBSUB_FALLBACK[sub] then + return SUBSUB_FALLBACK[sub] + end + return subs +end + +---@param sub string +---@return string[] +local function fetch_flags(sub) + local flags = {} + for _, it in ipairs(fetch_completions(sub)) do + if it:sub(1, 1) == "-" and it ~= "--" then + table.insert(flags, it) + end + end + return flags +end + +---@param r ow.Git.Repo +---@param lead string +---@return string[] +local function complete_tracked_paths(r, lead) + local dir, name_lead = lead:match("^(.*/)([^/]*)$") + dir = dir or "" + name_lead = name_lead or lead + return path_segments("", dir, name_lead, list_files(r, dir)) +end + +---@param r ow.Git.Repo +---@param lead string +---@return string[] +local function complete_unstaged_paths(r, lead) + local matches = {} + for path, entry_list in pairs(r.status.entries) do + if path:sub(1, #lead) == lead then + for _, e in ipairs(entry_list) do + if + e.kind == "unstaged" + or e.kind == "untracked" + or e.kind == "unmerged" + then + table.insert(matches, path) + break + end + end + end + end + table.sort(matches) + return matches +end + ---@param arg_lead string ---@return string[] function M.complete_rev(arg_lead) @@ -317,13 +469,10 @@ function M.complete_rev(arg_lead) local colon = arg_lead:find(":", 1, true) if not colon then - local matches = {} - for _, ref in ipairs(r:list_refs()) do - if ref:sub(1, #arg_lead) == arg_lead then - table.insert(matches, ref) - end - end - return matches + local refs = r:list_refs() + vim.list_extend(refs, r:list_pseudo_refs()) + vim.list_extend(refs, r:list_stash_refs()) + return prefix_filter(refs, arg_lead) end local rev = arg_lead:sub(1, colon - 1) @@ -358,44 +507,173 @@ function M.complete_rev(arg_lead) return matches end - local cmd = { "git", "ls-files" } - if dir ~= "" then - table.insert(cmd, dir) - end - local out = util.exec(cmd, { cwd = r.worktree, silent = true }) - if not out then + return path_segments(":", dir, name_lead, list_files(r, dir)) +end + +---@alias ow.Git.Cmd.Handler fun(r: ow.Git.Repo, lead: string, sub: string, idx: integer): string[] +---@alias ow.Git.Cmd.Slot ow.Git.Cmd.Handler | ow.Git.Cmd.Handler[] + +---@param r ow.Git.Repo +---@param lead string +---@return string[] +local function complete_remote(r, lead) + return prefix_filter(list_remotes(r), lead) +end + +---@param r ow.Git.Repo +---@param lead string +---@return string[] +local function complete_ref(r, lead) + return prefix_filter(r:list_refs(), lead) +end + +---@param r ow.Git.Repo +---@param lead string +---@return string[] +local function complete_pseudo_ref(r, lead) + return prefix_filter(r:list_pseudo_refs(), lead) +end + +---@param r ow.Git.Repo +---@param lead string +---@return string[] +local function complete_stash_ref(r, lead) + return prefix_filter(r:list_stash_refs(), lead) +end + +---@param _ ow.Git.Repo +---@param lead string +---@return string[] +local function complete_rev(_, lead) + return M.complete_rev(lead) +end + +---@param _ ow.Git.Repo +---@param lead string +---@param sub string +---@param idx integer +---@return string[] +local function complete_subsubcmd(_, lead, sub, idx) + if idx ~= 1 then return {} end - local matches = {} - local seen = {} - for _, full_path in ipairs(util.split_lines(out)) do - local rel = dir == "" and full_path or full_path:sub(#dir + 1) - local slash = rel:find("/", 1, true) - local segment = slash and rel:sub(1, slash) or rel - if not seen[segment] and segment:sub(1, #name_lead) == name_lead then - seen[segment] = true - table.insert(matches, ":" .. dir .. segment) + return prefix_filter(fetch_subsubcommands(sub), lead) +end + +local ALL_REFS = { complete_ref, complete_pseudo_ref, complete_stash_ref } +local REV_OR_PATH = { complete_rev, complete_tracked_paths } + +---@type table +local POSITIONAL_HANDLER = { + push = { complete_remote, ALL_REFS }, + pull = { complete_remote, ALL_REFS }, + fetch = { complete_remote, ALL_REFS }, + checkout = { REV_OR_PATH }, + reset = { REV_OR_PATH }, + restore = { complete_tracked_paths }, + add = { complete_unstaged_paths }, + rm = { complete_tracked_paths }, + mv = { complete_tracked_paths }, + blame = { complete_tracked_paths }, + branch = { complete_ref }, + switch = { complete_ref }, + merge = { ALL_REFS }, + rebase = { ALL_REFS }, + ["cherry-pick"] = { ALL_REFS }, + revert = { ALL_REFS }, + tag = { ALL_REFS }, + log = { REV_OR_PATH }, + diff = { REV_OR_PATH }, + show = { complete_rev }, + ["cat-file"] = { complete_rev }, + stash = { complete_subsubcmd }, + remote = { complete_subsubcmd }, + worktree = { complete_subsubcmd }, + bisect = { complete_subsubcmd }, + submodule = { complete_subsubcmd }, +} + +---@class ow.Git.Cmd.CompleteState +---@field prior string[] -- positional and flag tokens before the current arg_lead +---@field after_separator boolean -- whether `--` appeared in prior + +---@param cmd_line string +---@return ow.Git.Cmd.CompleteState +local function parse_complete_state(cmd_line) + local rest = cmd_line:gsub("^%s*%S+%s*", "", 1) + local trailing_space = rest == "" or rest:sub(-1):match("%s") ~= nil + local tokens = vim.split(vim.trim(rest), "%s+", { trimempty = true }) + local prior = trailing_space and tokens + or vim.list_slice(tokens, 1, #tokens - 1) + local after_separator = false + for _, t in ipairs(prior) do + if t == "--" then + after_separator = true + break end end - return matches + return { prior = prior, after_separator = after_separator } +end + +---@param prior string[] -- includes the subcommand at index 1 +---@return integer +local function positional_index(prior) + local pos = 0 + for i = 2, #prior do + if prior[i]:sub(1, 1) ~= "-" then + pos = pos + 1 + end + end + return pos + 1 end ---@param arg_lead string ---@param cmd_line string ---@return string[] function M.complete(arg_lead, cmd_line, _) - local rest = cmd_line:gsub("^%s*%S+%s*", "", 1) - local words = vim.split(rest, "%s+", { trimempty = false }) - if #words > 1 then + local state = parse_complete_state(cmd_line) + local prior = state.prior + + if #prior == 0 then + return prefix_filter(git_cmds(), arg_lead) + end + + local sub = prior[1] --[[@as string]] + + if arg_lead:sub(1, 1) == "-" then + return prefix_filter(fetch_flags(sub), arg_lead) + end + + local r = repo.resolve() + if not r then return {} end - local matches = {} - for _, c in ipairs(git_cmds()) do - if c:sub(1, #arg_lead) == arg_lead then - table.insert(matches, c) - end + + if state.after_separator then + return complete_tracked_paths(r, arg_lead) end - return matches + + local handlers = POSITIONAL_HANDLER[sub] + if not handlers then + return complete_tracked_paths(r, arg_lead) + end + + local idx = positional_index(prior) + local slot = handlers[idx] or handlers[#handlers] + if not slot then + return {} + end + if type(slot) == "function" then + return slot(r, arg_lead, sub, idx) + end + local result = {} + for _, fn in ipairs(slot) do + vim.list_extend(result, fn(r, arg_lead, sub, idx)) + end + return result end +M._parse_complete_state = parse_complete_state +M._positional_index = positional_index + return M diff --git a/lua/git/repo.lua b/lua/git/repo.lua index 5080b57..e0a4a1d 100644 --- a/lua/git/repo.lua +++ b/lua/git/repo.lua @@ -179,8 +179,45 @@ function Repo:list_refs() if not out then return {} end - local refs = util.split_lines(out) - table.insert(refs, 1, "HEAD") + return util.split_lines(out) +end + +local PSEUDO_REFS = { + "HEAD", + "FETCH_HEAD", + "ORIG_HEAD", + "MERGE_HEAD", + "REBASE_HEAD", + "CHERRY_PICK_HEAD", + "REVERT_HEAD", +} + +---@return string[] +function Repo:list_pseudo_refs() + local refs = {} + for _, name in ipairs(PSEUDO_REFS) do + if name == "HEAD" or vim.uv.fs_stat(self.gitdir .. "/" .. name) then + table.insert(refs, name) + end + end + return refs +end + +---@return string[] +function Repo:list_stash_refs() + if not vim.uv.fs_stat(self.gitdir .. "/refs/stash") then + return {} + end + local refs = { "stash" } + local out = util.exec( + { "git", "stash", "list", "--pretty=format:%gd" }, + { cwd = self.worktree, silent = true } + ) + if out then + for _, entry in ipairs(util.split_lines(out)) do + table.insert(refs, entry) + end + end return refs end diff --git a/test/git/cmd_test.lua b/test/git/cmd_test.lua index 069389e..5c5cbd7 100644 --- a/test/git/cmd_test.lua +++ b/test/git/cmd_test.lua @@ -97,3 +97,35 @@ end) t.test("parse_args expands leading ~/ to home", function() t.eq(cmd.parse_args("add ~/foo"), { "add", vim.fn.expand("~/foo") }) end) + +t.test("parse_complete_state with trailing space", function() + local s = cmd._parse_complete_state("G push origin ") + t.eq(s.prior, { "push", "origin" }) + t.falsy(s.after_separator) +end) + +t.test("parse_complete_state mid-token", function() + local s = cmd._parse_complete_state("G push or") + t.eq(s.prior, { "push" }) + t.falsy(s.after_separator) +end) + +t.test("parse_complete_state empty after command", function() + local s = cmd._parse_complete_state("G ") + t.eq(s.prior, {}) + t.falsy(s.after_separator) +end) + +t.test("parse_complete_state detects -- separator", function() + local s = cmd._parse_complete_state("G log -- foo") + t.eq(s.prior, { "log", "--" }) + t.truthy(s.after_separator) +end) + +t.test("positional_index ignores flags", function() + t.eq(cmd._positional_index({ "push" }), 1) + t.eq(cmd._positional_index({ "push", "origin" }), 2) + t.eq(cmd._positional_index({ "push", "--force" }), 1) + t.eq(cmd._positional_index({ "push", "--force", "origin" }), 2) + t.eq(cmd._positional_index({ "checkout", "-b", "feature" }), 2) +end)