diff --git a/lua/nvim-treesitter-textobjects/_range.lua b/lua/nvim-treesitter-textobjects/_range.lua new file mode 100644 index 00000000..b7024dc3 --- /dev/null +++ b/lua/nvim-treesitter-textobjects/_range.lua @@ -0,0 +1,207 @@ +--- Copy of vim.treesitter._range +--- TODO: replace with `vim.Range` when we drop support for 0.11 + +---@diagnostic disable: duplicate-doc-field +---@diagnostic disable: duplicate-doc-alias + +local api = vim.api + +local M = {} + +---@class Range2 +---@inlinedoc +---@field [1] integer start row +---@field [2] integer end row + +---@class Range4 +---@inlinedoc +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer end row +---@field [4] integer end column + +---@class Range6 +---@inlinedoc +---@field [1] integer start row +---@field [2] integer start column +---@field [3] integer start bytes +---@field [4] integer end row +---@field [5] integer end column +---@field [6] integer end bytes + +---@alias Range Range2|Range4|Range6 + +---@param a_row integer +---@param a_col integer +---@param b_row integer +---@param b_col integer +---@return integer +--- 1: a > b +--- 0: a == b +--- -1: a < b +local function cmp_pos(a_row, a_col, b_row, b_col) + if a_row == b_row then + if a_col > b_col then + return 1 + elseif a_col < b_col then + return -1 + else + return 0 + end + elseif a_row > b_row then + return 1 + end + + return -1 +end + +M.cmp_pos = { + lt = function(...) + return cmp_pos(...) == -1 + end, + le = function(...) + return cmp_pos(...) ~= 1 + end, + gt = function(...) + return cmp_pos(...) == 1 + end, + ge = function(...) + return cmp_pos(...) ~= -1 + end, + eq = function(...) + return cmp_pos(...) == 0 + end, + ne = function(...) + return cmp_pos(...) ~= 0 + end, +} + +setmetatable(M.cmp_pos, { __call = cmp_pos }) + +---Check if a variable is a valid range object +---@param r any +---@return boolean +function M.validate(r) + if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then + return false + end + + for _, e in + ipairs(r --[[@as any[] ]]) + do + if type(e) ~= 'number' then + return false + end + end + + return true +end + +---@param r1 Range +---@param r2 Range +---@return boolean +function M.intercepts(r1, r2) + local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) + local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2) + + -- r1 is above r2 + if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then + return false + end + + -- r1 is below r2 + if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then + return false + end + + return true +end + +---@param r1 Range6 +---@param r2 Range6 +---@return Range6? +function M.intersection(r1, r2) + if not M.intercepts(r1, r2) then + return nil + end + local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1 + local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1 + return { rs[1], rs[2], rs[3], re[4], re[5], re[6] } +end + +---@param r Range +---@return integer, integer, integer, integer +function M.unpack4(r) + if #r == 2 then + return r[1], 0, r[2], 0 + end + local off_1 = #r == 6 and 1 or 0 + return r[1], r[2], r[3 + off_1], r[4 + off_1] +end + +---@param r Range6 +---@return integer, integer, integer, integer, integer, integer +function M.unpack6(r) + return r[1], r[2], r[3], r[4], r[5], r[6] +end + +---@param r1 Range +---@param r2 Range +---@return boolean whether r1 contains r2 +function M.contains(r1, r2) + local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1) + local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2) + + -- start doesn't fit + if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then + return false + end + + -- end doesn't fit + if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then + return false + end + + return true +end + +--- @param source integer|string +--- @param index integer +--- @return integer +local function get_offset(source, index) + if index == 0 then + return 0 + end + + if type(source) == 'number' then + return api.nvim_buf_get_offset(source, index) + end + + local byte = 0 + local next_offset = source:gmatch('()\n') + local line = 1 + while line <= index do + byte = next_offset() --[[@as integer]] + line = line + 1 + end + + return byte +end + +---@param source integer|string +---@param range Range +---@return Range6 +function M.add_bytes(source, range) + if type(range) == 'table' and #range == 6 then + return range --[[@as Range6]] + end + + local start_row, start_col, end_row, end_col = M.unpack4(range) + -- TODO(vigoux): proper byte computation here, and account for EOL ? + local start_byte = get_offset(source, start_row) + start_col + local end_byte = get_offset(source, end_row) + end_col + + return { start_row, start_col, start_byte, end_row, end_col, end_byte } +end + +return M diff --git a/lua/nvim-treesitter-textobjects/move.lua b/lua/nvim-treesitter-textobjects/move.lua index 9721547b..98b6b919 100644 --- a/lua/nvim-treesitter-textobjects/move.lua +++ b/lua/nvim-treesitter-textobjects/move.lua @@ -3,8 +3,9 @@ local api = vim.api local shared = require('nvim-treesitter-textobjects.shared') local repeatable_move = require('nvim-treesitter-textobjects.repeatable_move') local global_config = require('nvim-treesitter-textobjects.config') +local ts_range = vim.treesitter._range or require('nvim-treesitter-textobjects._range') ----@param range Range4? +---@param range Range? ---@param goto_end boolean ---@param avoid_set_jump boolean local function goto_node(range, goto_end, avoid_set_jump) @@ -16,7 +17,7 @@ local function goto_node(range, goto_end, avoid_set_jump) vim.cmd("normal! m'") end ---@type integer, integer, integer, integer - local start_row, start_col, end_row, end_col = unpack(range) + local start_row, start_col, end_row, end_col = ts_range.unpack4(range) -- Enter visual mode if we are in operator pending mode -- If we don't do this, it will miss the last character. @@ -89,13 +90,13 @@ local function move(opts, query_strings, query_group) end ---@param start_ boolean - ---@param range Range6 + ---@param range Range ---@return boolean local function filter_function(start_, range) local row, col = unpack(api.nvim_win_get_cursor(winid)) --[[@as integer, integer]] row = row - 1 -- nvim_win_get_cursor is (1,0)-indexed ---@type integer, integer, integer, integer, integer, integer - local start_row, start_col, _, end_row, end_col, _ = unpack(range) + local start_row, start_col, end_row, end_col = ts_range.unpack4(range) if not start_ then if end_col == 0 then @@ -146,7 +147,7 @@ local function move(opts, query_strings, query_group) end end end - goto_node(best_range and shared.torange4(best_range), not best_start, not config.set_jumps) + goto_node(best_range and best_range, not best_start, not config.set_jumps) end end diff --git a/lua/nvim-treesitter-textobjects/select.lua b/lua/nvim-treesitter-textobjects/select.lua index ea8b4ae9..d465260e 100644 --- a/lua/nvim-treesitter-textobjects/select.lua +++ b/lua/nvim-treesitter-textobjects/select.lua @@ -1,12 +1,13 @@ local api = vim.api local global_config = require('nvim-treesitter-textobjects.config') local shared = require('nvim-treesitter-textobjects.shared') +local ts_range = vim.treesitter._range or require('nvim-treesitter-textobjects._range') ----@param range Range4 +---@param range Range ---@param selection_mode TSTextObjects.SelectionMode local function update_selection(range, selection_mode) ---@type integer, integer, integer, integer - local start_row, start_col, end_row, end_col = unpack(range) + local start_row, start_col, end_row, end_col = ts_range.unpack4(range) selection_mode = selection_mode or 'v' -- enter visual mode if normal or operator-pending (no) mode @@ -105,11 +106,11 @@ local function previous_position(bufnr, row, col) end ---@param bufnr integer ----@param range Range4 +---@param range Range ---@param selection_mode string ---@return Range4? local function include_surrounding_whitespace(bufnr, range, selection_mode) - local start_row, start_col, end_row, end_col = unpack(range) ---@type integer, integer, integer, integer + local start_row, start_col, end_row, end_col = ts_range.unpack4(range) ---@type integer, integer, integer, integer local extended = false local position = { end_row, end_col - 1 } local next = next_position(bufnr, unpack(position)) @@ -166,7 +167,6 @@ function M.select_textobject(query_string, query_group) { lookahead = lookahead, lookbehind = lookbehind } ) if range6 then - local range4 = shared.torange4(range6) local selection_mode = M.detect_selection_mode(query_string) if function_or_value_to_value(surrounding_whitespace, { @@ -174,11 +174,12 @@ function M.select_textobject(query_string, query_group) selection_mode = selection_mode, }) then - ---@diagnostic disable-next-line: cast-local-type - range4 = include_surrounding_whitespace(bufnr, range4, selection_mode) - end - if range4 then - update_selection(range4, selection_mode) + local range4 = include_surrounding_whitespace(bufnr, range6, selection_mode) + if range4 then + update_selection(range4, selection_mode) + end + else + update_selection(range6, selection_mode) end end end diff --git a/lua/nvim-treesitter-textobjects/shared.lua b/lua/nvim-treesitter-textobjects/shared.lua index c5947adf..aac66bb3 100644 --- a/lua/nvim-treesitter-textobjects/shared.lua +++ b/lua/nvim-treesitter-textobjects/shared.lua @@ -1,5 +1,5 @@ local ts = vim.treesitter -local add_bytes = require('vim.treesitter._range').add_bytes +local ts_range = ts._range or require('nvim-treesitter-textobjects._range') -- lookup table for parserless queries local lang_to_parser = { ecma = 'javascript', jsx = 'javascript' } @@ -75,7 +75,7 @@ local get_query_matches = memoize(function(bufnr, query_group, root, root_lang) if query_name ~= nil then local path = vim.split(query_name, '%.') if metadata[id] and metadata[id].range then - insert_to_path(prepared_match, path, add_bytes(bufnr, metadata[id].range)) + insert_to_path(prepared_match, path, ts_range.add_bytes(bufnr, metadata[id].range)) else local srow, scol, sbyte, erow, ecol, ebyte = nodes[1]:range(true) if #nodes > 1 then @@ -212,44 +212,12 @@ function M.find_best_range(bufnr, capture_string, query_group, filter_predicate, end -- TODO: replace with `vim.Range:has(vim.Pos)` when we drop support for nvim 0.11 ----@param range Range4 ----@param row integer +---@param range Range +---@param line integer ---@param col integer ---@return boolean -local function is_in_range(range, row, col) - local start_row, start_col, end_row, end_col = unpack(range) ---@type integer, integer, integer, integer - end_col = end_col - 1 - - local is_in_rows = start_row <= row and end_row >= row - local is_after_start_col_if_needed = true - if start_row == row then - is_after_start_col_if_needed = col >= start_col - end - local is_before_end_col_if_needed = true - if end_row == row then - is_before_end_col_if_needed = col <= end_col - end - return is_in_rows and is_after_start_col_if_needed and is_before_end_col_if_needed -end - --- TODO: replace with `vim.Range:has(vim.Range)` when we drop support for 0.11 ----@param outer Range4 ----@param inner Range4 ----@return boolean -local function contains(outer, inner) - local start_row_o, start_col_o, end_row_o, end_col_o = unpack(outer) ---@type integer, integer, integer, integer - local start_row_i, start_col_i, end_row_i, end_col_i = unpack(inner) ---@type integer, integer, integer, integer - - return start_row_o <= start_row_i - and start_col_o <= start_col_i - and end_row_o >= end_row_i - and end_col_o >= end_col_i -end - ----@param range Range6 ----@return Range4 -function M.torange4(range) - return { range[1], range[2], range[4], range[5] } +local function is_in_range(range, line, col) + return ts_range.contains(range, { line, col, line, col + 1 }) end --- Get the best `TSTextObjects.Range` at a given point @@ -274,7 +242,7 @@ local function best_range_at_point(ranges, row, col, opts) local lookbehind_earliest_start ---@type integer for _, range in pairs(ranges) do - if range and is_in_range(M.torange4(range), row, col) then + if range and is_in_range(range, row, col) then local length = range[6] - range[3] if not range_length or length < range_length then smallest_range = range @@ -385,7 +353,7 @@ function M.textobject_at_point(query_string, query_group, bufnr, pos, opts) local ranges_within_outer = {} for _, range in ipairs(ranges) do - if contains(M.torange4(range_outer), M.torange4(range)) then + if ts_range.contains(range_outer, range) then table.insert(ranges_within_outer, range) end end diff --git a/lua/nvim-treesitter-textobjects/swap.lua b/lua/nvim-treesitter-textobjects/swap.lua index d514ac8b..c3c4103a 100644 --- a/lua/nvim-treesitter-textobjects/swap.lua +++ b/lua/nvim-treesitter-textobjects/swap.lua @@ -1,16 +1,17 @@ local api = vim.api local shared = require('nvim-treesitter-textobjects.shared') +local ts_range = vim.treesitter._range or require('nvim-treesitter-textobjects._range') ---@class TSTextObjects.LspLocation ---@field line integer ---@field character integer ----@param range Range4 +---@param range Range ---@return lsp.Range local function to_lsp_range(range) ---@type integer, integer, integer, integer - local start_row, start_col, end_row, end_col = unpack(range) + local start_row, start_col, end_row, end_col = ts_range.unpack4(range) return { start = { line = start_row, @@ -23,8 +24,8 @@ local function to_lsp_range(range) } end ----@param range1 Range4 ----@param range2 Range4 +---@param range1 Range +---@param range2 Range ---@param bufnr integer ---@param cursor_to_second any local function swap_nodes(range1, range2, bufnr, cursor_to_second) @@ -32,8 +33,11 @@ local function swap_nodes(range1, range2, bufnr, cursor_to_second) return end - local text1 = api.nvim_buf_get_text(bufnr, range1[1], range1[2], range1[3], range1[4], {}) - local text2 = api.nvim_buf_get_text(bufnr, range2[1], range2[2], range2[3], range2[4], {}) + local start_row1, start_col1, end_row1, end_col1 = ts_range.unpack4(range1) + local start_row2, start_col2, end_row2, end_col2 = ts_range.unpack4(range2) + + local text1 = api.nvim_buf_get_text(bufnr, start_row1, start_col1, end_row1, end_col1, {}) + local text2 = api.nvim_buf_get_text(bufnr, start_row2, start_col2, end_row2, end_col2, {}) local lsp_range1 = to_lsp_range(range1) local lsp_range2 = to_lsp_range(range2) @@ -86,11 +90,11 @@ local function swap_nodes(range1, range2, bufnr, cursor_to_second) end end ----@param range1 Range6 ----@param range2 Range6 +---@param range1 Range +---@param range2 Range local function range_eq(range1, range2) - local srow1, scol1, _, erow1, ecol1, _ = unpack(range1) ---@type integer, integer, integer, integer, integer, integer - local srow2, scol2, _, erow2, ecol2, _ = unpack(range2) ---@type integer, integer, integer, integer, integer, integer + local srow1, scol1, erow1, ecol1 = ts_range.unpack4(range1) ---@type integer, integer, integer, integer, integer, integer + local srow2, scol2, erow2, ecol2 = ts_range.unpack4(range2) ---@type integer, integer, integer, integer, integer, integer return srow1 == srow2 and scol1 == scol2 and erow1 == erow2 and ecol1 == ecol2 end @@ -188,12 +192,7 @@ local function swap_textobject(query_strings, query_group, direction) and next_textobject(textobject_range, query_string, query_group, bufnr) or previous_textobject(textobject_range, query_string, query_group, bufnr) if adjacent then - swap_nodes( - shared.torange4(textobject_range), - shared.torange4(adjacent), - bufnr, - 'yes, set cursor!' - ) + swap_nodes(textobject_range, adjacent, bufnr, 'yes, set cursor!') end end end diff --git a/tests/select/lua/lookback.lua b/tests/select/lua/lookback.lua new file mode 100644 index 00000000..13502acf --- /dev/null +++ b/tests/select/lua/lookback.lua @@ -0,0 +1,8 @@ +--- selecting @function.inner should look back when it's within @function.outer. +local function a() + print('foo') +end -- call here to test + +local function b() + print('bar') +end diff --git a/tests/select/lua_spec.lua b/tests/select/lua_spec.lua new file mode 100644 index 00000000..2f05856c --- /dev/null +++ b/tests/select/lua_spec.lua @@ -0,0 +1,16 @@ +local Runner = require('tests.select.common').Runner + +local run = Runner:new(it, 'tests/select/lua', { + tabstop = 2, + shiftwidth = 2, + softtabstop = 0, + expandtab = true, +}) + +describe('Look back if within @function.outer range (Lua):', function() + run:compare_cmds('lookback.lua', { row = 4, col = 0, cmds = { 'dim', 'k^D' } }) +end) + +describe('Look forward if outside @function.outer range (Lua):', function() + run:compare_cmds('lookback.lua', { row = 5, col = 0, cmds = { 'dim', '2j^D' } }) +end)