Skip to content

Commit ec435c0

Browse files
committed
More tests
1 parent 30d860a commit ec435c0

File tree

12 files changed

+197
-371
lines changed

12 files changed

+197
-371
lines changed

src/cedarscript_editor/tree_sitter_identifier_finder.py

Lines changed: 92 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2+
from dataclasses import dataclass
23
from functools import cached_property
3-
from typing import Callable, TypeAlias, Sequence, NamedTuple, Iterable
4+
from typing import Sequence, Iterable
45

5-
from cedarscript_ast_parser import Marker, MarkerType, Segment
6+
from cedarscript_ast_parser import Marker, MarkerType, Segment, RelativeMarker
67
from grep_ast import filename_to_lang
78
from text_manipulation.indentation_kit import get_line_indent_count
8-
from text_manipulation.range_spec import IdentifierBoundaries, RangeSpec
9+
from text_manipulation.range_spec import IdentifierBoundaries, RangeSpec, ParentInfo, ParentRestriction
910
from tree_sitter_languages import get_language, get_parser
1011

1112
from .tree_sitter_identifier_queries import LANG_TO_TREE_SITTER_QUERY
@@ -18,65 +19,94 @@
1819

1920
_log = logging.getLogger(__name__)
2021

21-
"""Type alias for functions that find identifiers in source code.
22-
Takes a Marker/Segment and optional RangeSpec, returns identifier boundaries or range."""
23-
IdentifierFinder: TypeAlias = Callable[[Marker | Segment, RangeSpec | None], IdentifierBoundaries | RangeSpec | None]
22+
class IdentifierFinder:
23+
"""Finds identifiers in source code based on markers and parent restrictions.
2424
25-
26-
def find_identifier(source_info: tuple[str, str | Sequence[str]], search_rage: RangeSpec = RangeSpec.EMPTY) -> IdentifierFinder:
27-
"""Factory function that creates an identifier finder for the given source.
28-
29-
Args:
30-
source_info: Tuple of (file_path, source_content)
31-
search_rage: Optional range to limit the search scope
32-
33-
Returns:
34-
IdentifierFinder function configured for the given source
35-
"""
36-
file_path = source_info[0]
37-
source = source_info[1]
38-
if not isinstance(source, str):
39-
source = '\n'.join(source)
40-
return _select_finder(file_path, source, search_rage)
41-
42-
43-
def _select_finder(file_path: str, source: str, search_range: RangeSpec = RangeSpec.EMPTY) -> IdentifierFinder:
44-
"""Selects and configures an appropriate identifier finder for the given file.
45-
46-
Args:
25+
Attributes:
26+
lines: List of source code lines
4727
file_path: Path to the source file
48-
source: Source code content
49-
search_range: Optional range to limit the search scope
50-
51-
Returns:
52-
IdentifierFinder function configured for the file type
28+
source: Complete source code as a single string
29+
language: Tree-sitter language instance
30+
tree: Parsed tree-sitter tree
31+
query_info: Language-specific query information
5332
"""
54-
langstr = filename_to_lang(file_path)
55-
match langstr:
56-
case None:
57-
language = None
58-
query_info = None
59-
_log.info(f"[select_finder] NO LANGUAGE for `{file_path}`")
60-
case _:
61-
query_info = LANG_TO_TREE_SITTER_QUERY[langstr]
62-
language = get_language(langstr)
63-
_log.info(f"[select_finder] Selected {language}")
64-
tree = get_parser(langstr).parse(bytes(source, "utf-8"))
65-
66-
source = source.splitlines()
67-
68-
def find_by_marker(mos: Marker | Segment, search_range: RangeSpec | None = None) -> IdentifierBoundaries | RangeSpec | None:
69-
match mos:
7033

34+
def __init__(self, fname: str, source: str | Sequence[str], parent_restriction: ParentRestriction = None):
35+
self.parent_restriction = parent_restriction
36+
match source:
37+
case str() as s:
38+
self.lines = s.splitlines()
39+
case _ as lines:
40+
self.lines = lines
41+
source = '\n'.join(lines)
42+
langstr = filename_to_lang(fname)
43+
if langstr is None:
44+
self.language = None
45+
self.query_info = None
46+
_log.info(f"[IdentifierFinder] NO LANGUAGE for `{fname}`")
47+
return
48+
self.query_info: dict[str, dict[str, str]] = LANG_TO_TREE_SITTER_QUERY[langstr]
49+
self.language = get_language(langstr)
50+
_log.info(f"[IdentifierFinder] Selected {self.language}")
51+
self.tree = get_parser(langstr).parse(bytes(source, "utf-8"))
52+
53+
def __call__(self, mos: Marker | Segment, parent_restriction: ParentRestriction = None) -> IdentifierBoundaries | RangeSpec | None:
54+
parent_restriction = parent_restriction or self.parent_restriction
55+
match mos:
7156
case Marker(MarkerType.LINE) | Segment():
7257
# TODO pass IdentifierFinder to enable identifiers as start and/or end of a segment
73-
return mos.to_search_range(source, search_range).set_line_count(1) # returns RangeSpec
58+
return mos.to_search_range(self.lines, parent_restriction).set_line_count(1) # returns RangeSpec
7459

7560
case Marker() as marker:
7661
# Returns IdentifierBoundaries
77-
return _find_identifier(language, source, tree, query_info, marker)
78-
79-
return find_by_marker
62+
return self._find_identifier(marker, parent_restriction)
63+
64+
def _find_identifier(self,
65+
marker: Marker,
66+
parent_restriction: ParentRestriction
67+
) -> IdentifierBoundaries | RangeSpec | None:
68+
"""Finds an identifier in the source code using tree-sitter queries.
69+
70+
Args:
71+
language: Tree-sitter language
72+
source: List of source code lines
73+
tree: Parsed tree-sitter tree
74+
query_scm: Dictionary of queries for different identifier types
75+
marker: Type, name and offset of the identifier to find
76+
77+
Returns:
78+
IdentifierBoundaries with identifier IdentifierBoundaries with identifier start, body start, and end lines of the identifier
79+
or None if not found
80+
"""
81+
try:
82+
candidates = self.language.query(self.query_info[marker.type].format(name=marker.value)).captures(self.tree.root_node)
83+
candidates: list[IdentifierBoundaries] = [ib for ib in capture2identifier_boundaries(
84+
candidates,
85+
self.lines
86+
) if ib.match_parent(parent_restriction)]
87+
except Exception as e:
88+
raise ValueError(f"Unable to capture nodes for {marker}: {e}") from e
89+
90+
candidate_count = len(candidates)
91+
if not candidate_count:
92+
return None
93+
if candidate_count > 1 and marker.offset is None:
94+
raise ValueError(
95+
f"The {marker.type} identifier named `{marker.value}` is ambiguous (found {candidate_count} matches). "
96+
f"Choose an `OFFSET` between 0 and {candidate_count - 1} to determine how many to skip. "
97+
f"Example to reference the *last* `{marker.value}`: `OFFSET {candidate_count - 1}`"
98+
)
99+
if marker.offset and marker.offset >= candidate_count:
100+
raise ValueError(
101+
f"There are only {candidate_count} {marker.type} identifiers named `{marker.value}`, "
102+
f"but 'OFFSET' was set to {marker.offset} (you can skip at most {candidate_count - 1} of those)"
103+
)
104+
candidates.sort(key=lambda x: x.whole.start)
105+
result: IdentifierBoundaries = _get_by_offset(candidates, marker.offset or 0)
106+
match marker:
107+
case RelativeMarker(qualifier=relative_position_type):
108+
return result.location_to_search_range(relative_position_type)
109+
return result
80110

81111

82112
def _get_by_offset(obj: Sequence, offset: int):
@@ -85,7 +115,8 @@ def _get_by_offset(obj: Sequence, offset: int):
85115
return None
86116

87117

88-
class CaptureInfo(NamedTuple):
118+
@dataclass(frozen=True)
119+
class CaptureInfo:
89120
"""Container for information about a captured node from tree-sitter parsing.
90121
91122
Attributes:
@@ -120,10 +151,10 @@ def identifier(self):
120151
return self.node.text.decode("utf-8")
121152

122153
@cached_property
123-
def parents(self) -> list[tuple[str, str]]:
154+
def parents(self) -> list[ParentInfo]:
124155
"""Returns a list of (node_type, node_name) tuples representing the hierarchy.
125156
The list is ordered from immediate parent to root."""
126-
parents = []
157+
parents: list[ParentInfo] = []
127158
current = self.node.parent
128159

129160
while current:
@@ -135,7 +166,7 @@ def parents(self) -> list[tuple[str, str]]:
135166
if child.type == 'identifier' or child.type == 'name':
136167
name = child.text.decode('utf-8')
137168
break
138-
parents.append((current.type, name))
169+
parents.append(ParentInfo(name, current.type))
139170
current = current.parent
140171

141172
return parents
@@ -157,7 +188,10 @@ def associate_identifier_parts(captures: Iterable[CaptureInfo], lines: Sequence[
157188
capture_type = capture.capture_type.split('.')[-1]
158189
range_spec = capture.to_range_spec(lines)
159190
if capture_type == 'definition':
160-
identifier_map[range_spec.start] = IdentifierBoundaries(range_spec)
191+
identifier_map[range_spec.start] = IdentifierBoundaries(
192+
whole=range_spec,
193+
parents=capture.parents
194+
)
161195

162196
else:
163197
parent = find_parent_definition(capture.node)
@@ -190,48 +224,6 @@ def find_parent_definition(node):
190224
return None
191225

192226

193-
def _find_identifier(language, source: Sequence[str], tree, query_scm: dict[str, dict[str, str]], marker: Marker) -> IdentifierBoundaries | None:
194-
"""Finds an identifier in the source code using tree-sitter queries.
195-
196-
Args:
197-
language: Tree-sitter language
198-
source: List of source code lines
199-
tree: Parsed tree-sitter tree
200-
query_scm: Dictionary of queries for different identifier types
201-
marker: Type, name and offset of the identifier to find
202-
203-
Returns:
204-
IdentifierBoundaries with identifier IdentifierBoundaries with identifier start, body start, and end lines of the identifier
205-
or None if not found
206-
"""
207-
try:
208-
candidates = language.query(query_scm[marker.type].format(name=marker.value)).captures(tree.root_node)
209-
candidates: list[IdentifierBoundaries] = capture2identifier_boundaries(
210-
candidates,
211-
source
212-
)
213-
except Exception as e:
214-
raise ValueError(f"Unable to capture nodes for {marker}: {e}") from e
215-
216-
candidate_count = len(candidates)
217-
if not candidate_count:
218-
return None
219-
if candidate_count > 1 and marker.offset is None:
220-
raise ValueError(
221-
f"The {marker.type} identifier named `{marker.value}` is ambiguous (found {candidate_count} matches). "
222-
f"Choose an `OFFSET` between 0 and {candidate_count - 1} to determine how many to skip. "
223-
f"Example to reference the *last* `{marker.value}`: `OFFSET {candidate_count - 1}`"
224-
)
225-
if marker.offset and marker.offset >= candidate_count:
226-
raise ValueError(
227-
f"There are only {candidate_count} {marker.type} identifiers named `{marker.value}`, "
228-
f"but 'OFFSET' was set to {marker.offset} (you can skip at most {candidate_count - 1} of those)"
229-
)
230-
candidates.sort(key=lambda x: x.whole.start)
231-
result: IdentifierBoundaries = _get_by_offset(candidates, marker.offset or 0)
232-
return result
233-
234-
235227
def capture2identifier_boundaries(captures, lines: Sequence[str]) -> list[IdentifierBoundaries]:
236228
"""Converts raw tree-sitter captures to IdentifierBoundaries objects.
237229

tests/corpus/chat.make-top-level-from-method.1/chat.xml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,27 @@ Here's the CEDARScript script:
1212

1313
```CEDARScript
1414
-- 1. Move the `_candidate()` method from the `A` class, placing it at the top level, just before the line where its class starts
15-
UPDATE CLASS "A" -- As method `calc1` exists in more than 1 class, we specify which one we want to reference
16-
FROM FILE "main.py"
15+
UPDATE CLASS "A"
16+
FROM FILE "main.py"
1717
MOVE FUNCTION "_candidate"
1818
INSERT BEFORE CLASS "A"
19-
RELATIVE INDENTATION 0;
19+
RELATIVE INDENTATION 0;
2020

2121
-- 2. Update the copied function to remove references to `self`
2222
UPDATE FUNCTION "_candidate"
23-
FROM FILE "main.py"
23+
FROM FILE "main.py"
2424
REPLACE LINE "def _candidate(self, a, b, c):"
2525
WITH CONTENT '''
2626
@0:def _candidate(a, b, c):
2727
''';
2828

2929
-- 3. Update ALL call sites of the method `_candidate()` to call the new top-level function with the same name
3030
UPDATE FUNCTION "_check"
31-
FROM FILE "main.py"
31+
FROM FILE "main.py"
3232
REPLACE LINE 'self._candidate(' -- There's only 1 call site and this line marker is unambiguous, so it's more concise to just replace a LINE instead of a SEGMENT
3333
WITH CONTENT '''
3434
@0:_candidate(
3535
''';
3636
```
3737
</assistant>
38-
</train>
38+
</train>

0 commit comments

Comments
 (0)