@@ -43,7 +43,7 @@ def __init__(self, model_base: Optional[str] = None):
4343 self .encodings = encodings
4444 self ._c_index : int = encodings ._label2int ['C' ]
4545
46- def __call__ (self , string_or_list : Union [str , List [str ]], return_tokens : bool = False ) -> Union [
46+ def __call__ (self , string_or_list : Union [str , List [str ]], return_tokens : bool = False , cutoff : int = 5 ) -> Union [
4747 Tuple [List [str ], List [List [Tuple [str , int , int , str ]]]], List [str ]]:
4848 if isinstance (string_or_list , str ):
4949 tokens = [string_or_list ]
@@ -58,7 +58,7 @@ def __call__(self, string_or_list: Union[str, List[str]], return_tokens: bool =
5858 new_strings : List [str ] = []
5959
6060 for iBatch in range (p_ts .shape [0 ]):
61- new_str , toks = self ._extract_tokens (tokens [iBatch ], p_ts [iBatch ])
61+ new_str , toks = self ._extract_tokens (tokens [iBatch ], p_ts [iBatch ], cutoff = cutoff )
6262 new_strings .append (new_str )
6363 ext_tokens .append (toks )
6464
@@ -108,7 +108,7 @@ def _extract_tokens_2class(self, string: str, pred: NDArray[Int64]) -> Tuple[str
108108 new_str += string [last_pos :]
109109 return new_str , final_toks
110110
111- def _extract_tokens (self , string : str , pred : NDArray [Int64 ]) -> Tuple [str , List [Tuple [str , int , int , str ]]]:
111+ def _extract_tokens (self , string : str , pred : NDArray [Int64 ], cutoff : int = 5 ) -> Tuple [str , List [Tuple [str , int , int , str ]]]:
112112 mask = ''
113113 numbers = {str (ii ): 1 for ii in range (10 )}
114114
@@ -168,14 +168,14 @@ def _extract_tokens(self, string: str, pred: NDArray[Int64]) -> Tuple[str, List[
168168 # filter small tokens
169169 final_toks : List [Tuple [str , int , int , str ]] = []
170170 for token in tokens :
171- if token [2 ] - token [1 ] > 5 :
171+ if token [2 ] - token [1 ] > cutoff :
172172 final_toks .append (token )
173173 # compose new string
174174 new_str : str = ''
175175 last_pos = 0
176176
177- from ipdb import set_trace
178- set_trace ()
177+ # from ipdb import set_trace
178+ # set_trace()
179179 for token in final_toks :
180180 if token [1 ] > last_pos :
181181 new_str += string [last_pos :token [1 ]]
0 commit comments