Skip to content

Commit eeedc6b

Browse files
committed
Bugfix for empty strings and lists
1 parent 4b8f1c4 commit eeedc6b

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

stringlifier/api.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def __call__(self, string_or_list: Union[str, List[str]], return_tokens: bool =
5050
else:
5151
tokens = string_or_list
5252

53+
max_len = max([len(s) for s in tokens])
54+
if max_len == 0:
55+
if return_tokens:
56+
return [''], []
57+
else:
58+
return ['']
59+
5360
with torch.no_grad():
5461
p_ts = self.classifier(tokens)
5562

@@ -108,14 +115,15 @@ def _extract_tokens_2class(self, string: str, pred: NDArray[Int64]) -> Tuple[str
108115
new_str += string[last_pos:]
109116
return new_str, final_toks
110117

111-
def _extract_tokens(self, string: str, pred: NDArray[Int64], cutoff: int = 5) -> Tuple[str, List[Tuple[str, int, int, str]]]:
118+
def _extract_tokens(self, string: str, pred: NDArray[Int64], cutoff: int = 5) -> Tuple[
119+
str, List[Tuple[str, int, int, str]]]:
112120
mask = ''
113121
numbers = {str(ii): 1 for ii in range(10)}
114122

115123
for ii in range(len(pred)):
116124
p = pred[ii]
117125
cls = self.encodings._label_list[p]
118-
if cls == 'C' and string[ii] in numbers:
126+
if ii < len(string) and cls == 'C' and string[ii] in numbers:
119127
mask += 'N'
120128
else:
121129
mask += cls

0 commit comments

Comments
 (0)