Skip to content

Commit 4b8f1c4

Browse files
committed
Bugfix + exposing cutoff paramter on method call
1 parent 823b8f6 commit 4b8f1c4

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

stringlifier/api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, model_base: Optional[str] = None):
4343
self.encodings = encodings
4444
self._c_index: int = encodings._label2int['C']
4545

46-
def __call__(self, string_or_list: Union[str, List[str]], return_tokens: bool = False) -> Union[
46+
def __call__(self, string_or_list: Union[str, List[str]], return_tokens: bool = False, cutoff: int = 5) -> Union[
4747
Tuple[List[str], List[List[Tuple[str, int, int, str]]]], List[str]]:
4848
if isinstance(string_or_list, str):
4949
tokens = [string_or_list]
@@ -58,7 +58,7 @@ def __call__(self, string_or_list: Union[str, List[str]], return_tokens: bool =
5858
new_strings: List[str] = []
5959

6060
for iBatch in range(p_ts.shape[0]):
61-
new_str, toks = self._extract_tokens(tokens[iBatch], p_ts[iBatch])
61+
new_str, toks = self._extract_tokens(tokens[iBatch], p_ts[iBatch], cutoff=cutoff)
6262
new_strings.append(new_str)
6363
ext_tokens.append(toks)
6464

@@ -108,7 +108,7 @@ def _extract_tokens_2class(self, string: str, pred: NDArray[Int64]) -> Tuple[str
108108
new_str += string[last_pos:]
109109
return new_str, final_toks
110110

111-
def _extract_tokens(self, string: str, pred: NDArray[Int64]) -> Tuple[str, List[Tuple[str, int, int, str]]]:
111+
def _extract_tokens(self, string: str, pred: NDArray[Int64], cutoff: int = 5) -> Tuple[str, List[Tuple[str, int, int, str]]]:
112112
mask = ''
113113
numbers = {str(ii): 1 for ii in range(10)}
114114

@@ -168,14 +168,14 @@ def _extract_tokens(self, string: str, pred: NDArray[Int64]) -> Tuple[str, List[
168168
# filter small tokens
169169
final_toks: List[Tuple[str, int, int, str]] = []
170170
for token in tokens:
171-
if token[2] - token[1] > 5:
171+
if token[2] - token[1] > cutoff:
172172
final_toks.append(token)
173173
# compose new string
174174
new_str: str = ''
175175
last_pos = 0
176176

177-
from ipdb import set_trace
178-
set_trace()
177+
# from ipdb import set_trace
178+
# set_trace()
179179
for token in final_toks:
180180
if token[1] > last_pos:
181181
new_str += string[last_pos:token[1]]

0 commit comments

Comments
 (0)