Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions lua/nvim-treesitter-textobjects/_range.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
--- Copy of vim.treesitter._range
--- TODO: replace with `vim.Range` when we drop support for 0.11

---@diagnostic disable: duplicate-doc-field
---@diagnostic disable: duplicate-doc-alias

local api = vim.api

local M = {}

---@class Range2
---@inlinedoc
---@field [1] integer start row
---@field [2] integer end row

---@class Range4
---@inlinedoc
---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer end row
---@field [4] integer end column

---@class Range6
---@inlinedoc
---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer start bytes
---@field [4] integer end row
---@field [5] integer end column
---@field [6] integer end bytes

---@alias Range Range2|Range4|Range6

---@param a_row integer
---@param a_col integer
---@param b_row integer
---@param b_col integer
---@return integer
--- 1: a > b
--- 0: a == b
--- -1: a < b
local function cmp_pos(a_row, a_col, b_row, b_col)
if a_row == b_row then
if a_col > b_col then
return 1
elseif a_col < b_col then
return -1
else
return 0
end
elseif a_row > b_row then
return 1
end

return -1
end

M.cmp_pos = {
lt = function(...)
return cmp_pos(...) == -1
end,
le = function(...)
return cmp_pos(...) ~= 1
end,
gt = function(...)
return cmp_pos(...) == 1
end,
ge = function(...)
return cmp_pos(...) ~= -1
end,
eq = function(...)
return cmp_pos(...) == 0
end,
ne = function(...)
return cmp_pos(...) ~= 0
end,
}

setmetatable(M.cmp_pos, { __call = cmp_pos })

---Check if a variable is a valid range object
---@param r any
---@return boolean
function M.validate(r)
if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
return false
end

for _, e in
ipairs(r --[[@as any[] ]])
do
if type(e) ~= 'number' then
return false
end
end

return true
end

---@param r1 Range
---@param r2 Range
---@return boolean
function M.intercepts(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)

-- r1 is above r2
if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
return false
end

-- r1 is below r2
if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
return false
end

return true
end

---@param r1 Range6
---@param r2 Range6
---@return Range6?
function M.intersection(r1, r2)
if not M.intercepts(r1, r2) then
return nil
end
local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1
return { rs[1], rs[2], rs[3], re[4], re[5], re[6] }
end

---@param r Range
---@return integer, integer, integer, integer
function M.unpack4(r)
if #r == 2 then
return r[1], 0, r[2], 0
end
local off_1 = #r == 6 and 1 or 0
return r[1], r[2], r[3 + off_1], r[4 + off_1]
end

---@param r Range6
---@return integer, integer, integer, integer, integer, integer
function M.unpack6(r)
return r[1], r[2], r[3], r[4], r[5], r[6]
end

---@param r1 Range
---@param r2 Range
---@return boolean whether r1 contains r2
function M.contains(r1, r2)
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)

-- start doesn't fit
if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
return false
end

-- end doesn't fit
if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
return false
end

return true
end

--- @param source integer|string
--- @param index integer
--- @return integer
local function get_offset(source, index)
if index == 0 then
return 0
end

if type(source) == 'number' then
return api.nvim_buf_get_offset(source, index)
end

local byte = 0
local next_offset = source:gmatch('()\n')
local line = 1
while line <= index do
byte = next_offset() --[[@as integer]]
line = line + 1
end

return byte
end

---@param source integer|string
---@param range Range
---@return Range6
function M.add_bytes(source, range)
if type(range) == 'table' and #range == 6 then
return range --[[@as Range6]]
end

local start_row, start_col, end_row, end_col = M.unpack4(range)
-- TODO(vigoux): proper byte computation here, and account for EOL ?
local start_byte = get_offset(source, start_row) + start_col
local end_byte = get_offset(source, end_row) + end_col

return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end

return M
11 changes: 6 additions & 5 deletions lua/nvim-treesitter-textobjects/move.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ local api = vim.api
local shared = require('nvim-treesitter-textobjects.shared')
local repeatable_move = require('nvim-treesitter-textobjects.repeatable_move')
local global_config = require('nvim-treesitter-textobjects.config')
local ts_range = vim.treesitter._range or require('nvim-treesitter-textobjects._range')

---@param range Range4?
---@param range Range?
---@param goto_end boolean
---@param avoid_set_jump boolean
local function goto_node(range, goto_end, avoid_set_jump)
Expand All @@ -16,7 +17,7 @@ local function goto_node(range, goto_end, avoid_set_jump)
vim.cmd("normal! m'")
end
---@type integer, integer, integer, integer
local start_row, start_col, end_row, end_col = unpack(range)
local start_row, start_col, end_row, end_col = ts_range.unpack4(range)

-- Enter visual mode if we are in operator pending mode
-- If we don't do this, it will miss the last character.
Expand Down Expand Up @@ -89,13 +90,13 @@ local function move(opts, query_strings, query_group)
end

---@param start_ boolean
---@param range Range6
---@param range Range
---@return boolean
local function filter_function(start_, range)
local row, col = unpack(api.nvim_win_get_cursor(winid)) --[[@as integer, integer]]
row = row - 1 -- nvim_win_get_cursor is (1,0)-indexed
---@type integer, integer, integer, integer, integer, integer
local start_row, start_col, _, end_row, end_col, _ = unpack(range)
local start_row, start_col, end_row, end_col = ts_range.unpack4(range)

if not start_ then
if end_col == 0 then
Expand Down Expand Up @@ -146,7 +147,7 @@ local function move(opts, query_strings, query_group)
end
end
end
goto_node(best_range and shared.torange4(best_range), not best_start, not config.set_jumps)
goto_node(best_range and best_range, not best_start, not config.set_jumps)
end
end

Expand Down
21 changes: 11 additions & 10 deletions lua/nvim-treesitter-textobjects/select.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
local api = vim.api
local global_config = require('nvim-treesitter-textobjects.config')
local shared = require('nvim-treesitter-textobjects.shared')
local ts_range = vim.treesitter._range or require('nvim-treesitter-textobjects._range')

---@param range Range4
---@param range Range
---@param selection_mode TSTextObjects.SelectionMode
local function update_selection(range, selection_mode)
---@type integer, integer, integer, integer
local start_row, start_col, end_row, end_col = unpack(range)
local start_row, start_col, end_row, end_col = ts_range.unpack4(range)
selection_mode = selection_mode or 'v'

-- enter visual mode if normal or operator-pending (no) mode
Expand Down Expand Up @@ -105,11 +106,11 @@ local function previous_position(bufnr, row, col)
end

---@param bufnr integer
---@param range Range4
---@param range Range
---@param selection_mode string
---@return Range4?
local function include_surrounding_whitespace(bufnr, range, selection_mode)
local start_row, start_col, end_row, end_col = unpack(range) ---@type integer, integer, integer, integer
local start_row, start_col, end_row, end_col = ts_range.unpack4(range) ---@type integer, integer, integer, integer
local extended = false
local position = { end_row, end_col - 1 }
local next = next_position(bufnr, unpack(position))
Expand Down Expand Up @@ -166,19 +167,19 @@ function M.select_textobject(query_string, query_group)
{ lookahead = lookahead, lookbehind = lookbehind }
)
if range6 then
local range4 = shared.torange4(range6)
local selection_mode = M.detect_selection_mode(query_string)
if
function_or_value_to_value(surrounding_whitespace, {
query_string = query_string,
selection_mode = selection_mode,
})
then
---@diagnostic disable-next-line: cast-local-type
range4 = include_surrounding_whitespace(bufnr, range4, selection_mode)
end
if range4 then
update_selection(range4, selection_mode)
local range4 = include_surrounding_whitespace(bufnr, range6, selection_mode)
if range4 then
update_selection(range4, selection_mode)
end
else
update_selection(range6, selection_mode)
end
end
end
Expand Down
48 changes: 8 additions & 40 deletions lua/nvim-treesitter-textobjects/shared.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
local ts = vim.treesitter
local add_bytes = require('vim.treesitter._range').add_bytes
local ts_range = ts._range or require('nvim-treesitter-textobjects._range')

-- lookup table for parserless queries
local lang_to_parser = { ecma = 'javascript', jsx = 'javascript' }
Expand Down Expand Up @@ -75,7 +75,7 @@ local get_query_matches = memoize(function(bufnr, query_group, root, root_lang)
if query_name ~= nil then
local path = vim.split(query_name, '%.')
if metadata[id] and metadata[id].range then
insert_to_path(prepared_match, path, add_bytes(bufnr, metadata[id].range))
insert_to_path(prepared_match, path, ts_range.add_bytes(bufnr, metadata[id].range))
else
local srow, scol, sbyte, erow, ecol, ebyte = nodes[1]:range(true)
if #nodes > 1 then
Expand Down Expand Up @@ -212,44 +212,12 @@ function M.find_best_range(bufnr, capture_string, query_group, filter_predicate,
end

-- TODO: replace with `vim.Range:has(vim.Pos)` when we drop support for nvim 0.11
---@param range Range4
---@param row integer
---@param range Range
---@param line integer
---@param col integer
---@return boolean
local function is_in_range(range, row, col)
local start_row, start_col, end_row, end_col = unpack(range) ---@type integer, integer, integer, integer
end_col = end_col - 1

local is_in_rows = start_row <= row and end_row >= row
local is_after_start_col_if_needed = true
if start_row == row then
is_after_start_col_if_needed = col >= start_col
end
local is_before_end_col_if_needed = true
if end_row == row then
is_before_end_col_if_needed = col <= end_col
end
return is_in_rows and is_after_start_col_if_needed and is_before_end_col_if_needed
end

-- TODO: replace with `vim.Range:has(vim.Range)` when we drop support for 0.11
---@param outer Range4
---@param inner Range4
---@return boolean
local function contains(outer, inner)
local start_row_o, start_col_o, end_row_o, end_col_o = unpack(outer) ---@type integer, integer, integer, integer
local start_row_i, start_col_i, end_row_i, end_col_i = unpack(inner) ---@type integer, integer, integer, integer

return start_row_o <= start_row_i
and start_col_o <= start_col_i
and end_row_o >= end_row_i
and end_col_o >= end_col_i
end

---@param range Range6
---@return Range4
function M.torange4(range)
return { range[1], range[2], range[4], range[5] }
local function is_in_range(range, line, col)
return ts_range.contains(range, { line, col, line, col + 1 })
end

--- Get the best `TSTextObjects.Range` at a given point
Expand All @@ -274,7 +242,7 @@ local function best_range_at_point(ranges, row, col, opts)
local lookbehind_earliest_start ---@type integer

for _, range in pairs(ranges) do
if range and is_in_range(M.torange4(range), row, col) then
if range and is_in_range(range, row, col) then
local length = range[6] - range[3]
if not range_length or length < range_length then
smallest_range = range
Expand Down Expand Up @@ -385,7 +353,7 @@ function M.textobject_at_point(query_string, query_group, bufnr, pos, opts)

local ranges_within_outer = {}
for _, range in ipairs(ranges) do
if contains(M.torange4(range_outer), M.torange4(range)) then
if ts_range.contains(range_outer, range) then
table.insert(ranges_within_outer, range)
end
end
Expand Down
Loading