From ab39ed11fcc0e56dfa8d05b42e03c0d0417092f6 Mon Sep 17 00:00:00 2001 From: Mengqin Shen Date: Thu, 5 Feb 2026 21:56:00 -0800 Subject: [PATCH 1/3] fix(py): fix model config consistencies for gemini --- .../genkit/src/genkit/ai/_base_async.py | 72 +++-------- .../plugins/google_genai/models/gemini.py | 116 ++++++++++++------ .../plugins/google_genai/models/utils.py | 37 +++++- py/samples/google-genai-hello/src/main.py | 6 +- 4 files changed, 129 insertions(+), 102 deletions(-) diff --git a/py/packages/genkit/src/genkit/ai/_base_async.py b/py/packages/genkit/src/genkit/ai/_base_async.py index 8f849c65b2..c232b57ff8 100644 --- a/py/packages/genkit/src/genkit/ai/_base_async.py +++ b/py/packages/genkit/src/genkit/ai/_base_async.py @@ -16,9 +16,8 @@ """Asynchronous server gateway interface implementation for Genkit.""" -import signal from collections.abc import Coroutine -from typing import Any, TypeVar +from typing import Any, TypeVar, cast import anyio import httpx @@ -127,7 +126,7 @@ async def dev_runner() -> T: assert spec is not None # Capture spec in local var for nested functions (pyrefly doesn't narrow closures) server_spec: ServerSpec = spec - user_result: T | None = None + user_result: T = None # type: ignore[assignment] user_task_finished_event = anyio.Event() async def run_user_coro_wrapper() -> None: @@ -135,64 +134,25 @@ async def run_user_coro_wrapper() -> None: nonlocal user_result try: user_result = await coro - logger.debug('User coroutine completed successfully.') except Exception as err: - # Log error but don't necessarily stop the server logger.error(f'User coroutine failed: {err}', exc_info=True) - # Store exception? Or let TaskGroup handle it if critical? - # Depending on desired behavior, could raise here to stop everything. - pass # Continue running server for now finally: user_task_finished_event.set() reflection_server = _make_reflection_server(self.registry, server_spec) - # Setup signal handlers for graceful shutdown (parity with JS) - - # Actually, anyio.run handles Ctrl+C (SIGINT) by raising KeyboardInterrupt/CancelledError - # For SIGTERM, we might need to be explicit if we run in a container/process manager. - # JS uses: process.on('SIGTERM', shutdown); process.on('SIGINT', shutdown); - - # Since anyio/asyncio handles SIGINT well, let's add a task to catch SIGTERM - async def handle_sigterm(tg_to_cancel: anyio.abc.TaskGroup) -> None: # type: ignore[name-defined] - with anyio.open_signal_receiver(signal.SIGTERM) as signals: - async for _signum in signals: - logger.info('Received SIGTERM, cancelling tasks...') - tg_to_cancel.cancel_scope.cancel() - return - try: # Use lazy_write=True to prevent race condition where file exists before server is up async with RuntimeManager(server_spec, lazy_write=True) as runtime_manager: - # We use anyio.TaskGroup because it is compatible with - # asyncio's event loop and works with Python 3.10 - # (asyncio.TaskGroup was added in 3.11, and we can switch to - # that when we drop support for 3.10). async with anyio.create_task_group() as tg: # Start reflection server in the background. - tg.start_soon(reflection_server.serve, name='genkit-reflection-server') - await logger.ainfo(f'Started Genkit reflection server at {server_spec.url}') - - # Start SIGTERM handler - tg.start_soon(handle_sigterm, tg, name='genkit-sigterm-handler') - - # Wait for server to be responsive - # We need to loop and poll the health endpoint or wait for uvicorn to be ready - # Since uvicorn run is blocking (but we are in a task), we can't easily hook into its startup - # unless we use uvicorn's server object directly which we do. - # reflection_server.started is set when uvicorn starts. - - # Simple polling loop + tg.start_soon(reflection_server.serve) + logger.info(f'Started Genkit reflection server at {server_spec.url}') + # Wait for the server to be healthy before starting the user task. max_retries = 20 # 2 seconds total roughly for _i in range(max_retries): try: - # TODO(#4334): Use async http client if available to avoid blocking loop? - # But we are in dev mode, so maybe okay. - # Actually we should use anyio.to_thread to avoid blocking event loop - # or assume standard lib urllib is fast enough for localhost. - - # Use httpx async client to avoid blocking the event loop health_url = f'{server_spec.url}/api/__health' async with httpx.AsyncClient(timeout=0.5) as client: response = await client.get(health_url) @@ -207,29 +167,25 @@ async def handle_sigterm(tg_to_cancel: anyio.abc.TaskGroup) -> None: # type: ig _ = runtime_manager.write_runtime_file() # Start the (potentially short-lived) user coroutine wrapper - tg.start_soon(run_user_coro_wrapper, name='genkit-user-coroutine') - await logger.ainfo('Started Genkit user coroutine') + tg.start_soon(run_user_coro_wrapper) + logger.info('Started Genkit user coroutine') # Block here until the task group is canceled (e.g. Ctrl+C) # or a task raises an unhandled exception. It should not # exit just because the user coroutine finishes. - - except anyio.get_cancelled_exc_class(): - logger.info('Development server task group cancelled (e.g., Ctrl+C).') - raise + await anyio.Event().wait() except Exception: logger.exception('Development server task group error') raise - # After the TaskGroup finishes (error or cancelation). + # After the TaskGroup finishes (normally or by task completion). if user_task_finished_event.is_set(): - await logger.adebug('User coroutine finished before TaskGroup exit.') - if user_result is None: - raise RuntimeError('User coroutine finished without a result.') - return user_result + if user_result is not None: + return user_result + else: + raise RuntimeError('User coroutine finished without a result (likely cancelled).') - await logger.adebug('User coroutine did not finish before TaskGroup exit.') - raise RuntimeError('User coroutine did not finish before TaskGroup exit.') + return None # type: ignore[return-value] return anyio.run(dev_runner) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index bbcafe7117..999610ebf3 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -161,6 +161,10 @@ deprecated_enum_metafactory, ) from genkit.plugins.google_genai.models.utils import PartConverter +from genkit.core.typing import ( + Candidate, + FinishReason, +) from genkit.types import ( Constrained, GenerateRequest, @@ -1219,7 +1223,7 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> Gen # TODO(#4361): Do not move - this method mutates `request` by extracting system # prompts into configuration object - request_cfg = self._genkit_to_googleai_cfg(request=request) + request_cfg = await self._genkit_to_googleai_cfg(request=request) # TTS models require response_modalities: ["AUDIO"] if is_tts_model(model_name): @@ -1229,11 +1233,6 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> Gen # Image models require response_modalities: ["TEXT", "IMAGE"] if is_image_model(model_name): - if request.tools: - raise ValueError( - f'Model {model_name} does not support tools. ' - 'Please remove the tools config or use a model that supports tools.' - ) if not request_cfg: request_cfg = genai_types.GenerateContentConfig() request_cfg.response_modalities = ['TEXT', 'IMAGE'] @@ -1397,17 +1396,67 @@ async def _generate( ) from e span.set_attribute('genkit:output', dump_json(response)) - content = self._contents_from_response(response) + content = await self._contents_from_response(response) # Ensure we always have at least one content item to avoid UI errors if not content: content = [TextPart(text='')] + finish_reason = FinishReason.OTHER + candidates = [] + if response.candidates: + for i, c in enumerate(response.candidates): + c_content = [] + if c.content and c.content.parts: + for j, part in enumerate(c.content.parts): + converted = PartConverter.from_gemini(part=part, ref=str(j)) + if converted: + c_content.append(converted) + + if not c_content: + c_content = [TextPart(text='')] + + c_finish_reason = FinishReason.OTHER + if c.finish_reason: + fr_name = c.finish_reason.name + if fr_name == 'STOP': + c_finish_reason = FinishReason.STOP + elif fr_name == 'MAX_TOKENS': + c_finish_reason = FinishReason.LENGTH + elif fr_name in ['SAFETY', 'RECITATION', 'BLOCKLIST', 'PROHIBITED_CONTENT', 'SPII']: + c_finish_reason = FinishReason.BLOCKED + elif fr_name == 'OTHER': + c_finish_reason = FinishReason.OTHER + + if i == 0: + finish_reason = c_finish_reason + + candidates.append( + Candidate( + index=float(i), + message=Message(role=Role.MODEL, content=c_content), + finish_reason=c_finish_reason, + ) + ) + return GenerateResponse( message=Message( content=content, # type: ignore[arg-type] - content is list[Part] after conversion role=Role.MODEL, - ) + ), + finish_reason=finish_reason, + candidates=candidates, + usage=GenerationUsage( + input_tokens=float(response.usage_metadata.prompt_token_count or 0) + if response.usage_metadata + else None, + output_tokens=float(response.usage_metadata.candidates_token_count or 0) + if response.usage_metadata + else None, + total_tokens=float(response.usage_metadata.total_token_count or 0) + if response.usage_metadata + else None, + ), ) async def _streaming_generate( @@ -1520,7 +1569,7 @@ async def _build_messages( continue content_parts: list[genai_types.Part] = [] for p in msg.content: - converted = PartConverter.to_gemini(p) + converted = await PartConverter.to_gemini(p) if isinstance(converted, list): content_parts.extend(converted) else: @@ -1540,7 +1589,7 @@ async def _build_messages( return request_contents, cache - def _contents_from_response(self, response: genai_types.GenerateContentResponse) -> list: + async def _contents_from_response(self, response: genai_types.GenerateContentResponse) -> list: """Retrieve contents from google-genai response. Args: @@ -1561,15 +1610,21 @@ def _contents_from_response(self, response: genai_types.GenerateContentResponse) # Ensure we always return a list, even if empty return content if content else [] - def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.GenerateContentConfig | None: - """Translate GenerationCommonConfig to Google Ai GenerateContentConfig. + async def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.GenerateContentConfig | None: + """Converts a Genkit GenerateRequest to a Gemini GenerateContentConfig.""" + system_instruction: list[genai.types.Part] = [] - Args: - request: Genkit request. + # 1. System messages + system_messages = list(filter(lambda m: m.role == Role.SYSTEM, request.messages)) + for m in system_messages: + if m.content: + for p in m.content: + converted = await PartConverter.to_gemini(p) + if isinstance(converted, list): + system_instruction.extend(converted) + else: + system_instruction.append(converted) - Returns: - Google Ai request config or None. - """ cfg = None tools = [] @@ -1665,10 +1720,10 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener cfg = genai_types.GenerateContentConfig(**dumped_config) - if request.output: - if not cfg: - cfg = genai_types.GenerateContentConfig() + if not cfg: + cfg = genai_types.GenerateContentConfig() + if request.output: response_mime_type = 'application/json' if request.output.format == 'json' and not request.tools else None cfg.response_mime_type = response_mime_type @@ -1676,31 +1731,12 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener cfg.response_schema = self._convert_schema_property(request.output.schema) if request.tools: - if not cfg: - cfg = genai_types.GenerateContentConfig() - tools.extend(self._get_tools(request)) if tools: - if not cfg: - cfg = genai_types.GenerateContentConfig() cfg.tools = tools - system_messages = list(filter(lambda m: m.role == Role.SYSTEM, request.messages)) - if system_messages: - system_parts = [] - if not cfg: - cfg = genai.types.GenerateContentConfig() - - for msg in system_messages: - for p in msg.content: - converted = PartConverter.to_gemini(p) - if isinstance(converted, list): - system_parts.extend(converted) - else: - system_parts.append(converted) - cfg.system_instruction = genai.types.Content(parts=system_parts) - + cfg.system_instruction = system_instruction if system_instruction else None return cfg def _create_usage_stats(self, request: GenerateRequest, response: GenerateResponse) -> GenerationUsage: diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py index 304829a323..9381bbb37b 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py @@ -19,8 +19,10 @@ import base64 from typing import cast +import httpx from google import genai +from genkit.core.http_client import get_cached_client from genkit.core.typing import DocumentPart, Metadata from genkit.types import ( CustomPart, @@ -62,7 +64,7 @@ class PartConverter: DATA = 'data:' @classmethod - def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[genai.types.Part]: + async def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[genai.types.Part]: """Maps a Genkit Part to a Gemini Part. This method inspects the root type of the Genkit Part and converts it @@ -164,6 +166,22 @@ def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[genai.t ) ) + + if url.startswith('http'): + try: + data, mime_type = await cls._download_image(url) + # If mime type wasn't in headers, fallback to existing or default + mime_type = mime_type or part.root.media.content_type or 'image/jpeg' + return genai.types.Part( + inline_data=genai.types.Blob( + mime_type=mime_type, + data=data, + ) + ) + except Exception: + # Fallback to file_uri if download fails + pass + return genai.types.Part( file_data=genai.types.FileData( mime_type=part.root.media.content_type, @@ -240,7 +258,7 @@ def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part: ref=ref or getattr(part.function_call, 'id', None), # restore slashes name=(part.function_call.name or '').replace('__', '/'), - input=part.function_call.args, + input=part.function_call.args if part.function_call.args is not None else {}, ), metadata=cls._encode_thought_signature(part.thought_signature), ) @@ -305,3 +323,18 @@ def _encode_thought_signature(cls, thought_signature: bytes | None) -> Metadata if thought_signature: return Metadata(root={'thoughtSignature': base64.b64encode(thought_signature).decode('utf-8')}) return None + + @classmethod + async def _download_image(cls, url: str) -> tuple[bytes, str | None]: + """Downloads an image from a URL. + + Args: + url: The URL to download. + + Returns: + A tuple containing the image content (bytes) and its MIME type (str or None). + """ + client = get_cached_client(cache_key='google_genai_media') + response = await client.get(url, timeout=60.0) + response.raise_for_status() + return response.content, response.headers.get('content-type') diff --git a/py/samples/google-genai-hello/src/main.py b/py/samples/google-genai-hello/src/main.py index 412c9edb3f..1d77798caf 100755 --- a/py/samples/google-genai-hello/src/main.py +++ b/py/samples/google-genai-hello/src/main.py @@ -766,8 +766,10 @@ async def tool_calling(input: ToolCallingInput) -> str: async def main() -> None: """Main function - keep alive for Dev UI.""" - # Keep the process alive for Dev UI - _ = await asyncio.Event().wait() + await logger.ainfo('Starting main execution loop') + while True: + await asyncio.sleep(3600) + await logger.ainfo('Exiting main execution loop') if __name__ == '__main__': From a3855cc242a619c09eb773fd91f7479ce1a3e158 Mon Sep 17 00:00:00 2001 From: Mengqin Shen Date: Thu, 5 Feb 2026 22:31:50 -0800 Subject: [PATCH 2/3] fix(py): fix lint errors and tox test errors --- .../genkit/src/genkit/ai/_base_async.py | 72 +++++++++++++++---- .../plugins/google_genai/models/embedder.py | 6 +- .../plugins/google_genai/models/gemini.py | 23 +++--- .../plugins/google_genai/models/utils.py | 4 +- .../google-genai/test/google_plugin_test.py | 5 +- py/pyproject.toml | 2 +- 6 files changed, 77 insertions(+), 35 deletions(-) diff --git a/py/packages/genkit/src/genkit/ai/_base_async.py b/py/packages/genkit/src/genkit/ai/_base_async.py index c232b57ff8..8f849c65b2 100644 --- a/py/packages/genkit/src/genkit/ai/_base_async.py +++ b/py/packages/genkit/src/genkit/ai/_base_async.py @@ -16,8 +16,9 @@ """Asynchronous server gateway interface implementation for Genkit.""" +import signal from collections.abc import Coroutine -from typing import Any, TypeVar, cast +from typing import Any, TypeVar import anyio import httpx @@ -126,7 +127,7 @@ async def dev_runner() -> T: assert spec is not None # Capture spec in local var for nested functions (pyrefly doesn't narrow closures) server_spec: ServerSpec = spec - user_result: T = None # type: ignore[assignment] + user_result: T | None = None user_task_finished_event = anyio.Event() async def run_user_coro_wrapper() -> None: @@ -134,25 +135,64 @@ async def run_user_coro_wrapper() -> None: nonlocal user_result try: user_result = await coro + logger.debug('User coroutine completed successfully.') except Exception as err: + # Log error but don't necessarily stop the server logger.error(f'User coroutine failed: {err}', exc_info=True) + # Store exception? Or let TaskGroup handle it if critical? + # Depending on desired behavior, could raise here to stop everything. + pass # Continue running server for now finally: user_task_finished_event.set() reflection_server = _make_reflection_server(self.registry, server_spec) + # Setup signal handlers for graceful shutdown (parity with JS) + + # Actually, anyio.run handles Ctrl+C (SIGINT) by raising KeyboardInterrupt/CancelledError + # For SIGTERM, we might need to be explicit if we run in a container/process manager. + # JS uses: process.on('SIGTERM', shutdown); process.on('SIGINT', shutdown); + + # Since anyio/asyncio handles SIGINT well, let's add a task to catch SIGTERM + async def handle_sigterm(tg_to_cancel: anyio.abc.TaskGroup) -> None: # type: ignore[name-defined] + with anyio.open_signal_receiver(signal.SIGTERM) as signals: + async for _signum in signals: + logger.info('Received SIGTERM, cancelling tasks...') + tg_to_cancel.cancel_scope.cancel() + return + try: # Use lazy_write=True to prevent race condition where file exists before server is up async with RuntimeManager(server_spec, lazy_write=True) as runtime_manager: + # We use anyio.TaskGroup because it is compatible with + # asyncio's event loop and works with Python 3.10 + # (asyncio.TaskGroup was added in 3.11, and we can switch to + # that when we drop support for 3.10). async with anyio.create_task_group() as tg: # Start reflection server in the background. - tg.start_soon(reflection_server.serve) - logger.info(f'Started Genkit reflection server at {server_spec.url}') + tg.start_soon(reflection_server.serve, name='genkit-reflection-server') + await logger.ainfo(f'Started Genkit reflection server at {server_spec.url}') + + # Start SIGTERM handler + tg.start_soon(handle_sigterm, tg, name='genkit-sigterm-handler') + + # Wait for server to be responsive + # We need to loop and poll the health endpoint or wait for uvicorn to be ready + # Since uvicorn run is blocking (but we are in a task), we can't easily hook into its startup + # unless we use uvicorn's server object directly which we do. + # reflection_server.started is set when uvicorn starts. + + # Simple polling loop - # Wait for the server to be healthy before starting the user task. max_retries = 20 # 2 seconds total roughly for _i in range(max_retries): try: + # TODO(#4334): Use async http client if available to avoid blocking loop? + # But we are in dev mode, so maybe okay. + # Actually we should use anyio.to_thread to avoid blocking event loop + # or assume standard lib urllib is fast enough for localhost. + + # Use httpx async client to avoid blocking the event loop health_url = f'{server_spec.url}/api/__health' async with httpx.AsyncClient(timeout=0.5) as client: response = await client.get(health_url) @@ -167,25 +207,29 @@ async def run_user_coro_wrapper() -> None: _ = runtime_manager.write_runtime_file() # Start the (potentially short-lived) user coroutine wrapper - tg.start_soon(run_user_coro_wrapper) - logger.info('Started Genkit user coroutine') + tg.start_soon(run_user_coro_wrapper, name='genkit-user-coroutine') + await logger.ainfo('Started Genkit user coroutine') # Block here until the task group is canceled (e.g. Ctrl+C) # or a task raises an unhandled exception. It should not # exit just because the user coroutine finishes. - await anyio.Event().wait() + + except anyio.get_cancelled_exc_class(): + logger.info('Development server task group cancelled (e.g., Ctrl+C).') + raise except Exception: logger.exception('Development server task group error') raise - # After the TaskGroup finishes (normally or by task completion). + # After the TaskGroup finishes (error or cancelation). if user_task_finished_event.is_set(): - if user_result is not None: - return user_result - else: - raise RuntimeError('User coroutine finished without a result (likely cancelled).') + await logger.adebug('User coroutine finished before TaskGroup exit.') + if user_result is None: + raise RuntimeError('User coroutine finished without a result.') + return user_result - return None # type: ignore[return-value] + await logger.adebug('User coroutine did not finish before TaskGroup exit.') + raise RuntimeError('User coroutine did not finish before TaskGroup exit.') return anyio.run(dev_runner) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py index 32a8a670b7..b5addb7847 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py @@ -94,7 +94,7 @@ async def generate(self, request: EmbedRequest) -> EmbedResponse: Returns: EmbedResponse """ - contents = self._build_contents(request) + contents = await self._build_contents(request) config = self._genkit_to_googleai_cfg(request) response = await self._client.aio.models.embed_content( model=self._version, @@ -105,7 +105,7 @@ async def generate(self, request: EmbedRequest) -> EmbedResponse: embeddings = [Embedding(embedding=em.values or []) for em in (response.embeddings or [])] return EmbedResponse(embeddings=embeddings) - def _build_contents(self, request: EmbedRequest) -> list[genai.types.Content]: + async def _build_contents(self, request: EmbedRequest) -> list[genai.types.Content]: """Build google-genai request contents from Genkit request. Args: @@ -118,7 +118,7 @@ def _build_contents(self, request: EmbedRequest) -> list[genai.types.Content]: for doc in request.input: content_parts: list[genai.types.Part] = [] for p in doc.content: - converted = PartConverter.to_gemini(p) + converted = await PartConverter.to_gemini(p) if isinstance(converted, list): content_parts.extend(converted) else: diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 999610ebf3..b48f25c4d9 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -157,14 +157,14 @@ from genkit.codec import dump_dict, dump_json from genkit.core.error import GenkitError, StatusName from genkit.core.tracing import tracer -from genkit.lang.deprecations import ( - deprecated_enum_metafactory, -) -from genkit.plugins.google_genai.models.utils import PartConverter from genkit.core.typing import ( Candidate, FinishReason, ) +from genkit.lang.deprecations import ( + deprecated_enum_metafactory, +) +from genkit.plugins.google_genai.models.utils import PartConverter from genkit.types import ( Constrained, GenerateRequest, @@ -174,6 +174,7 @@ GenerationUsage, Message, ModelInfo, + Part, Role, Stage, Supports, @@ -1400,7 +1401,7 @@ async def _generate( # Ensure we always have at least one content item to avoid UI errors if not content: - content = [TextPart(text='')] + content = [Part(root=TextPart(text=''))] finish_reason = FinishReason.OTHER candidates = [] @@ -1414,7 +1415,7 @@ async def _generate( c_content.append(converted) if not c_content: - c_content = [TextPart(text='')] + c_content = [Part(root=TextPart(text=''))] c_finish_reason = FinishReason.OTHER if c.finish_reason: @@ -1441,7 +1442,7 @@ async def _generate( return GenerateResponse( message=Message( - content=content, # type: ignore[arg-type] - content is list[Part] after conversion + content=content, role=Role.MODEL, ), finish_reason=finish_reason, @@ -1453,9 +1454,7 @@ async def _generate( output_tokens=float(response.usage_metadata.candidates_token_count or 0) if response.usage_metadata else None, - total_tokens=float(response.usage_metadata.total_token_count or 0) - if response.usage_metadata - else None, + total_tokens=float(response.usage_metadata.total_token_count or 0) if response.usage_metadata else None, ), ) @@ -1513,7 +1512,7 @@ async def _streaming_generate( ) from e accumulated_content = [] async for response_chunk in await generator: - content = self._contents_from_response(response_chunk) + content = await self._contents_from_response(response_chunk) if content: # Only process if we have content accumulated_content.extend(content) ctx.send_chunk( @@ -1736,7 +1735,7 @@ async def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types if tools: cfg.tools = tools - cfg.system_instruction = system_instruction if system_instruction else None + cfg.system_instruction = genai_types.Content(parts=system_instruction) if system_instruction else None return cfg def _create_usage_stats(self, request: GenerateRequest, response: GenerateResponse) -> GenerationUsage: diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py index 9381bbb37b..eefa1cb3b4 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py @@ -19,7 +19,6 @@ import base64 from typing import cast -import httpx from google import genai from genkit.core.http_client import get_cached_client @@ -166,7 +165,6 @@ async def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[g ) ) - if url.startswith('http'): try: data, mime_type = await cls._download_image(url) @@ -178,7 +176,7 @@ async def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[g data=data, ) ) - except Exception: + except Exception: # noqa: S110 - intentionally silent, fallback to file_uri # Fallback to file_uri if download fails pass diff --git a/py/plugins/google-genai/test/google_plugin_test.py b/py/plugins/google-genai/test/google_plugin_test.py index e6db3feee6..88cd7ad738 100644 --- a/py/plugins/google-genai/test/google_plugin_test.py +++ b/py/plugins/google-genai/test/google_plugin_test.py @@ -726,7 +726,8 @@ def test_config_schema_extra_fields() -> None: assert config.model_dump()['new_experimental_param'] == 'test' -def test_system_prompt_handling() -> None: +@pytest.mark.asyncio +async def test_system_prompt_handling() -> None: """Test that system prompts are correctly extracted to config.""" mock_client = MagicMock(spec=genai.Client) model = GeminiModel(version='gemini-1.5-flash', client=mock_client) @@ -739,7 +740,7 @@ def test_system_prompt_handling() -> None: config=None, ) - cfg = model._genkit_to_googleai_cfg(request) + cfg = await model._genkit_to_googleai_cfg(request) assert cfg is not None assert cfg.system_instruction is not None diff --git a/py/pyproject.toml b/py/pyproject.toml index 6d718aef8d..cbd6260b0d 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -75,7 +75,7 @@ dev = [ lint = [ "bandit>=1.7.0", "deptry>=0.22.0", - "litestar>=2.0.0", # For web/typing.py type resolution + "litestar>=2.0.0", # For web/typing.py type resolution "mypy>=1.14.0", "pip-audit>=2.7.0", "pypdf>=6.6.2", From 30222744198565caa2e1a748d8346f93bc5db41a Mon Sep 17 00:00:00 2001 From: Mengqin Shen Date: Thu, 5 Feb 2026 23:04:51 -0800 Subject: [PATCH 3/3] fix(py): fix nox test errors --- .../plugins/google_genai/models/gemini.py | 88 +++++++++---------- 1 file changed, 40 insertions(+), 48 deletions(-) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index b48f25c4d9..cee58628a3 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -1489,7 +1489,7 @@ async def _streaming_generate( ) client = client or self._client try: - generator = client.aio.models.generate_content_stream( + generator = await client.aio.models.generate_content_stream( model=model_name, contents=cast(genai_types.ContentListUnion, request_contents), config=request_cfg, @@ -1511,7 +1511,7 @@ async def _streaming_generate( cause=e, ) from e accumulated_content = [] - async for response_chunk in await generator: + async for response_chunk in generator: content = await self._contents_from_response(response_chunk) if content: # Only process if we have content accumulated_content.extend(content) @@ -1632,13 +1632,11 @@ async def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types if isinstance(request_config, GeminiConfigSchema): cfg = request_config elif isinstance(request_config, GenerationCommonConfig): - cfg = genai_types.GenerateContentConfig( - max_output_tokens=request_config.max_output_tokens, - top_k=request_config.top_k, - top_p=request_config.top_p, - temperature=request_config.temperature, - stop_sequences=request_config.stop_sequences, - ) + dumped = request_config.model_dump(exclude_none=True) + if dumped: + cfg = genai_types.GenerateContentConfig(**dumped) + else: + cfg = None elif isinstance(request_config, dict): if 'image_config' in request_config: cfg = GeminiImageConfigSchema(**request_config) @@ -1698,45 +1696,39 @@ async def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types if key in dumped_config: del dumped_config[key] - if 'image_config' in dumped_config and isinstance(dumped_config['image_config'], dict): - valid_image_keys = { - 'aspect_ratio', - 'image_size', - 'person_generation', - 'output_mime_type', - 'output_compression_quality', - } - dumped_config['image_config'] = { - k: v for k, v in dumped_config['image_config'].items() if k in valid_image_keys - } - - # Check if image_config is actually supported by the installed SDK version - if ( - 'image_config' in dumped_config - and 'image_config' not in genai_types.GenerateContentConfig.model_fields - ): - del dumped_config['image_config'] - - cfg = genai_types.GenerateContentConfig(**dumped_config) - - if not cfg: - cfg = genai_types.GenerateContentConfig() - - if request.output: - response_mime_type = 'application/json' if request.output.format == 'json' and not request.tools else None - cfg.response_mime_type = response_mime_type - - if request.output.schema and request.output.constrained: - cfg.response_schema = self._convert_schema_property(request.output.schema) - - if request.tools: - tools.extend(self._get_tools(request)) - - if tools: - cfg.tools = tools - - cfg.system_instruction = genai_types.Content(parts=system_instruction) if system_instruction else None - return cfg + # Check for SDK support of newer fields + for key in ['image_config', 'thinking_config', 'response_modalities']: + if key in dumped_config and key not in genai_types.GenerateContentConfig.model_fields: + del dumped_config[key] + + if dumped_config: + cfg = genai_types.GenerateContentConfig(**dumped_config) + else: + cfg = None + + # Tools from top-level field and config-level fields + tools.extend(self._get_tools(request)) + + if cfg is not None or tools or system_instruction or request.output: + if cfg is None: + cfg = genai_types.GenerateContentConfig() + + if request.output: + response_mime_type = ( + 'application/json' if request.output.format == 'json' and not request.tools else None + ) + cfg.response_mime_type = response_mime_type + + if request.output.schema and request.output.constrained: + cfg.response_schema = self._convert_schema_property(request.output.schema) + + if tools: + cfg.tools = tools + + cfg.system_instruction = genai_types.Content(parts=system_instruction) if system_instruction else None + return cfg + + return None def _create_usage_stats(self, request: GenerateRequest, response: GenerateResponse) -> GenerationUsage: """Create usage statistics.