11import logging
22from typing import Callable , TypeAlias , Sequence , NamedTuple , Iterable
33
4- from cedarscript_ast_parser import Marker , MarkerType
4+ from cedarscript_ast_parser import Marker , MarkerType , Segment
55from grep_ast import filename_to_lang
66from text_manipulation .indentation_kit import get_line_indent_count
77from text_manipulation .range_spec import IdentifierBoundaries , RangeSpec
1111
1212_log = logging .getLogger (__name__ )
1313
14- IdentifierFinder : TypeAlias = Callable [[Marker ], IdentifierBoundaries | None ]
14+ IdentifierFinder : TypeAlias = Callable [[Marker ], IdentifierBoundaries | RangeSpec | None ]
1515
1616
17- def find_identifier (source_info : tuple [str , str | Sequence [str ]]) -> IdentifierFinder :
17+ def find_identifier (source_info : tuple [str , str | Sequence [str ]], search_rage : RangeSpec = RangeSpec . EMPTY ) -> IdentifierFinder :
1818 file_path = source_info [0 ]
1919 source = source_info [1 ]
2020 if not isinstance (source , str ):
2121 source = '\n ' .join (source )
22- return _select_finder (file_path , source )
22+ return _select_finder (file_path , source , search_rage )
2323
2424
25- def _select_finder (file_path : str , source : str ) -> IdentifierFinder :
25+ def _select_finder (file_path : str , source : str , search_range : RangeSpec = RangeSpec . EMPTY ) -> IdentifierFinder :
2626 langstr = filename_to_lang (file_path )
2727 language = get_language (langstr )
2828 parser = get_parser (langstr )
2929 _log .info (f"[select_finder] Selected { language } " )
3030
3131 tree = parser .parse (bytes (source , "utf-8" ))
32+ source = source .splitlines ()
3233 query_info = LANG_TO_TREE_SITTER_QUERY [langstr ]
3334
34- def find_by_marker (marker : Marker ) -> IdentifierBoundaries | None :
35- return _find_identifier (language , source , tree , query_info , marker )
35+ def find_by_marker (mos : Marker | Segment ) -> IdentifierBoundaries | RangeSpec | None :
36+ match mos :
37+
38+ case Marker (MarkerType .LINE ) | Segment ():
39+ # TODO pass IdentifierFinder to enable identifiers as start and/or end of a segment
40+ return mos .to_search_range (source , search_range ).set_line_count (1 ) # returns RangeSpec
41+
42+ case Marker () as marker :
43+ # Returns IdentifierBoundaries
44+ return _find_identifier (language , source , tree , query_info , marker )
3645
3746 return find_by_marker
3847
@@ -66,7 +75,7 @@ def identifier(self):
6675 return self .node .text .decode ("utf-8" )
6776
6877
69- def associate_identifier_parts (captures : Iterable [CaptureInfo ], lines : list [str ]) -> list [IdentifierBoundaries ]:
78+ def associate_identifier_parts (captures : Iterable [CaptureInfo ], lines : Sequence [str ]) -> list [IdentifierBoundaries ]:
7079 identifier_map : dict [int , IdentifierBoundaries ] = {}
7180
7281 for capture in captures :
@@ -105,7 +114,7 @@ def find_parent_definition(node):
105114 return None
106115
107116
108- def _find_identifier (language , source : str , tree , query_scm : dict [str , dict [str , str ]], marker : Marker ) \
117+ def _find_identifier (language , source : Sequence [ str ] , tree , query_scm : dict [str , dict [str , str ]], marker : Marker ) \
109118 -> IdentifierBoundaries | None :
110119 """
111120 Find the starting line index of a specified function in the given lines.
@@ -121,7 +130,7 @@ def _find_identifier(language, source: str, tree, query_scm: dict[str, dict[str,
121130 candidates = language .query (query_scm [marker .type ].format (name = marker .value )).captures (tree .root_node )
122131 candidates : list [IdentifierBoundaries ] = capture2identifier_boundaries (
123132 candidates ,
124- source . splitlines ()
133+ source
125134 )
126135 except Exception as e :
127136 raise ValueError (f"Unable to capture nodes for { marker } : { e } " ) from e
@@ -145,7 +154,7 @@ def _find_identifier(language, source: str, tree, query_scm: dict[str, dict[str,
145154 return result
146155
147156
148- def capture2identifier_boundaries (captures , lines : list [str ]) -> list [IdentifierBoundaries ]:
157+ def capture2identifier_boundaries (captures , lines : Sequence [str ]) -> list [IdentifierBoundaries ]:
149158 captures = [CaptureInfo (c [1 ], c [0 ]) for c in captures if not c [1 ].startswith ('_' )]
150159 unique_captures = {}
151160 for capture in captures :
0 commit comments