From 598965fa997fef6caeebd41f5547df0eafd93599 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 Jan 2026 09:43:23 -0800 Subject: [PATCH] feat: Adding a new `ArtifactService.saveAndReloadArtifact()` method The `saveAndReloadArtifact()` enables a save without a second i/o call just to get the full file path. PiperOrigin-RevId: 862275776 --- .../adk/artifacts/BaseArtifactService.java | 22 ++++ .../adk/artifacts/GcsArtifactService.java | 103 +++++++++++++----- .../artifacts/InMemoryArtifactService.java | 10 ++ .../adk/artifacts/GcsArtifactServiceTest.java | 37 +++++++ .../InMemoryArtifactServiceTest.java | 82 ++++++++++++++ 5 files changed, 226 insertions(+), 28 deletions(-) create mode 100644 core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index 847e88dd9..b6a3cee23 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -39,6 +39,28 @@ public interface BaseArtifactService { Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact); + /** + * Saves an artifact and returns it with fileData if available. + * + *

Implementations should override this default method for efficiency, as the default performs + * two I/O operations (save then load). + * + * @param appName the app name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the filename + * @param artifact the artifact to save + * @return the saved artifact with fileData if available. + */ + default Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMap( + version -> + loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) + .toSingle()); + } + /** * Gets an artifact. * diff --git a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java index 1bfef8cf8..b9bc49a02 100644 --- a/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java @@ -18,6 +18,7 @@ import static java.util.Collections.max; +import com.google.auto.value.AutoValue; import com.google.cloud.storage.Blob; import com.google.cloud.storage.BlobId; import com.google.cloud.storage.BlobInfo; @@ -27,6 +28,7 @@ import com.google.common.base.Splitter; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; +import com.google.genai.types.FileData; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; @@ -108,34 +110,8 @@ private String getBlobName( @Override public Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { - return listVersions(appName, userId, sessionId, filename) - .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) - .map( - nextVersion -> { - String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); - BlobId blobId = BlobId.of(bucketName, blobName); - - BlobInfo blobInfo = - BlobInfo.newBuilder(blobId) - .setContentType(artifact.inlineData().get().mimeType().orElse(null)) - .build(); - - try { - byte[] dataToSave = - artifact - .inlineData() - .get() - .data() - .orElseThrow( - () -> - new IllegalArgumentException( - "Saveable artifact data must be non-empty.")); - storageClient.create(blobInfo, dataToSave); - return nextVersion; - } catch (StorageException e) { - throw new VerifyException("Failed to save artifact to GCS", e); - } - }); + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .map(SaveResult::version); } /** @@ -275,4 +251,75 @@ public Single> listVersions( return Single.just(ImmutableList.of()); } } + + @Override + public Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifactAndReturnBlob(appName, userId, sessionId, filename, artifact) + .flatMap( + blob -> { + Blob savedBlob = blob.blob(); + String resultMimeType = + Optional.ofNullable(savedBlob.getContentType()) + .or( + () -> + artifact.inlineData().flatMap(com.google.genai.types.Blob::mimeType)) + .orElse("application/octet-stream"); + return Single.just( + Part.builder() + .fileData( + FileData.builder() + .fileUri("gs://" + savedBlob.getBucket() + "/" + savedBlob.getName()) + .mimeType(resultMimeType) + .build()) + .build()); + }); + } + + @AutoValue + abstract static class SaveResult { + static SaveResult create(Blob blob, int version) { + return new AutoValue_GcsArtifactService_SaveResult(blob, version); + } + + abstract Blob blob(); + + abstract int version(); + } + + private Single saveArtifactAndReturnBlob( + String appName, String userId, String sessionId, String filename, Part artifact) { + return listVersions(appName, userId, sessionId, filename) + .map(versions -> versions.isEmpty() ? 0 : max(versions) + 1) + .map( + nextVersion -> { + if (artifact.inlineData().isEmpty()) { + throw new IllegalArgumentException("Saveable artifact must have inline data."); + } + + String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion); + BlobId blobId = BlobId.of(bucketName, blobName); + + BlobInfo blobInfo = + BlobInfo.newBuilder(blobId) + .setContentType(artifact.inlineData().get().mimeType().orElse(null)) + .build(); + + try { + byte[] dataToSave = + artifact + .inlineData() + .get() + .data() + .orElseThrow( + () -> + new IllegalArgumentException( + "Saveable artifact data must be non-empty.")); + Blob blob = storageClient.create(blobInfo, dataToSave); + return SaveResult.create(blob, nextVersion); + } catch (StorageException e) { + throw new VerifyException("Failed to save artifact to GCS", e); + } + }); + } } diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 27b85136d..5808f7083 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -125,6 +125,16 @@ public Single> listVersions( return Single.just(IntStream.range(0, size).boxed().collect(toImmutableList())); } + @Override + public Single saveAndReloadArtifact( + String appName, String userId, String sessionId, String filename, Part artifact) { + return saveArtifact(appName, userId, sessionId, filename, artifact) + .flatMap( + version -> + loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) + .toSingle()); + } + private Map> getArtifactsMap(String appName, String userId, String sessionId) { return artifacts .computeIfAbsent(appName, unused -> new HashMap<>()) diff --git a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java index 1df66c36d..40493bf3a 100644 --- a/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/GcsArtifactServiceTest.java @@ -31,6 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -76,6 +77,7 @@ private Blob mockBlob(String name, String contentType, byte[] content) { when(blob.exists()).thenReturn(true); BlobId blobId = BlobId.of(BUCKET_NAME, name); when(blob.getBlobId()).thenReturn(blobId); + when(blob.getBucket()).thenReturn(BUCKET_NAME); return blob; } @@ -89,6 +91,8 @@ public void save_firstVersion_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -109,6 +113,8 @@ public void save_subsequentVersion_savesCorrectly() { Blob blobV0 = mockBlob(blobNameV0, "text/plain", new byte[] {1}); when(mockBlobPage.iterateAll()).thenReturn(Collections.singletonList(blobV0)); + Blob savedBlob = mockBlob(expectedBlobNameV1, "image/png", new byte[] {4, 5}); + when(mockStorage.create(eq(expectedBlobInfoV1), eq(new byte[] {4, 5}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); @@ -126,6 +132,8 @@ public void save_userNamespace_savesCorrectly() { BlobInfo.newBuilder(expectedBlobId).setContentType("application/json").build(); when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/json", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); int version = service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, USER_FILENAME, artifact).blockingGet(); @@ -330,7 +338,36 @@ public void listVersions_noVersions_returnsEmptyList() { assertThat(versions).isEmpty(); } + @Test + public void saveAndReloadArtifact_savesAndReturnsFileData() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "application/octet-stream"); + String expectedBlobName = + String.format("%s/%s/%s/%s/0", APP_NAME, USER_ID, SESSION_ID, FILENAME); + BlobId expectedBlobId = BlobId.of(BUCKET_NAME, expectedBlobName); + BlobInfo expectedBlobInfo = + BlobInfo.newBuilder(expectedBlobId).setContentType("application/octet-stream").build(); + + when(mockBlobPage.iterateAll()).thenReturn(ImmutableList.of()); + Blob savedBlob = mockBlob(expectedBlobName, "application/octet-stream", new byte[] {1, 2, 3}); + when(mockStorage.create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3}))).thenReturn(savedBlob); + + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + + assertThat(result).isPresent(); + assertThat(result.get().fileData()).isPresent(); + assertThat(result.get().fileData().get().fileUri()) + .hasValue("gs://" + BUCKET_NAME + "/" + expectedBlobName); + assertThat(result.get().fileData().get().mimeType()).hasValue("application/octet-stream"); + verify(mockStorage).create(eq(expectedBlobInfo), eq(new byte[] {1, 2, 3})); + } + private static Optional asOptional(Maybe maybe) { return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); } + + private static Optional asOptional(Single single) { + return Optional.of(single.blockingGet()); + } } diff --git a/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java new file mode 100644 index 000000000..4cb493277 --- /dev/null +++ b/core/src/test/java/com/google/adk/artifacts/InMemoryArtifactServiceTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.artifacts; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InMemoryArtifactService}. */ +@RunWith(JUnit4.class) +public class InMemoryArtifactServiceTest { + + private static final String APP_NAME = "test-app"; + private static final String USER_ID = "test-user"; + private static final String SESSION_ID = "test-session"; + private static final String FILENAME = "test-file.txt"; + + private InMemoryArtifactService service; + + @Before + public void setUp() { + service = new InMemoryArtifactService(); + } + + @Test + public void saveArtifact_savesAndReturnsVersion() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + int version = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact).blockingGet(); + assertThat(version).isEqualTo(0); + } + + @Test + public void loadArtifact_loadsLatest() { + Part artifact1 = Part.fromBytes(new byte[] {1}, "text/plain"); + Part artifact2 = Part.fromBytes(new byte[] {1, 2}, "text/plain"); + var unused1 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact1).blockingGet(); + var unused2 = + service.saveArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact2).blockingGet(); + Optional result = + asOptional(service.loadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, Optional.empty())); + assertThat(result).hasValue(artifact2); + } + + @Test + public void saveAndReloadArtifact_reloadsArtifact() { + Part artifact = Part.fromBytes(new byte[] {1, 2, 3}, "text/plain"); + Optional result = + asOptional( + service.saveAndReloadArtifact(APP_NAME, USER_ID, SESSION_ID, FILENAME, artifact)); + assertThat(result).hasValue(artifact); + } + + private static Optional asOptional(Maybe maybe) { + return maybe.map(Optional::of).defaultIfEmpty(Optional.empty()).blockingGet(); + } + + private static Optional asOptional(Single single) { + return Optional.of(single.blockingGet()); + } +}