Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def generate(self, request: EmbedRequest) -> EmbedResponse:
Returns:
EmbedResponse
"""
contents = self._build_contents(request)
contents = await self._build_contents(request)
config = self._genkit_to_googleai_cfg(request)
response = await self._client.aio.models.embed_content(
model=self._version,
Expand All @@ -105,7 +105,7 @@ async def generate(self, request: EmbedRequest) -> EmbedResponse:
embeddings = [Embedding(embedding=em.values or []) for em in (response.embeddings or [])]
return EmbedResponse(embeddings=embeddings)

def _build_contents(self, request: EmbedRequest) -> list[genai.types.Content]:
async def _build_contents(self, request: EmbedRequest) -> list[genai.types.Content]:
"""Build google-genai request contents from Genkit request.

Args:
Expand All @@ -118,7 +118,7 @@ def _build_contents(self, request: EmbedRequest) -> list[genai.types.Content]:
for doc in request.input:
content_parts: list[genai.types.Part] = []
for p in doc.content:
converted = PartConverter.to_gemini(p)
converted = await PartConverter.to_gemini(p)
if isinstance(converted, list):
content_parts.extend(converted)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@
from genkit.codec import dump_dict, dump_json
from genkit.core.error import GenkitError, StatusName
from genkit.core.tracing import tracer
from genkit.core.typing import (
Candidate,
FinishReason,
)
from genkit.lang.deprecations import (
deprecated_enum_metafactory,
)
Expand All @@ -170,6 +174,7 @@
GenerationUsage,
Message,
ModelInfo,
Part,
Role,
Stage,
Supports,
Expand Down Expand Up @@ -1219,7 +1224,7 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> Gen

# TODO(#4361): Do not move - this method mutates `request` by extracting system
# prompts into configuration object
request_cfg = self._genkit_to_googleai_cfg(request=request)
request_cfg = await self._genkit_to_googleai_cfg(request=request)

# TTS models require response_modalities: ["AUDIO"]
if is_tts_model(model_name):
Expand All @@ -1229,11 +1234,6 @@ async def generate(self, request: GenerateRequest, ctx: ActionRunContext) -> Gen

# Image models require response_modalities: ["TEXT", "IMAGE"]
if is_image_model(model_name):
if request.tools:
raise ValueError(
f'Model {model_name} does not support tools. '
'Please remove the tools config or use a model that supports tools.'
)
if not request_cfg:
request_cfg = genai_types.GenerateContentConfig()
request_cfg.response_modalities = ['TEXT', 'IMAGE']
Expand Down Expand Up @@ -1397,17 +1397,65 @@ async def _generate(
) from e
span.set_attribute('genkit:output', dump_json(response))

content = self._contents_from_response(response)
content = await self._contents_from_response(response)

# Ensure we always have at least one content item to avoid UI errors
if not content:
content = [TextPart(text='')]
content = [Part(root=TextPart(text=''))]

finish_reason = FinishReason.OTHER
candidates = []
if response.candidates:
for i, c in enumerate(response.candidates):
c_content = []
if c.content and c.content.parts:
for j, part in enumerate(c.content.parts):
converted = PartConverter.from_gemini(part=part, ref=str(j))
if converted:
c_content.append(converted)

if not c_content:
c_content = [Part(root=TextPart(text=''))]

c_finish_reason = FinishReason.OTHER
if c.finish_reason:
fr_name = c.finish_reason.name
if fr_name == 'STOP':
c_finish_reason = FinishReason.STOP
elif fr_name == 'MAX_TOKENS':
c_finish_reason = FinishReason.LENGTH
elif fr_name in ['SAFETY', 'RECITATION', 'BLOCKLIST', 'PROHIBITED_CONTENT', 'SPII']:
c_finish_reason = FinishReason.BLOCKED
elif fr_name == 'OTHER':
c_finish_reason = FinishReason.OTHER

if i == 0:
finish_reason = c_finish_reason

candidates.append(
Candidate(
index=float(i),
message=Message(role=Role.MODEL, content=c_content),
finish_reason=c_finish_reason,
)
)

return GenerateResponse(
message=Message(
content=content, # type: ignore[arg-type] - content is list[Part] after conversion
content=content,
role=Role.MODEL,
)
),
finish_reason=finish_reason,
candidates=candidates,
usage=GenerationUsage(
input_tokens=float(response.usage_metadata.prompt_token_count or 0)
if response.usage_metadata
else None,
output_tokens=float(response.usage_metadata.candidates_token_count or 0)
if response.usage_metadata
else None,
total_tokens=float(response.usage_metadata.total_token_count or 0) if response.usage_metadata else None,
),
)

async def _streaming_generate(
Expand Down Expand Up @@ -1441,7 +1489,7 @@ async def _streaming_generate(
)
client = client or self._client
try:
generator = client.aio.models.generate_content_stream(
generator = await client.aio.models.generate_content_stream(
model=model_name,
contents=cast(genai_types.ContentListUnion, request_contents),
config=request_cfg,
Expand All @@ -1463,8 +1511,8 @@ async def _streaming_generate(
cause=e,
) from e
accumulated_content = []
async for response_chunk in await generator:
content = self._contents_from_response(response_chunk)
async for response_chunk in generator:
content = await self._contents_from_response(response_chunk)
if content: # Only process if we have content
accumulated_content.extend(content)
ctx.send_chunk(
Expand Down Expand Up @@ -1520,7 +1568,7 @@ async def _build_messages(
continue
content_parts: list[genai_types.Part] = []
for p in msg.content:
converted = PartConverter.to_gemini(p)
converted = await PartConverter.to_gemini(p)
if isinstance(converted, list):
content_parts.extend(converted)
else:
Expand All @@ -1540,7 +1588,7 @@ async def _build_messages(

return request_contents, cache

def _contents_from_response(self, response: genai_types.GenerateContentResponse) -> list:
async def _contents_from_response(self, response: genai_types.GenerateContentResponse) -> list:
"""Retrieve contents from google-genai response.

Args:
Expand All @@ -1561,15 +1609,21 @@ def _contents_from_response(self, response: genai_types.GenerateContentResponse)
# Ensure we always return a list, even if empty
return content if content else []

def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.GenerateContentConfig | None:
"""Translate GenerationCommonConfig to Google Ai GenerateContentConfig.
async def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.GenerateContentConfig | None:
"""Converts a Genkit GenerateRequest to a Gemini GenerateContentConfig."""
system_instruction: list[genai.types.Part] = []

Args:
request: Genkit request.
# 1. System messages
system_messages = list(filter(lambda m: m.role == Role.SYSTEM, request.messages))
for m in system_messages:
if m.content:
for p in m.content:
converted = await PartConverter.to_gemini(p)
if isinstance(converted, list):
system_instruction.extend(converted)
else:
system_instruction.append(converted)

Returns:
Google Ai request config or None.
"""
cfg = None
tools = []

Expand All @@ -1578,13 +1632,11 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener
if isinstance(request_config, GeminiConfigSchema):
cfg = request_config
elif isinstance(request_config, GenerationCommonConfig):
cfg = genai_types.GenerateContentConfig(
max_output_tokens=request_config.max_output_tokens,
top_k=request_config.top_k,
top_p=request_config.top_p,
temperature=request_config.temperature,
stop_sequences=request_config.stop_sequences,
)
dumped = request_config.model_dump(exclude_none=True)
if dumped:
cfg = genai_types.GenerateContentConfig(**dumped)
else:
cfg = None
elif isinstance(request_config, dict):
if 'image_config' in request_config:
cfg = GeminiImageConfigSchema(**request_config)
Expand Down Expand Up @@ -1644,64 +1696,39 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener
if key in dumped_config:
del dumped_config[key]

if 'image_config' in dumped_config and isinstance(dumped_config['image_config'], dict):
valid_image_keys = {
'aspect_ratio',
'image_size',
'person_generation',
'output_mime_type',
'output_compression_quality',
}
dumped_config['image_config'] = {
k: v for k, v in dumped_config['image_config'].items() if k in valid_image_keys
}

# Check if image_config is actually supported by the installed SDK version
if (
'image_config' in dumped_config
and 'image_config' not in genai_types.GenerateContentConfig.model_fields
):
del dumped_config['image_config']

cfg = genai_types.GenerateContentConfig(**dumped_config)

if request.output:
if not cfg:
cfg = genai_types.GenerateContentConfig()
# Check for SDK support of newer fields
for key in ['image_config', 'thinking_config', 'response_modalities']:
if key in dumped_config and key not in genai_types.GenerateContentConfig.model_fields:
del dumped_config[key]

response_mime_type = 'application/json' if request.output.format == 'json' and not request.tools else None
cfg.response_mime_type = response_mime_type
if dumped_config:
cfg = genai_types.GenerateContentConfig(**dumped_config)
else:
cfg = None

if request.output.schema and request.output.constrained:
cfg.response_schema = self._convert_schema_property(request.output.schema)
# Tools from top-level field and config-level fields
tools.extend(self._get_tools(request))

if request.tools:
if not cfg:
if cfg is not None or tools or system_instruction or request.output:
if cfg is None:
cfg = genai_types.GenerateContentConfig()

tools.extend(self._get_tools(request))
if request.output:
response_mime_type = (
'application/json' if request.output.format == 'json' and not request.tools else None
)
cfg.response_mime_type = response_mime_type

if tools:
if not cfg:
cfg = genai_types.GenerateContentConfig()
cfg.tools = tools
if request.output.schema and request.output.constrained:
cfg.response_schema = self._convert_schema_property(request.output.schema)

system_messages = list(filter(lambda m: m.role == Role.SYSTEM, request.messages))
if system_messages:
system_parts = []
if not cfg:
cfg = genai.types.GenerateContentConfig()

for msg in system_messages:
for p in msg.content:
converted = PartConverter.to_gemini(p)
if isinstance(converted, list):
system_parts.extend(converted)
else:
system_parts.append(converted)
cfg.system_instruction = genai.types.Content(parts=system_parts)
if tools:
cfg.tools = tools

cfg.system_instruction = genai_types.Content(parts=system_instruction) if system_instruction else None
return cfg

return cfg
return None

def _create_usage_stats(self, request: GenerateRequest, response: GenerateResponse) -> GenerationUsage:
"""Create usage statistics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from google import genai

from genkit.core.http_client import get_cached_client
from genkit.core.typing import DocumentPart, Metadata
from genkit.types import (
CustomPart,
Expand Down Expand Up @@ -62,7 +63,7 @@ class PartConverter:
DATA = 'data:'

@classmethod
def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[genai.types.Part]:
async def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[genai.types.Part]:
"""Maps a Genkit Part to a Gemini Part.

This method inspects the root type of the Genkit Part and converts it
Expand Down Expand Up @@ -164,6 +165,21 @@ def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[genai.t
)
)

if url.startswith('http'):
try:
data, mime_type = await cls._download_image(url)
# If mime type wasn't in headers, fallback to existing or default
mime_type = mime_type or part.root.media.content_type or 'image/jpeg'
return genai.types.Part(
inline_data=genai.types.Blob(
mime_type=mime_type,
data=data,
)
)
except Exception: # noqa: S110 - intentionally silent, fallback to file_uri
# Fallback to file_uri if download fails
pass

return genai.types.Part(
file_data=genai.types.FileData(
mime_type=part.root.media.content_type,
Expand Down Expand Up @@ -240,7 +256,7 @@ def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part:
ref=ref or getattr(part.function_call, 'id', None),
# restore slashes
name=(part.function_call.name or '').replace('__', '/'),
input=part.function_call.args,
input=part.function_call.args if part.function_call.args is not None else {},
),
metadata=cls._encode_thought_signature(part.thought_signature),
)
Expand Down Expand Up @@ -305,3 +321,18 @@ def _encode_thought_signature(cls, thought_signature: bytes | None) -> Metadata
if thought_signature:
return Metadata(root={'thoughtSignature': base64.b64encode(thought_signature).decode('utf-8')})
return None

@classmethod
async def _download_image(cls, url: str) -> tuple[bytes, str | None]:
"""Downloads an image from a URL.

Args:
url: The URL to download.

Returns:
A tuple containing the image content (bytes) and its MIME type (str or None).
"""
client = get_cached_client(cache_key='google_genai_media')
response = await client.get(url, timeout=60.0)
response.raise_for_status()
return response.content, response.headers.get('content-type')
5 changes: 3 additions & 2 deletions py/plugins/google-genai/test/google_plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ def test_config_schema_extra_fields() -> None:
assert config.model_dump()['new_experimental_param'] == 'test'


def test_system_prompt_handling() -> None:
@pytest.mark.asyncio
async def test_system_prompt_handling() -> None:
"""Test that system prompts are correctly extracted to config."""
mock_client = MagicMock(spec=genai.Client)
model = GeminiModel(version='gemini-1.5-flash', client=mock_client)
Expand All @@ -739,7 +740,7 @@ def test_system_prompt_handling() -> None:
config=None,
)

cfg = model._genkit_to_googleai_cfg(request)
cfg = await model._genkit_to_googleai_cfg(request)

assert cfg is not None
assert cfg.system_instruction is not None
Expand Down
Loading
Loading