Skip to content

Commit f7bd3c1

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Expose log probs of candidates in LlmResponse
fixes #2764 PiperOrigin-RevId: 807516910
1 parent 1ce043a commit f7bd3c1

File tree

2 files changed

+253
-0
lines changed

2 files changed

+253
-0
lines changed

src/google/adk/models/llm_response.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class LlmResponse(BaseModel):
4242
custom_metadata: The custom metadata of the LlmResponse.
4343
input_transcription: Audio transcription of user input.
4444
output_transcription: Audio transcription of model output.
45+
avg_logprobs: Average log probability of the generated tokens.
46+
logprobs_result: Detailed log probabilities for chosen and top candidate tokens.
4547
"""
4648

4749
model_config = ConfigDict(
@@ -109,6 +111,12 @@ class LlmResponse(BaseModel):
109111
output_transcription: Optional[types.Transcription] = None
110112
"""Audio transcription of model output."""
111113

114+
avg_logprobs: Optional[float] = None
115+
"""Average log probability of the generated tokens."""
116+
117+
logprobs_result: Optional[types.LogprobsResult] = None
118+
"""Detailed log probabilities for chosen and top candidate tokens."""
119+
112120
@staticmethod
113121
def create(
114122
generate_content_response: types.GenerateContentResponse,
@@ -131,13 +139,17 @@ def create(
131139
grounding_metadata=candidate.grounding_metadata,
132140
usage_metadata=usage_metadata,
133141
finish_reason=candidate.finish_reason,
142+
avg_logprobs=candidate.avg_logprobs,
143+
logprobs_result=candidate.logprobs_result,
134144
)
135145
else:
136146
return LlmResponse(
137147
error_code=candidate.finish_reason,
138148
error_message=candidate.finish_message,
139149
usage_metadata=usage_metadata,
140150
finish_reason=candidate.finish_reason,
151+
avg_logprobs=candidate.avg_logprobs,
152+
logprobs_result=candidate.logprobs_result,
141153
)
142154
else:
143155
if generate_content_response.prompt_feedback:
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for LlmResponse, including log probabilities feature."""
16+
17+
from google.adk.models.llm_response import LlmResponse
18+
from google.genai import types
19+
20+
21+
def test_llm_response_create_with_logprobs():
22+
"""Test LlmResponse.create() extracts logprobs from candidate."""
23+
avg_logprobs = -0.75
24+
logprobs_result = types.LogprobsResult(
25+
chosen_candidates=[], top_candidates=[]
26+
)
27+
28+
generate_content_response = types.GenerateContentResponse(
29+
candidates=[
30+
types.Candidate(
31+
content=types.Content(parts=[types.Part(text='Response text')]),
32+
finish_reason=types.FinishReason.STOP,
33+
avg_logprobs=avg_logprobs,
34+
logprobs_result=logprobs_result,
35+
)
36+
]
37+
)
38+
39+
response = LlmResponse.create(generate_content_response)
40+
41+
assert response.avg_logprobs == avg_logprobs
42+
assert response.logprobs_result == logprobs_result
43+
assert response.content.parts[0].text == 'Response text'
44+
assert response.finish_reason == types.FinishReason.STOP
45+
46+
47+
def test_llm_response_create_without_logprobs():
48+
"""Test LlmResponse.create() handles missing logprobs gracefully."""
49+
generate_content_response = types.GenerateContentResponse(
50+
candidates=[
51+
types.Candidate(
52+
content=types.Content(parts=[types.Part(text='Response text')]),
53+
finish_reason=types.FinishReason.STOP,
54+
avg_logprobs=None,
55+
logprobs_result=None,
56+
)
57+
]
58+
)
59+
60+
response = LlmResponse.create(generate_content_response)
61+
62+
assert response.avg_logprobs is None
63+
assert response.logprobs_result is None
64+
assert response.content.parts[0].text == 'Response text'
65+
66+
67+
def test_llm_response_create_error_case_with_logprobs():
68+
"""Test LlmResponse.create() includes logprobs in error cases."""
69+
avg_logprobs = -2.1
70+
71+
generate_content_response = types.GenerateContentResponse(
72+
candidates=[
73+
types.Candidate(
74+
content=None, # No content - error case
75+
finish_reason=types.FinishReason.SAFETY,
76+
finish_message='Safety filter triggered',
77+
avg_logprobs=avg_logprobs,
78+
logprobs_result=None,
79+
)
80+
]
81+
)
82+
83+
response = LlmResponse.create(generate_content_response)
84+
85+
assert response.avg_logprobs == avg_logprobs
86+
assert response.logprobs_result is None
87+
assert response.error_code == types.FinishReason.SAFETY
88+
assert response.error_message == 'Safety filter triggered'
89+
90+
91+
def test_llm_response_create_no_candidates():
92+
"""Test LlmResponse.create() with no candidates."""
93+
generate_content_response = types.GenerateContentResponse(
94+
candidates=[],
95+
prompt_feedback=types.GenerateContentResponsePromptFeedback(
96+
block_reason=types.BlockedReason.SAFETY,
97+
block_reason_message='Prompt blocked for safety',
98+
),
99+
)
100+
101+
response = LlmResponse.create(generate_content_response)
102+
103+
# No candidates means no logprobs
104+
assert response.avg_logprobs is None
105+
assert response.logprobs_result is None
106+
assert response.error_code == types.BlockedReason.SAFETY
107+
assert response.error_message == 'Prompt blocked for safety'
108+
109+
110+
def test_llm_response_create_with_concrete_logprobs_result():
111+
"""Test LlmResponse.create() with detailed logprobs_result containing actual token data."""
112+
# Create realistic logprobs data
113+
chosen_candidates = [
114+
types.LogprobsResultCandidate(
115+
token='The', log_probability=-0.1, token_id=123
116+
),
117+
types.LogprobsResultCandidate(
118+
token=' capital', log_probability=-0.5, token_id=456
119+
),
120+
types.LogprobsResultCandidate(
121+
token=' of', log_probability=-0.2, token_id=789
122+
),
123+
]
124+
125+
top_candidates = [
126+
types.LogprobsResultTopCandidates(
127+
candidates=[
128+
types.LogprobsResultCandidate(
129+
token='The', log_probability=-0.1, token_id=123
130+
),
131+
types.LogprobsResultCandidate(
132+
token='A', log_probability=-2.3, token_id=124
133+
),
134+
types.LogprobsResultCandidate(
135+
token='This', log_probability=-3.1, token_id=125
136+
),
137+
]
138+
),
139+
types.LogprobsResultTopCandidates(
140+
candidates=[
141+
types.LogprobsResultCandidate(
142+
token=' capital', log_probability=-0.5, token_id=456
143+
),
144+
types.LogprobsResultCandidate(
145+
token=' city', log_probability=-1.2, token_id=457
146+
),
147+
types.LogprobsResultCandidate(
148+
token=' main', log_probability=-2.8, token_id=458
149+
),
150+
]
151+
),
152+
]
153+
154+
avg_logprobs = -0.27 # Average of -0.1, -0.5, -0.2
155+
logprobs_result = types.LogprobsResult(
156+
chosen_candidates=chosen_candidates, top_candidates=top_candidates
157+
)
158+
159+
generate_content_response = types.GenerateContentResponse(
160+
candidates=[
161+
types.Candidate(
162+
content=types.Content(
163+
parts=[types.Part(text='The capital of France is Paris.')]
164+
),
165+
finish_reason=types.FinishReason.STOP,
166+
avg_logprobs=avg_logprobs,
167+
logprobs_result=logprobs_result,
168+
)
169+
]
170+
)
171+
172+
response = LlmResponse.create(generate_content_response)
173+
174+
assert response.avg_logprobs == avg_logprobs
175+
assert response.logprobs_result is not None
176+
177+
# Test chosen candidates
178+
assert len(response.logprobs_result.chosen_candidates) == 3
179+
assert response.logprobs_result.chosen_candidates[0].token == 'The'
180+
assert response.logprobs_result.chosen_candidates[0].log_probability == -0.1
181+
assert response.logprobs_result.chosen_candidates[0].token_id == 123
182+
assert response.logprobs_result.chosen_candidates[1].token == ' capital'
183+
assert response.logprobs_result.chosen_candidates[1].log_probability == -0.5
184+
assert response.logprobs_result.chosen_candidates[1].token_id == 456
185+
186+
# Test top candidates
187+
assert len(response.logprobs_result.top_candidates) == 2
188+
assert (
189+
len(response.logprobs_result.top_candidates[0].candidates) == 3
190+
) # 3 alternatives for first token
191+
assert response.logprobs_result.top_candidates[0].candidates[0].token == 'The'
192+
assert (
193+
response.logprobs_result.top_candidates[0].candidates[0].token_id == 123
194+
)
195+
assert response.logprobs_result.top_candidates[0].candidates[1].token == 'A'
196+
assert (
197+
response.logprobs_result.top_candidates[0].candidates[1].token_id == 124
198+
)
199+
assert (
200+
response.logprobs_result.top_candidates[0].candidates[2].token == 'This'
201+
)
202+
assert (
203+
response.logprobs_result.top_candidates[0].candidates[2].token_id == 125
204+
)
205+
206+
207+
def test_llm_response_create_with_partial_logprobs_result():
208+
"""Test LlmResponse.create() with logprobs_result having only chosen_candidates."""
209+
chosen_candidates = [
210+
types.LogprobsResultCandidate(
211+
token='Hello', log_probability=-0.05, token_id=111
212+
),
213+
types.LogprobsResultCandidate(
214+
token=' world', log_probability=-0.8, token_id=222
215+
),
216+
]
217+
218+
logprobs_result = types.LogprobsResult(
219+
chosen_candidates=chosen_candidates,
220+
top_candidates=[], # Empty top candidates
221+
)
222+
223+
generate_content_response = types.GenerateContentResponse(
224+
candidates=[
225+
types.Candidate(
226+
content=types.Content(parts=[types.Part(text='Hello world')]),
227+
finish_reason=types.FinishReason.STOP,
228+
avg_logprobs=-0.425, # Average of -0.05 and -0.8
229+
logprobs_result=logprobs_result,
230+
)
231+
]
232+
)
233+
234+
response = LlmResponse.create(generate_content_response)
235+
236+
assert response.avg_logprobs == -0.425
237+
assert response.logprobs_result is not None
238+
assert len(response.logprobs_result.chosen_candidates) == 2
239+
assert len(response.logprobs_result.top_candidates) == 0
240+
assert response.logprobs_result.chosen_candidates[0].token == 'Hello'
241+
assert response.logprobs_result.chosen_candidates[1].token == ' world'

0 commit comments

Comments
 (0)