Skip to content

Commit 445ec69

Browse files
committed
Initial fix for 'find_by_marker'
1 parent 4c9e9f5 commit 445ec69

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

src/cedarscript_editor/cedarscript_editor.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def _update_command(self, cmd: UpdateCommand):
149149

150150
identifier_finder = find_identifier(source_info)
151151

152+
# TODO test...
153+
152154
match action:
153155
case MoveClause():
154156
# (Check parse_update_command)
@@ -164,11 +166,19 @@ def _update_command(self, cmd: UpdateCommand):
164166
# Set range_spec to cover the identifier
165167
search_range = restrict_search_range(action, target, identifier_finder)
166168

167-
marker, search_range = find_marker_or_segment(action, lines, search_range)
168-
169-
search_range = restrict_search_range_for_marker(
170-
marker, action, lines, search_range, identifier_finder
171-
)
169+
# UPDATE FUNCTION "_check_raw_id_fields_item"
170+
# FROM FILE "refactor-benchmark/checks_BaseModelAdminChecks__check_raw_id_fields_item/checks.py"
171+
# REPLACE LINE "def _check_raw_id_fields_item(self, obj, field_name, label):"
172+
# WITH CONTENT '''
173+
# @0:def _check_raw_id_fields_item(obj, field_name, label):
174+
# ''';
175+
# target = IdentifierFromFile(file_path='refactor-benchmark/checks_BaseModelAdminChecks__check_raw_id_fields_item/checks.py', identifier_type=<MarkerType.FUNCTION: 'function'>, name='_check_raw_id_fields_item', where_clause=None, offset=None)
176+
# action = ReplaceClause(region=Marker(type=<MarkerType.LINE: line>, value=def _check_raw_id_fields_item(self, obj, field_name, label):, offset=None))
177+
if search_range.line_count:
178+
marker, search_range = find_marker_or_segment(action, lines, search_range)
179+
search_range = restrict_search_range_for_marker(
180+
marker, action, lines, search_range, identifier_finder
181+
)
172182

173183
match content:
174184
case str() | [str(), *_] | (str(), *_):
@@ -309,7 +319,6 @@ def find_marker_or_segment(
309319

310320

311321
def restrict_search_range(action, target, identifier_finder: IdentifierFinder) -> RangeSpec:
312-
search_range = RangeSpec.EMPTY
313322
match target:
314323
case IdentifierFromFile() as identifier_from_file:
315324
identifier_marker = identifier_from_file.as_marker
@@ -318,12 +327,19 @@ def restrict_search_range(action, target, identifier_finder: IdentifierFinder) -
318327
raise ValueError(f"'{identifier_marker}' not found")
319328
match action:
320329
case RegionClause(region=region):
321-
match region: # BodyOrWhole | Marker | Segment
322-
case BodyOrWhole():
323-
search_range = identifier_boundaries.location_to_search_range(region)
324-
case _:
325-
search_range = identifier_boundaries.location_to_search_range(BodyOrWhole.WHOLE)
326-
return search_range
330+
match region:
331+
case BodyOrWhole() | RelativePositionType():
332+
return identifier_boundaries.location_to_search_range(region)
333+
case Marker() as inner_marker:
334+
match identifier_finder(inner_marker):
335+
case IdentifierBoundaries() as inner_boundaries:
336+
return inner_boundaries.whole
337+
case RangeSpec() as inner_range_spec:
338+
return inner_range_spec
339+
case _ as invalid: # Marker (LINE) or Segment
340+
# TODO
341+
raise ValueError(f'Not implemented: {invalid}')
342+
return RangeSpec.EMPTY
327343

328344

329345
def restrict_search_range_for_marker(

src/cedarscript_editor/tree_sitter_identifier_finder.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import Callable, TypeAlias, Sequence, NamedTuple, Iterable
33

4-
from cedarscript_ast_parser import Marker, MarkerType
4+
from cedarscript_ast_parser import Marker, MarkerType, Segment
55
from grep_ast import filename_to_lang
66
from text_manipulation.indentation_kit import get_line_indent_count
77
from text_manipulation.range_spec import IdentifierBoundaries, RangeSpec
@@ -11,28 +11,37 @@
1111

1212
_log = logging.getLogger(__name__)
1313

14-
IdentifierFinder: TypeAlias = Callable[[Marker], IdentifierBoundaries | None]
14+
IdentifierFinder: TypeAlias = Callable[[Marker], IdentifierBoundaries | RangeSpec | None]
1515

1616

17-
def find_identifier(source_info: tuple[str, str | Sequence[str]]) -> IdentifierFinder:
17+
def find_identifier(source_info: tuple[str, str | Sequence[str]], search_rage: RangeSpec = RangeSpec.EMPTY) -> IdentifierFinder:
1818
file_path = source_info[0]
1919
source = source_info[1]
2020
if not isinstance(source, str):
2121
source = '\n'.join(source)
22-
return _select_finder(file_path, source)
22+
return _select_finder(file_path, source, search_rage)
2323

2424

25-
def _select_finder(file_path: str, source: str) -> IdentifierFinder:
25+
def _select_finder(file_path: str, source: str, search_range: RangeSpec = RangeSpec.EMPTY) -> IdentifierFinder:
2626
langstr = filename_to_lang(file_path)
2727
language = get_language(langstr)
2828
parser = get_parser(langstr)
2929
_log.info(f"[select_finder] Selected {language}")
3030

3131
tree = parser.parse(bytes(source, "utf-8"))
32+
source = source.splitlines()
3233
query_info = LANG_TO_TREE_SITTER_QUERY[langstr]
3334

34-
def find_by_marker(marker: Marker) -> IdentifierBoundaries | None:
35-
return _find_identifier(language, source, tree, query_info, marker)
35+
def find_by_marker(mos: Marker | Segment) -> IdentifierBoundaries | RangeSpec | None:
36+
match mos:
37+
38+
case Marker(MarkerType.LINE) | Segment():
39+
# TODO pass IdentifierFinder to enable identifiers as start and/or end of a segment
40+
return mos.to_search_range(source, search_range).set_line_count(1) # returns RangeSpec
41+
42+
case Marker() as marker:
43+
# Returns IdentifierBoundaries
44+
return _find_identifier(language, source, tree, query_info, marker)
3645

3746
return find_by_marker
3847

@@ -66,7 +75,7 @@ def identifier(self):
6675
return self.node.text.decode("utf-8")
6776

6877

69-
def associate_identifier_parts(captures: Iterable[CaptureInfo], lines: list[str]) -> list[IdentifierBoundaries]:
78+
def associate_identifier_parts(captures: Iterable[CaptureInfo], lines: Sequence[str]) -> list[IdentifierBoundaries]:
7079
identifier_map: dict[int, IdentifierBoundaries] = {}
7180

7281
for capture in captures:
@@ -105,7 +114,7 @@ def find_parent_definition(node):
105114
return None
106115

107116

108-
def _find_identifier(language, source: str, tree, query_scm: dict[str, dict[str, str]], marker: Marker) \
117+
def _find_identifier(language, source: Sequence[str], tree, query_scm: dict[str, dict[str, str]], marker: Marker) \
109118
-> IdentifierBoundaries | None:
110119
"""
111120
Find the starting line index of a specified function in the given lines.
@@ -121,7 +130,7 @@ def _find_identifier(language, source: str, tree, query_scm: dict[str, dict[str,
121130
candidates = language.query(query_scm[marker.type].format(name=marker.value)).captures(tree.root_node)
122131
candidates: list[IdentifierBoundaries] = capture2identifier_boundaries(
123132
candidates,
124-
source.splitlines()
133+
source
125134
)
126135
except Exception as e:
127136
raise ValueError(f"Unable to capture nodes for {marker}: {e}") from e
@@ -145,7 +154,7 @@ def _find_identifier(language, source: str, tree, query_scm: dict[str, dict[str,
145154
return result
146155

147156

148-
def capture2identifier_boundaries(captures, lines: list[str]) -> list[IdentifierBoundaries]:
157+
def capture2identifier_boundaries(captures, lines: Sequence[str]) -> list[IdentifierBoundaries]:
149158
captures = [CaptureInfo(c[1], c[0]) for c in captures if not c[1].startswith('_')]
150159
unique_captures = {}
151160
for capture in captures:

src/text_manipulation/range_spec.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class RangeSpec(NamedTuple):
4040
indent: int = 0
4141

4242
def __str__(self):
43-
"""Return a string representation of the RangeSpec."""
4443
return (f'{self.start}:{self.end}' if self.as_index is None else f'%{self.as_index}') + f'@{self.indent}'
4544

4645
def __lt__(self, other):

0 commit comments

Comments
 (0)