Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 30 additions & 25 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from typing import TYPE_CHECKING
from typing import Union

from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropicVertex
from anthropic import AnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
Expand All @@ -42,7 +41,7 @@
if TYPE_CHECKING:
from .llm_request import LlmRequest

__all__ = ["AnthropicLlm", "Claude"]
__all__ = ["Claude"]

logger = logging.getLogger("google_adk." + __name__)

Expand Down Expand Up @@ -77,13 +76,23 @@ def _is_image_part(part: types.Part) -> bool:
)


def _is_pdf_part(part: types.Part) -> bool:
"""Check if the part contains PDF data."""
return (
part.inline_data
and part.inline_data.mime_type
and part.inline_data.mime_type == "application/pdf"
)


def part_to_message_block(
part: types.Part,
) -> Union[
anthropic_types.TextBlockParam,
anthropic_types.ImageBlockParam,
anthropic_types.ToolUseBlockParam,
anthropic_types.ToolResultBlockParam,
anthropic_types.DocumentBlockParam, # For PDF document blocks
]:
if part.text:
return anthropic_types.TextBlockParam(text=part.text, type="text")
Expand Down Expand Up @@ -135,6 +144,18 @@ def part_to_message_block(
type="base64", media_type=part.inline_data.mime_type, data=data
),
)
elif _is_pdf_part(part):
# Handle PDF documents - Anthropic supports PDFs as document blocks
# PDFs are encoded as base64 and sent with document type
data = base64.b64encode(part.inline_data.data).decode()
return anthropic_types.DocumentBlockParam(
type="document",
source={
"type": "base64",
"media_type": part.inline_data.mime_type,
"data": data,
},
)
elif part.executable_code:
return anthropic_types.TextBlockParam(
type="text",
Expand Down Expand Up @@ -265,15 +286,15 @@ def function_declaration_to_tool_param(
)


class AnthropicLlm(BaseLlm):
"""Integration with Claude models via the Anthropic API.
class Claude(BaseLlm):
"""Integration with Claude models served from Vertex AI.

Attributes:
model: The name of the Claude model.
max_tokens: The maximum number of tokens to generate.
"""

model: str = "claude-sonnet-4-20250514"
model: str = "claude-3-5-sonnet-v2@20241022"
max_tokens: int = 8192

@classmethod
Expand Down Expand Up @@ -305,7 +326,7 @@ async def generate_content_async(
else NOT_GIVEN
)
# TODO(b/421255973): Enable streaming for anthropic models.
message = await self._anthropic_client.messages.create(
message = self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
Expand All @@ -316,23 +337,7 @@ async def generate_content_async(
yield message_to_generate_content_response(message)

@cached_property
def _anthropic_client(self) -> AsyncAnthropic:
return AsyncAnthropic()


class Claude(AnthropicLlm):
"""Integration with Claude models served from Vertex AI.

Attributes:
model: The name of the Claude model.
max_tokens: The maximum number of tokens to generate.
"""

model: str = "claude-3-5-sonnet-v2@20241022"

@cached_property
@override
def _anthropic_client(self) -> AsyncAnthropicVertex:
def _anthropic_client(self) -> AnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
Expand All @@ -342,7 +347,7 @@ def _anthropic_client(self) -> AsyncAnthropicVertex:
" Anthropic on Vertex."
)

return AsyncAnthropicVertex(
return AnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
)
65 changes: 33 additions & 32 deletions tests/unittests/models/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from anthropic import types as anthropic_types
from google.adk import version as adk_version
from google.adk.models import anthropic_llm
from google.adk.models.anthropic_llm import AnthropicLlm
from google.adk.models.anthropic_llm import Claude
from google.adk.models.anthropic_llm import content_to_message_param
from google.adk.models.anthropic_llm import function_declaration_to_tool_param
Expand Down Expand Up @@ -360,37 +359,6 @@ async def mock_coro():
assert responses[0].content.parts[0].text == "Hello, how can I help you?"


@pytest.mark.asyncio
async def test_anthropic_llm_generate_content_async(
llm_request, generate_content_response, generate_llm_response
):
anthropic_llm_instance = AnthropicLlm(model="claude-sonnet-4-20250514")
with mock.patch.object(
anthropic_llm_instance, "_anthropic_client"
) as mock_client:
with mock.patch.object(
anthropic_llm,
"message_to_generate_content_response",
return_value=generate_llm_response,
):
# Create a mock coroutine that returns the generate_content_response.
async def mock_coro():
return generate_content_response

# Assign the coroutine to the mocked method
mock_client.messages.create.return_value = mock_coro()

responses = [
resp
async for resp in anthropic_llm_instance.generate_content_async(
llm_request, stream=False
)
]
assert len(responses) == 1
assert isinstance(responses[0], LlmResponse)
assert responses[0].content.parts[0].text == "Hello, how can I help you?"


@pytest.mark.asyncio
async def test_generate_content_async_with_max_tokens(
llm_request, generate_content_response, generate_llm_response
Expand Down Expand Up @@ -497,6 +465,39 @@ def test_part_to_message_block_with_multiple_content_items():
assert result["content"] == "First part\nSecond part"


def test_part_to_message_block_with_pdf():
"""Test that part_to_message_block handles PDF documents."""
import base64

from anthropic import types as anthropic_types
from google.adk.models.anthropic_llm import part_to_message_block

# Create a PDF part with inline data
pdf_data = (
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\nxref\n0"
b" 1\ntrailer\n<<\n/Root 1 0 R\n>>\n%%EOF"
)
pdf_part = types.Part(
inline_data=types.Blob(
mime_type="application/pdf",
data=pdf_data,
)
)

result = part_to_message_block(pdf_part)

# PDF should be returned as DocumentBlockParam (TypedDict, which is a dict)
assert isinstance(result, dict)
# Verify it matches DocumentBlockParam structure
assert result["type"] == "document"
assert "source" in result
assert result["source"]["type"] == "base64"
assert result["source"]["media_type"] == "application/pdf"
# Verify the data is base64 encoded and can be decoded back
decoded_data = base64.b64decode(result["source"]["data"])
assert decoded_data == pdf_data


content_to_message_param_test_cases = [
(
"user_role_with_text_and_image",
Expand Down