Skip to content

Commit b101f6d

Browse files
committed
Fix significant bugs
1 parent ec435c0 commit b101f6d

File tree

3 files changed

+105
-51
lines changed

3 files changed

+105
-51
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
]
2323
keywords = ["cedarscript", "code-editing", "refactoring", "code-analysis", "sql-like", "ai-assisted-development"]
2424
dependencies = [
25-
"cedarscript-ast-parser>=0.2.10",
25+
"cedarscript-ast-parser==0.2.11",
2626
"grep-ast==0.3.3",
2727
"tree-sitter-languages==1.10.2",
2828
]

src/cedarscript_editor/cedarscript_editor.py

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
SelectCommand, CreateCommand, IdentifierFromFile, Segment, Marker, MoveClause, DeleteClause, \
77
InsertClause, ReplaceClause, EditingAction, BodyOrWhole, RegionClause, MarkerType
88
from cedarscript_ast_parser.cedarscript_ast_parser import MarkerCompatible, RelativeMarker, \
9-
RelativePositionType
9+
RelativePositionType, Region, SingleFileClause
1010
from text_manipulation import (
1111
IndentationInfo, IdentifierBoundaries, RangeSpec, read_file, write_file, bow_to_search_range
1212
)
1313

14-
from .tree_sitter_identifier_finder import IdentifierFinder, find_identifier
14+
from .tree_sitter_identifier_finder import IdentifierFinder
1515

1616

1717
class CEDARScriptEditorException(Exception):
@@ -105,26 +105,21 @@ def _update_command(self, cmd: UpdateCommand):
105105
src = read_file(file_path)
106106
lines = src.splitlines()
107107

108-
source_info: tuple[str | bytes, str | Sequence[str]] = (file_path, src)
109-
110-
identifier_finder = find_identifier(source_info)
111-
112-
# TODO test...
108+
identifier_finder = IdentifierFinder(file_path, src, RangeSpec.EMPTY)
113109

110+
search_range = RangeSpec.EMPTY
114111
match action:
115112
case MoveClause():
116-
# (Check parse_update_command)
117-
# when action=MoveClause example (MOVE roll TO AFTER score):
118-
# action.deleteclause.region=WHOLE
119-
# action.as_marker = action.insertclause.as_marker
120-
# action.insertclause.insert_position=FUNCTION(score)
121-
# target.as_marker = FUNCTION(roll) (the one to delete)
122-
search_range = RangeSpec.EMPTY
123-
move_src_range = restrict_search_range(action, target, identifier_finder, lines)
113+
# READ + DELETE region : action.region (PARENT RESTRICTION: target.as_marker)
114+
move_src_range = restrict_search_range(action.region, target, identifier_finder, lines)
115+
# WRITE region: action.insert_position
116+
search_range = restrict_search_range(action.insert_position, None, identifier_finder, lines)
124117
case _:
125118
move_src_range = None
126119
# Set range_spec to cover the identifier
127-
search_range = restrict_search_range(action, target, identifier_finder, lines)
120+
match action:
121+
case RegionClause(region=region):
122+
search_range = restrict_search_range(action.region, target, identifier_finder, lines)
128123

129124
# UPDATE FUNCTION "_check_raw_id_fields_item"
130125
# FROM FILE "refactor-benchmark/checks_BaseModelAdminChecks__check_raw_id_fields_item/checks.py"
@@ -134,11 +129,17 @@ def _update_command(self, cmd: UpdateCommand):
134129
# ''';
135130
# target = IdentifierFromFile(file_path='refactor-benchmark/checks_BaseModelAdminChecks__check_raw_id_fields_item/checks.py', identifier_type=<MarkerType.FUNCTION: 'function'>, name='_check_raw_id_fields_item', where_clause=None, offset=None)
136131
# action = ReplaceClause(region=Marker(type=<MarkerType.LINE: line>, value=def _check_raw_id_fields_item(self, obj, field_name, label):, offset=None))
137-
if search_range.line_count and not isinstance(action.region if hasattr(action, 'region') else None, Segment):
138-
marker, search_range = find_marker_or_segment(action, lines, search_range)
139-
search_range = restrict_search_range_for_marker(
140-
marker, action, lines, search_range, identifier_finder
141-
)
132+
if search_range.line_count:
133+
match action:
134+
case RegionClause(region=Segment()):
135+
pass
136+
case RegionClause(region=Marker()) if action.region.type in [MarkerType.FUNCTION, MarkerType.METHOD, MarkerType.CLASS]:
137+
pass
138+
case _:
139+
marker, search_range = find_marker_or_segment(action, lines, search_range)
140+
search_range = restrict_search_range_for_marker(
141+
marker, action, lines, search_range, identifier_finder
142+
)
142143

143144
match content:
144145
case str() | [str(), *_] | (str(), *_):
@@ -263,7 +264,7 @@ def find_index_range_for_region(region: BodyOrWhole | Marker | Segment | Relativ
263264
pass
264265
case _:
265266
# TODO transform to RangeSpec
266-
mos = find_identifier(("TODO?.py", lines))(mos).body
267+
mos = IdentifierFinder("TODO?.py", lines, RangeSpec.EMPTY)(mos, search_range).body
267268
index_range = mos.to_search_range(
268269
lines,
269270
search_range.start if search_range else 0,
@@ -296,28 +297,38 @@ def find_marker_or_segment(
296297
return marker, search_range
297298

298299

299-
def restrict_search_range(action, target, identifier_finder: IdentifierFinder, lines: Sequence[str]) -> RangeSpec:
300-
match target:
301-
case IdentifierFromFile() as identifier_from_file:
302-
identifier_marker = identifier_from_file.as_marker
303-
identifier_boundaries = identifier_finder(identifier_marker)
304-
if not identifier_boundaries:
305-
raise ValueError(f"'{identifier_marker}' not found")
306-
match action:
307-
case RegionClause(region=region):
308-
match region:
309-
case BodyOrWhole() | RelativePositionType():
310-
return identifier_boundaries.location_to_search_range(region)
311-
case Marker() as inner_marker:
312-
match identifier_finder(inner_marker, identifier_boundaries.whole):
313-
case IdentifierBoundaries() as inner_boundaries:
314-
return inner_boundaries.whole
315-
case RangeSpec() as inner_range_spec:
316-
return inner_range_spec
317-
case Segment() as segment:
318-
return segment.to_search_range(lines, identifier_boundaries.whole)
319-
case _ as invalid:
320-
raise ValueError(f'Unsupported region type: {type(invalid)}')
300+
def restrict_search_range(
301+
region: Region, parent_restriction: any,
302+
identifier_finder: IdentifierFinder, lines: Sequence[str]
303+
) -> RangeSpec:
304+
identifier_boundaries = None
305+
match parent_restriction:
306+
case IdentifierFromFile():
307+
identifier_boundaries = identifier_finder(parent_restriction.as_marker)
308+
match region:
309+
case BodyOrWhole() | RelativePositionType():
310+
match parent_restriction:
311+
case IdentifierFromFile():
312+
match identifier_boundaries:
313+
case None:
314+
raise ValueError(f"'{parent_restriction}' not found")
315+
case SingleFileClause():
316+
return RangeSpec.EMPTY
317+
case None:
318+
raise ValueError(f"'{region}' requires parent_restriction")
319+
case _:
320+
raise ValueError(f"'{region}' isn't compatible with {parent_restriction}")
321+
return identifier_boundaries.location_to_search_range(region)
322+
case Marker() as inner_marker:
323+
match identifier_finder(inner_marker, identifier_boundaries.whole if identifier_boundaries is not None else None):
324+
case IdentifierBoundaries() as inner_boundaries:
325+
return inner_boundaries.location_to_search_range(BodyOrWhole.WHOLE)
326+
case RangeSpec() as inner_range_spec:
327+
return inner_range_spec
328+
case Segment() as segment:
329+
return segment.to_search_range(lines, identifier_boundaries.whole if identifier_boundaries is not None else None)
330+
case _ as invalid:
331+
raise ValueError(f'Unsupported region type: {type(invalid)}')
321332
return RangeSpec.EMPTY
322333

323334

src/text_manipulation/range_spec.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
import re
1313
from collections.abc import Sequence
14-
from typing import NamedTuple
14+
from typing import NamedTuple, TypeAlias
1515
from functools import total_ordering
1616

1717

1818
from cedarscript_ast_parser import Marker, RelativeMarker, RelativePositionType, MarkerType, BodyOrWhole
19+
1920
from .indentation_kit import get_line_indent_count
2021

2122
MATCH_TYPES = ('exact', 'stripped', 'normalized', 'partial')
@@ -43,20 +44,43 @@ def __str__(self):
4344
return (f'{self.start}:{self.end}' if self.as_index is None else f'%{self.as_index}') + f'@{self.indent}'
4445

4546
def __lt__(self, other):
46-
"""Compare if this range is strictly before another range."""
47-
return self.end < other.start
47+
"""Compare if this range is strictly before another range or int."""
48+
match other:
49+
case int():
50+
return self.end <= other
51+
case RangeSpec():
52+
return self.end <= other.start
4853

4954
def __le__(self, other):
5055
"""Compare if this range is before or adjacent to another range."""
51-
return self.end <= other.start
56+
match other:
57+
case int():
58+
return self.end <= other - 1
59+
case RangeSpec():
60+
return self.end <= other.start - 1
5261

5362
def __gt__(self, other):
5463
"""Compare if this range is strictly after another range."""
55-
return self.start > other.end
64+
match other:
65+
case int():
66+
return self.start > other
67+
case RangeSpec():
68+
return self.start >= other.end
5669

5770
def __ge__(self, other):
5871
"""Compare if this range is after or adjacent to another range."""
59-
return self.start >= other.end
72+
match other:
73+
case int():
74+
return self.start >= other
75+
case RangeSpec():
76+
return self.start >= other.end - 1
77+
78+
def __contains__(self, item):
79+
match item:
80+
case int():
81+
return self.start <= item < self.end
82+
case RangeSpec():
83+
return self == RangeSpec.EMPTY or item != RangeSpec.EMPTY and self.start <= item.start and item.end <= self.end
6084

6185
@property
6286
def line_count(self):
@@ -235,6 +259,12 @@ def from_line_marker(
235259
RangeSpec.EMPTY = RangeSpec(0, -1, 0)
236260

237261

262+
class ParentInfo(NamedTuple):
263+
parent_name: str
264+
parent_type: str
265+
266+
ParentRestriction: TypeAlias = RangeSpec | str | None
267+
238268
class IdentifierBoundaries(NamedTuple):
239269
"""
240270
Represents the boundaries of an identifier in code, including its whole range and body range.
@@ -251,6 +281,7 @@ class IdentifierBoundaries(NamedTuple):
251281
body: RangeSpec | None = None
252282
docstring: RangeSpec | None = None
253283
decorators: list[RangeSpec] = []
284+
parents: list[ParentInfo] = []
254285

255286
def __str__(self):
256287
return f'IdentifierBoundaries({self.whole} (BODY: {self.body}) )'
@@ -270,6 +301,18 @@ def end_line(self) -> int:
270301
"""Return the 1-indexed end line of the whole identifier."""
271302
return self.whole.end
272303

304+
def match_parent(self, parent_restriction: ParentRestriction) -> bool:
305+
match parent_restriction:
306+
case None:
307+
return True
308+
case RangeSpec():
309+
return self.whole in parent_restriction
310+
case str() as parent_name:
311+
# TODO Implement advanced query syntax
312+
return parent_name in [p.parent_name for p in self.parents]
313+
case _:
314+
raise ValueError(f'Invalid parent restriction: {parent_restriction}')
315+
273316
def location_to_search_range(self, location: BodyOrWhole | RelativePositionType) -> RangeSpec:
274317
"""
275318
Convert a location specifier to a RangeSpec for searching.

0 commit comments

Comments
 (0)