@@ -1552,7 +1552,7 @@ def logit_bias_processor(
15521552 self .detokenize (completion_tokens [:returned_tokens ])
15531553 )
15541554 token_offset = len (prompt_tokens ) + returned_tokens
1555- logits = self ._scores [token_offset - 1 , :]. tolist ()
1555+ logits = self ._scores [token_offset - 1 , :]
15561556 current_logprobs = Llama .logits_to_logprobs (logits )
15571557 sorted_logprobs = list (
15581558 sorted (
@@ -1671,7 +1671,7 @@ def logit_bias_processor(
16711671 self .detokenize (completion_tokens [:returned_tokens ])
16721672 )
16731673 token_offset = len (prompt_tokens ) + returned_tokens - 1
1674- logits = self ._scores [token_offset , :]. tolist ()
1674+ logits = self ._scores [token_offset , :]
16751675 current_logprobs = Llama .logits_to_logprobs (logits )
16761676 sorted_logprobs = list (
16771677 sorted (
@@ -1785,9 +1785,8 @@ def logit_bias_processor(
17851785 self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
17861786 for token in all_tokens
17871787 ]
1788- all_logprobs = [
1789- Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
1790- ][token_offset :]
1788+ all_logprobs = Llama .logits_to_logprobs (self ._scores )[token_offset :]
1789+ # TODO: may be able to change this loop to use np.take_along_dim
17911790 for token , token_str , logprobs_token in zip (
17921791 all_tokens , all_token_strs , all_logprobs
17931792 ):
@@ -2282,7 +2281,7 @@ def token_nl(self) -> int:
22822281
22832282 @staticmethod
22842283 def logits_to_logprobs (
2285- logits : Union [List , npt .NDArray [np .single ]], axis : int = - 1
2284+ logits : Union [npt .NDArray [np .single ], List ], axis : int = - 1
22862285 ) -> npt .NDArray [np .single ]:
22872286 # https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
22882287 logits_maxs : np .ndarray = np .amax (logits , axis = axis , keepdims = True )
@@ -2293,7 +2292,7 @@ def logits_to_logprobs(
22932292 subtract_maxs = np .subtract (logits , logits_maxs , dtype = np .single )
22942293 exp = np .exp (subtract_maxs )
22952294 # Suppress warnings about log of zero
2296- with np .errstate (divide = ' ignore' ):
2295+ with np .errstate (divide = " ignore" ):
22972296 summed = np .sum (exp , axis = axis , keepdims = True )
22982297 out = np .log (summed )
22992298 return subtract_maxs - out
0 commit comments