Skip to content

Commit 36dd459

Browse files
committed
Fix mypy errors
1 parent c380a1a commit 36dd459

File tree

2 files changed

+56
-25
lines changed

2 files changed

+56
-25
lines changed

examples/google_adk/calendar_agent/adk_agent_executor.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from google.adk import Runner
99
from google.adk.auth import AuthConfig
1010
from google.adk.events import Event
11+
from google.adk.session_service import (
12+
Session,
13+
)
1114
from google.genai import types
1215

1316
from a2a.server.agent_execution import AgentExecutor
@@ -44,6 +47,7 @@ class ADKAgentExecutor(AgentExecutor):
4447
"""An AgentExecutor that runs an ADK-based Agent."""
4548

4649
_awaiting_auth: dict[str, asyncio.Future]
50+
_running_sessions: dict[str, Session]
4751

4852
def __init__(self, runner: Runner, card: AgentCard):
4953
self.runner = runner
@@ -53,7 +57,7 @@ def __init__(self, runner: Runner, card: AgentCard):
5357

5458
def _run_agent(
5559
self, session_id, new_message: types.Content
56-
) -> AsyncGenerator[Event, None]:
60+
) -> AsyncGenerator[Event]:
5761
return self.runner.run_async(
5862
session_id=session_id, user_id='self', new_message=new_message
5963
)
@@ -97,12 +101,17 @@ async def _process_request(
97101
# is received.
98102
break
99103
if event.is_final_response():
100-
parts = convert_genai_parts_to_a2a(event.content.parts)
101-
logger.debug('Yielding final response: %s', parts)
102-
task_updater.add_artifact(parts)
103-
task_updater.complete()
104+
if event.content and event.content.parts:
105+
parts = convert_genai_parts_to_a2a(event.content.parts)
106+
logger.debug('Yielding final response: %s', parts)
107+
task_updater.add_artifact(parts)
108+
task_updater.complete()
104109
break
105-
if not event.get_function_calls():
110+
if (
111+
not event.get_function_calls()
112+
and event.content
113+
and event.content.parts
114+
):
106115
logger.debug('Yielding update response')
107116
task_updater.update_status(
108117
TaskState.working,
@@ -228,12 +237,17 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue):
228237
async def on_auth_callback(self, state: str, uri: str):
229238
self._awaiting_auth[state].set_result(uri)
230239

231-
def _upsert_session(self, session_id: str):
232-
return self.runner.session_service.get_session(
233-
app_name=self.runner.app_name, user_id='self', session_id=session_id
234-
) or self.runner.session_service.create_session(
240+
def _upsert_session(self, session_id: str) -> Session:
241+
session = self.runner.session_service.get_session(
235242
app_name=self.runner.app_name, user_id='self', session_id=session_id
236243
)
244+
if session is None:
245+
session = self.runner.session_service.create_session(
246+
app_name=self.runner.app_name,
247+
user_id='self',
248+
session_id=session_id,
249+
)
250+
return session
237251

238252

239253
def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]:
@@ -243,32 +257,43 @@ def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]:
243257

244258
def convert_a2a_part_to_genai(part: Part) -> types.Part:
245259
"""Convert a single A2A Part type into a Google Gen AI Part type."""
246-
part = part.root
247-
if isinstance(part, TextPart):
248-
return types.Part(text=part.text)
249-
if isinstance(part, FilePart):
250-
if isinstance(part.file, FileWithUri):
260+
if isinstance(part.root, TextPart):
261+
return types.Part(text=part.root.text)
262+
if isinstance(part.root, FilePart):
263+
file_data = part.root.file
264+
if isinstance(file_data, FileWithUri):
251265
return types.Part(
252266
file_data=types.FileData(
253-
file_uri=part.file.uri, mime_type=part.file.mime_type
267+
file_uri=file_data.uri,
268+
mime_type=file_data.mime_type,
254269
)
255270
)
256-
if isinstance(part.file, FileWithBytes):
271+
if isinstance(file_data, FileWithBytes):
257272
return types.Part(
258273
inline_data=types.Blob(
259-
data=part.file.bytes, mime_type=part.file.mime_type
274+
data=file_data.bytes,
275+
mime_type=file_data.mime_type,
260276
)
261277
)
262-
raise ValueError(f'Unsupported file type: {type(part.file)}')
263-
raise ValueError(f'Unsupported part type: {type(part)}')
278+
raise ValueError(f'Unsupported file type: {type(file_data)}')
279+
raise ValueError(f'Unsupported part root type: {type(part.root)}')
264280

265281

266282
def convert_genai_parts_to_a2a(parts: list[types.Part]) -> list[Part]:
267283
"""Convert a list of Google Gen AI Part types into a list of A2A Part types."""
284+
if not parts:
285+
return []
268286
return [
269287
convert_genai_part_to_a2a(part)
270288
for part in parts
271-
if (part.text or part.file_data or part.inline_data)
289+
if part
290+
and (
291+
part.text
292+
or part.file_data
293+
or part.inline_data
294+
or part.function_call
295+
or part.function_response
296+
)
272297
]
273298

274299

@@ -295,7 +320,9 @@ def convert_genai_part_to_a2a(part: types.Part) -> Part:
295320
raise ValueError(f'Unsupported part type: {part}')
296321

297322

298-
def get_auth_request_function_call(event: Event) -> types.FunctionCall:
323+
def get_auth_request_function_call(
324+
event: Event,
325+
) -> types.FunctionCall | None:
299326
"""Get the special auth request function call from the event."""
300327
if not (event.content and event.content.parts):
301328
return None
@@ -315,8 +342,12 @@ def get_auth_config(
315342
auth_request_function_call: types.FunctionCall,
316343
) -> AuthConfig:
317344
"""Extracts the AuthConfig object from the arguments of the auth request function call."""
318-
if not auth_request_function_call.args or not (
319-
auth_config := auth_request_function_call.args.get('auth_config')
345+
if (
346+
not auth_request_function_call.args
347+
or not isinstance(auth_request_function_call.args, dict)
348+
or not (
349+
auth_config := auth_request_function_call.args.get('auth_config')
350+
)
320351
):
321352
raise ValueError(
322353
f'Cannot get auth config from function call: {auth_request_function_call}'

src/a2a/client/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def send_message_streaming(
133133
request: SendStreamingMessageRequest,
134134
*,
135135
http_kwargs: dict[str, Any] | None = None,
136-
) -> AsyncGenerator[SendStreamingMessageResponse, None]:
136+
) -> AsyncGenerator[SendStreamingMessageResponse]:
137137
if not request.id:
138138
request.id = str(uuid4())
139139

0 commit comments

Comments
 (0)