Skip to content

Commit 0a07a66

Browse files
GWealecopybara-github
authored andcommitted
feat: Change service creation and add app name mapping for sessions
This change refactors how session, memory, and artifact services are created in the fast_api server, using the shared service_factory. Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 839997110
1 parent e9182e5 commit 0a07a66

File tree

7 files changed

+121
-60
lines changed

7 files changed

+121
-60
lines changed

src/google/adk/cli/cli.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,22 @@ async def run_cli(
159159
load_services_module(str(agent_root))
160160
user_id = 'test_user'
161161

162+
agents_dir = str(agent_parent_path)
163+
agent_loader = AgentLoader(agents_dir=agents_dir)
164+
agent_or_app = agent_loader.load_agent(agent_folder_name)
165+
session_app_name = (
166+
agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name
167+
)
168+
app_name_to_dir = None
169+
if isinstance(agent_or_app, App) and agent_or_app.name != agent_folder_name:
170+
app_name_to_dir = {agent_or_app.name: agent_folder_name}
171+
162172
# Create session and artifact services using factory functions
163173
# Sessions persist under <agents_dir>/<agent>/.adk/session.db by default.
164174
session_service = create_session_service_from_options(
165175
base_dir=agent_parent_path,
166176
session_service_uri=session_service_uri,
177+
app_name_to_dir=app_name_to_dir,
167178
)
168179

169180
artifact_service = create_artifact_service_from_options(
@@ -172,13 +183,6 @@ async def run_cli(
172183
)
173184

174185
credential_service = InMemoryCredentialService()
175-
agents_dir = str(agent_parent_path)
176-
agent_or_app = AgentLoader(agents_dir=agents_dir).load_agent(
177-
agent_folder_name
178-
)
179-
session_app_name = (
180-
agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name
181-
)
182186
if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'):
183187
envs.load_dotenv_for_agent(agent_folder_name, agents_dir)
184188

src/google/adk/cli/fast_api.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,24 @@
3535
from starlette.types import Lifespan
3636
from watchdog.observers import Observer
3737

38-
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
3938
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
4039
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
4140
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
42-
from ..memory.in_memory_memory_service import InMemoryMemoryService
4341
from ..runners import Runner
44-
from ..sessions.in_memory_session_service import InMemorySessionService
4542
from .adk_web_server import AdkWebServer
46-
from .service_registry import get_service_registry
4743
from .service_registry import load_services_module
4844
from .utils import envs
4945
from .utils import evals
5046
from .utils.agent_change_handler import AgentChangeEventHandler
5147
from .utils.agent_loader import AgentLoader
48+
from .utils.service_factory import create_artifact_service_from_options
49+
from .utils.service_factory import create_memory_service_from_options
50+
from .utils.service_factory import create_session_service_from_options
5251

5352
logger = logging.getLogger("google_adk." + __name__)
5453

5554
_LAZY_SERVICE_IMPORTS: dict[str, str] = {
5655
"AgentLoader": ".utils.agent_loader",
57-
"InMemoryArtifactService": "..artifacts.in_memory_artifact_service",
58-
"InMemoryMemoryService": "..memory.in_memory_memory_service",
59-
"InMemorySessionService": "..sessions.in_memory_session_service",
6056
"LocalEvalSetResultsManager": "..evaluation.local_eval_set_results_manager",
6157
"LocalEvalSetsManager": "..evaluation.local_eval_sets_manager",
6258
}
@@ -112,48 +108,31 @@ def get_fast_api_app(
112108
# Load services.py from agents_dir for custom service registration.
113109
load_services_module(agents_dir)
114110

115-
service_registry = get_service_registry()
116-
117111
# Build the Memory service
118-
if memory_service_uri:
119-
memory_service = service_registry.create_memory_service(
120-
memory_service_uri, agents_dir=agents_dir
112+
try:
113+
memory_service = create_memory_service_from_options(
114+
base_dir=agents_dir,
115+
memory_service_uri=memory_service_uri,
121116
)
122-
if not memory_service:
123-
raise click.ClickException(
124-
"Unsupported memory service URI: %s" % memory_service_uri
125-
)
126-
else:
127-
memory_service = InMemoryMemoryService()
117+
except ValueError as exc:
118+
raise click.ClickException(str(exc)) from exc
128119

129120
# Build the Session service
130-
if session_service_uri:
131-
session_kwargs = session_db_kwargs or {}
132-
session_service = service_registry.create_session_service(
133-
session_service_uri, agents_dir=agents_dir, **session_kwargs
134-
)
135-
if not session_service:
136-
# Fallback to DatabaseSessionService if the service registry doesn't
137-
# support the session service URI scheme.
138-
from ..sessions.database_session_service import DatabaseSessionService
139-
140-
session_service = DatabaseSessionService(
141-
db_url=session_service_uri, **session_kwargs
142-
)
143-
else:
144-
session_service = InMemorySessionService()
121+
session_service = create_session_service_from_options(
122+
base_dir=agents_dir,
123+
session_service_uri=session_service_uri,
124+
session_db_kwargs=session_db_kwargs,
125+
)
145126

146127
# Build the Artifact service
147-
if artifact_service_uri:
148-
artifact_service = service_registry.create_artifact_service(
149-
artifact_service_uri, agents_dir=agents_dir
128+
try:
129+
artifact_service = create_artifact_service_from_options(
130+
base_dir=agents_dir,
131+
artifact_service_uri=artifact_service_uri,
132+
strict_uri=True,
150133
)
151-
if not artifact_service:
152-
raise click.ClickException(
153-
"Unsupported artifact service URI: %s" % artifact_service_uri
154-
)
155-
else:
156-
artifact_service = InMemoryArtifactService()
134+
except ValueError as exc:
135+
raise click.ClickException(str(exc)) from exc
157136

158137
# Build the Credential service
159138
credential_service = InMemoryCredentialService()

src/google/adk/cli/utils/local_storage.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import asyncio
1818
import logging
1919
from pathlib import Path
20+
from typing import Mapping
2021
from typing import Optional
2122

2223
from typing_extensions import override
@@ -61,6 +62,7 @@ def create_local_session_service(
6162
*,
6263
base_dir: Path | str,
6364
per_agent: bool = False,
65+
app_name_to_dir: Optional[Mapping[str, str]] = None,
6466
) -> BaseSessionService:
6567
"""Creates a local SQLite-backed session service.
6668
@@ -69,6 +71,8 @@ def create_local_session_service(
6971
per_agent: If True, creates a PerAgentDatabaseSessionService that stores
7072
sessions in each agent's .adk folder. If False, creates a single
7173
SqliteSessionService at base_dir/.adk/session.db.
74+
app_name_to_dir: Optional mapping from logical app name to on-disk agent
75+
folder name. Only used when per_agent is True; defaults to identity.
7276
7377
Returns:
7478
A BaseSessionService instance backed by SQLite.
@@ -78,7 +82,10 @@ def create_local_session_service(
7882
"Using per-agent session storage rooted at %s",
7983
base_dir,
8084
)
81-
return PerAgentDatabaseSessionService(agents_root=base_dir)
85+
return PerAgentDatabaseSessionService(
86+
agents_root=base_dir,
87+
app_name_to_dir=app_name_to_dir,
88+
)
8289

8390
return create_local_database_session_service(base_dir=base_dir)
8491

@@ -108,23 +115,26 @@ def __init__(
108115
self,
109116
*,
110117
agents_root: Path | str,
118+
app_name_to_dir: Optional[Mapping[str, str]] = None,
111119
):
112120
self._agents_root = Path(agents_root).resolve()
121+
self._app_name_to_dir = dict(app_name_to_dir or {})
113122
self._services: dict[str, BaseSessionService] = {}
114123
self._service_lock = asyncio.Lock()
115124

116125
async def _get_service(self, app_name: str) -> BaseSessionService:
117126
async with self._service_lock:
118-
service = self._services.get(app_name)
127+
storage_name = self._app_name_to_dir.get(app_name, app_name)
128+
service = self._services.get(storage_name)
119129
if service is not None:
120130
return service
121131
folder = dot_adk_folder_for_agent(
122-
agents_root=self._agents_root, app_name=app_name
132+
agents_root=self._agents_root, app_name=storage_name
123133
)
124134
service = create_local_database_session_service(
125135
base_dir=folder.agent_dir,
126136
)
127-
self._services[app_name] = service
137+
self._services[storage_name] = service
128138
return service
129139

130140
@override

src/google/adk/cli/utils/service_factory.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def create_session_service_from_options(
3333
base_dir: Path | str,
3434
session_service_uri: Optional[str] = None,
3535
session_db_kwargs: Optional[dict[str, Any]] = None,
36+
app_name_to_dir: Optional[dict[str, str]] = None,
3637
) -> BaseSessionService:
3738
"""Creates a session service based on CLI/web options."""
3839
base_path = Path(base_dir)
@@ -64,7 +65,11 @@ def create_session_service_from_options(
6465
return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs)
6566

6667
# Default to per-agent local SQLite storage in <agents_root>/<agent>/.adk/.
67-
return create_local_session_service(base_dir=base_path, per_agent=True)
68+
return create_local_session_service(
69+
base_dir=base_path,
70+
per_agent=True,
71+
app_name_to_dir=app_name_to_dir,
72+
)
6873

6974

7075
def create_memory_service_from_options(
@@ -96,6 +101,7 @@ def create_artifact_service_from_options(
96101
*,
97102
base_dir: Path | str,
98103
artifact_service_uri: Optional[str] = None,
104+
strict_uri: bool = False,
99105
) -> BaseArtifactService:
100106
"""Creates an artifact service based on CLI/web options."""
101107
base_path = Path(base_dir)
@@ -108,6 +114,10 @@ def create_artifact_service_from_options(
108114
agents_dir=str(base_path),
109115
)
110116
if service is None:
117+
if strict_uri:
118+
raise ValueError(
119+
f"Unsupported artifact service URI: {artifact_service_uri}"
120+
)
111121
logger.warning(
112122
"Unsupported artifact service URI: %s, falling back to in-memory",
113123
artifact_service_uri,

tests/unittests/cli/test_fast_api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,15 @@ def test_app(
416416
with (
417417
patch("signal.signal", return_value=None),
418418
patch(
419-
"google.adk.cli.fast_api.InMemorySessionService",
419+
"google.adk.cli.fast_api.create_session_service_from_options",
420420
return_value=mock_session_service,
421421
),
422422
patch(
423-
"google.adk.cli.fast_api.InMemoryArtifactService",
423+
"google.adk.cli.fast_api.create_artifact_service_from_options",
424424
return_value=mock_artifact_service,
425425
),
426426
patch(
427-
"google.adk.cli.fast_api.InMemoryMemoryService",
427+
"google.adk.cli.fast_api.create_memory_service_from_options",
428428
return_value=mock_memory_service,
429429
),
430430
patch(
@@ -556,15 +556,15 @@ def test_app_with_a2a(
556556
with (
557557
patch("signal.signal", return_value=None),
558558
patch(
559-
"google.adk.cli.fast_api.InMemorySessionService",
559+
"google.adk.cli.fast_api.create_session_service_from_options",
560560
return_value=mock_session_service,
561561
),
562562
patch(
563-
"google.adk.cli.fast_api.InMemoryArtifactService",
563+
"google.adk.cli.fast_api.create_artifact_service_from_options",
564564
return_value=mock_artifact_service,
565565
),
566566
patch(
567-
"google.adk.cli.fast_api.InMemoryMemoryService",
567+
"google.adk.cli.fast_api.create_memory_service_from_options",
568568
return_value=mock_memory_service,
569569
),
570570
patch(

tests/unittests/cli/utils/test_local_storage.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818

1919
from google.adk.cli.utils.local_storage import create_local_database_session_service
20+
from google.adk.cli.utils.local_storage import create_local_session_service
2021
from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService
2122
from google.adk.sessions.sqlite_session_service import SqliteSessionService
2223
import pytest
@@ -48,6 +49,29 @@ async def test_per_agent_session_service_creates_scoped_dot_adk(
4849
assert agent_b_sessions.sessions[0].app_name == "agent_b"
4950

5051

52+
@pytest.mark.asyncio
53+
async def test_per_agent_session_service_respects_app_name_alias(
54+
tmp_path: Path,
55+
) -> None:
56+
folder_name = "agent_folder"
57+
logical_name = "custom_app"
58+
(tmp_path / folder_name).mkdir()
59+
60+
service = create_local_session_service(
61+
base_dir=tmp_path,
62+
per_agent=True,
63+
app_name_to_dir={logical_name: folder_name},
64+
)
65+
66+
session = await service.create_session(
67+
app_name=logical_name,
68+
user_id="user",
69+
)
70+
71+
assert session.app_name == logical_name
72+
assert (tmp_path / folder_name / ".adk" / "session.db").exists()
73+
74+
5175
def test_create_local_database_session_service_returns_sqlite(
5276
tmp_path: Path,
5377
) -> None:

tests/unittests/cli/utils/test_service_factory.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ async def test_create_session_service_defaults_to_per_agent_sqlite(
6060
assert (agent_dir / ".adk" / "session.db").exists()
6161

6262

63+
@pytest.mark.asyncio
64+
async def test_create_session_service_respects_app_name_mapping(
65+
tmp_path: Path,
66+
) -> None:
67+
agent_dir = tmp_path / "agent_folder"
68+
logical_name = "custom_app"
69+
agent_dir.mkdir()
70+
71+
service = service_factory.create_session_service_from_options(
72+
base_dir=tmp_path,
73+
app_name_to_dir={logical_name: "agent_folder"},
74+
)
75+
76+
assert isinstance(service, PerAgentDatabaseSessionService)
77+
session = await service.create_session(app_name=logical_name, user_id="user")
78+
assert session.app_name == logical_name
79+
assert (agent_dir / ".adk" / "session.db").exists()
80+
81+
6382
def test_create_session_service_fallbacks_to_database(
6483
tmp_path: Path, monkeypatch
6584
):
@@ -101,6 +120,21 @@ def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch):
101120
)
102121

103122

123+
def test_create_artifact_service_raises_on_unknown_scheme_when_strict(
124+
tmp_path: Path, monkeypatch
125+
):
126+
registry = Mock()
127+
registry.create_artifact_service.return_value = None
128+
monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry)
129+
130+
with pytest.raises(ValueError):
131+
service_factory.create_artifact_service_from_options(
132+
base_dir=tmp_path,
133+
artifact_service_uri="unknown://foo",
134+
strict_uri=True,
135+
)
136+
137+
104138
def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch):
105139
registry = Mock()
106140
expected = object()

0 commit comments

Comments
 (0)