From 67bdc89c5257469ad63f984175b89aa2152aa0f1 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 26 Jan 2026 14:53:25 +0100 Subject: [PATCH 01/10] Kotlin reflections support --- .../proxysupport/ByteBuddyProxyFactory.java | 31 +- client-kotlin/build.gradle.kts | 8 +- .../dev/restate/client/kotlin/ingress.kt | 308 ++++++++ .../main/java/dev/restate/client/Client.java | 18 + common-kotlin/build.gradle.kts | 12 + .../reflection/kotlin/RequestCaptureProxy.kt | 71 ++ .../common/reflection/kotlin/reflections.kt | 111 +++ .../dev/restate/common/InvocationOptions.java | 4 + .../main/java/dev/restate/common/Request.java | 5 + .../java/dev/restate/common/RequestImpl.java | 6 + .../common/reflections/MethodInfo.java | 3 + examples/build.gradle.kts | 3 - .../my/restate/sdk/examples/CounterKt.kt | 20 +- .../sdk/kotlin/gen/KElementConverter.kt | 2 +- sdk-api-kotlin/build.gradle.kts | 10 +- .../dev/restate/sdk/kotlin/HandlerRunner.kt | 2 + .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 702 +++++++++++++++++- .../MalformedRestateServiceException.kt | 22 + .../ReflectionServiceDefinitionFactory.kt | 432 +++++++++++ .../kotlin/internal/RestateContextElement.kt | 24 + ...dpoint.definition.ServiceDefinitionFactory | 1 + sdk-core/build.gradle.kts | 20 + .../sdk/core/kotlinapi/KotlinAPITests.kt | 2 + .../reflections/ReflectionDiscoveryTest.kt | 91 +++ .../kotlinapi/reflections/ReflectionTest.kt | 229 ++++++ .../core/kotlinapi/reflections/testClasses.kt | 193 +++++ .../build.gradle.kts | 13 + .../sdk/springboot/kotlin/GreeterNewApi.kt | 26 + .../kotlin/SdkTestingIntegrationTest.kt | 28 +- settings.gradle.kts | 1 + 30 files changed, 2370 insertions(+), 28 deletions(-) create mode 100644 common-kotlin/build.gradle.kts create mode 100644 common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt create mode 100644 common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/reflections.kt create mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/MalformedRestateServiceException.kt create mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt create mode 100644 sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/RestateContextElement.kt create mode 100644 sdk-api-kotlin/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory create mode 100644 sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt create mode 100644 sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt create mode 100644 sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt create mode 100644 sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/GreeterNewApi.kt diff --git a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java index 049c3df97..71b38389a 100644 --- a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java +++ b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java @@ -52,7 +52,24 @@ public final class ByteBuddyProxyFactory implements ProxyFactory { public @Nullable T createProxy(Class clazz, MethodInterceptor interceptor) { // Cannot proxy final classes if (Modifier.isFinal(clazz.getModifiers())) { - throw new IllegalArgumentException("Class " + clazz + " is final, cannot be proxied."); + if (ReflectionUtils.isKotlinClass(clazz)) { + throw new IllegalArgumentException( + clazz + + +""" + is not open, cannot be proxied. Suggestions: +* Extract the @Handler annotated functions in an interface +* Make the class and all its @Handler annotated functions 'open' +* Use the Kotlin allopen compiler plugin https://kotlinlang.org/docs/all-open-plugin.html with the following configuration: + +allOpen { + annotations("dev.restate.sdk.annotation.Service", "dev.restate.sdk.annotation.VirtualObject", "dev.restate.sdk.annotation.Workflow") +} +"""); + } + throw new IllegalArgumentException( + clazz + + " is final, cannot be proxied. Remove the final keyword, or refactor it extracting the restate interface out of it."); } try { @@ -88,7 +105,7 @@ public Method getMethod() { return proxyInstance; } catch (Exception e) { - throw new IllegalArgumentException("Cannot create proxy for class " + clazz, e); + throw new IllegalArgumentException("Cannot create proxy for " + clazz, e); } } @@ -109,10 +126,12 @@ private Class generateProxyClass(Class clazz) throws NoSuchFieldExcept : byteBuddy.subclass(clazz); var annotationMatcher = - isAnnotatedWith(Handler.class) - .or(isAnnotatedWith(Exclusive.class)) - .or(isAnnotatedWith(Shared.class)) - .or(isAnnotatedWith(Workflow.class)); + not(isStatic()) + .and( + isAnnotatedWith(Handler.class) + .or(isAnnotatedWith(Exclusive.class)) + .or(isAnnotatedWith(Shared.class)) + .or(isAnnotatedWith(Workflow.class))); try (var unloaded = builder // Add a field to store the interceptor diff --git a/client-kotlin/build.gradle.kts b/client-kotlin/build.gradle.kts index 96aca77f7..0594330a1 100644 --- a/client-kotlin/build.gradle.kts +++ b/client-kotlin/build.gradle.kts @@ -5,9 +5,15 @@ plugins { description = "Restate Client to interact with services from within other Kotlin applications" +configurations.all { + // Gonna conflict with sdk-serde-kotlinx + exclude(group = "dev.restate", module = "sdk-serde-jackson") +} + dependencies { - api(project(":client")) { exclude("dev.restate", "sdk-serde-jackson") } + api(project(":client")) api(project(":sdk-serde-kotlinx")) + implementation(project(":common-kotlin")) implementation(libs.kotlinx.coroutines.core) } diff --git a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt index cc0e8381f..db59e7f01 100644 --- a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt +++ b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt @@ -13,12 +13,20 @@ import dev.restate.client.RequestOptions import dev.restate.client.Response import dev.restate.client.ResponseHead import dev.restate.client.SendResponse +import dev.restate.common.InvocationOptions import dev.restate.common.Output import dev.restate.common.Request import dev.restate.common.Target import dev.restate.common.WorkflowRequest +import dev.restate.common.reflection.kotlin.RequestCaptureProxy +import dev.restate.common.reflection.kotlin.captureInvocation +import dev.restate.common.reflections.ProxySupport +import dev.restate.common.reflections.ReflectionUtils import dev.restate.serde.TypeTag import dev.restate.serde.kotlinx.typeTag +import kotlin.coroutines.Continuation +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED +import kotlin.coroutines.startCoroutine import kotlin.time.Duration import kotlin.time.toJavaDuration import kotlinx.coroutines.future.await @@ -262,3 +270,303 @@ val Response.response: Res /** @see SendResponse.sendStatus */ val SendResponse.sendStatus: SendResponse.SendStatus get() = this.sendStatus() + +/** + * Create a proxy client for a Restate service. + * + * Example usage: + * ```kotlin + * val greeter = client.service() + * val response = greeter.greet("Alice") + * ``` + * + * @param SVC the service class annotated with @Service + * @return a proxy client to invoke the service + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.service(): SVC { + return service(this, SVC::class.java) +} + +/** + * Create a proxy client for a Restate virtual object. + * + * Example usage: + * ```kotlin + * val counter = client.virtualObject("my-key") + * val value = counter.increment() + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a proxy client to invoke the virtual object + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.virtualObject(key: String): SVC { + return virtualObject(this, SVC::class.java, key) +} + +/** + * Create a proxy client for a Restate workflow. + * + * Example usage: + * ```kotlin + * val wf = client.workflow("wf-123") + * val result = wf.run("input") + * ``` + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a proxy client to invoke the workflow + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.workflow(key: String): SVC { + return workflow(this, SVC::class.java, key) +} + +/** + * Create a proxy for a service that uses the ingress client to make calls. + * + * @param client the ingress client to use for calls + * @param clazz the service class + * @return a proxy that intercepts method calls and executes them via the client + */ +@PublishedApi +internal fun service(client: Client, clazz: Class): SVC { + ReflectionUtils.mustHaveServiceAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + + val serviceName = ReflectionUtils.extractServiceName(clazz) + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, null).toRequest() + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { client.callAsync(request).await().response() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +/** + * Create a proxy for a virtual object that uses the ingress client to make calls. + * + * @param client the ingress client to use for calls + * @param clazz the virtual object class + * @param key the virtual object key + * @return a proxy that intercepts method calls and executes them via the client + */ +@PublishedApi +internal fun virtualObject(client: Client, clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + + val serviceName = ReflectionUtils.extractServiceName(clazz) + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { client.callAsync(request).await().response() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +/** + * Create a proxy for a workflow that uses the ingress client to make calls. + * + * @param client the ingress client to use for calls + * @param clazz the workflow class + * @param key the workflow key + * @return a proxy that intercepts method calls and executes them via the client + */ +@PublishedApi +internal fun workflow(client: Client, clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveWorkflowAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + + val serviceName = ReflectionUtils.extractServiceName(clazz) + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { client.callAsync(request).await().response() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +/** + * Builder for creating type-safe requests. + * + * This builder allows the response type to be inferred from the lambda passed to [request]. + * + * @param SVC the service/virtual object/workflow class + */ +@org.jetbrains.annotations.ApiStatus.Experimental +class KClientRequestBuilder +@PublishedApi +internal constructor( + private val client: Client, + private val clazz: Class, + private val key: String?, +) { + /** + * Create a request by invoking a method on the target. + * + * The response type is inferred from the return type of the invoked method. + * + * @param Res the response type (inferred from the lambda) + * @param block a suspend lambda that invokes a method on the target + * @return a [KClientRequest] with the correct response type + */ + @Suppress("UNCHECKED_CAST") + fun request(block: suspend (SVC) -> Res): KClientRequest { + return KClientRequestImpl( + client, + RequestCaptureProxy(clazz, key).capture(block as suspend (SVC) -> Any?).toRequest(), + ) + as KClientRequest + } +} + +/** + * Kotlin-idiomatic request for invoking Restate services from an ingress client. + * + * Example usage: + * ```kotlin + * client.toService() + * .request { it.add(1) } + * .withOptions { idempotencyKey = "123" } + * .call() + * ``` + * + * @param Req the request type + * @param Res the response type + */ +@org.jetbrains.annotations.ApiStatus.Experimental +interface KClientRequest : Request { + + /** + * Configure invocation options using a DSL. + * + * @param block builder block for options + * @return a new request with the configured options + */ + fun withOptions(block: InvocationOptions.Builder.() -> Unit): KClientRequest + + /** + * Call the target handler and wait for the response. + * + * @return the response + */ + suspend fun call(): Response + + /** + * Send the request without waiting for the response. + * + * @param delay optional delay before the invocation is executed + * @return the send response with invocation handle + */ + suspend fun send(delay: Duration? = null): SendResponse +} + +/** + * Create a builder for invoking a Restate service. + * + * Example usage: + * ```kotlin + * val response = client.toService() + * .request { it.greet("Alice") } + * .call() + * ``` + * + * @param SVC the service class annotated with @Service + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.toService(): KClientRequestBuilder { + ReflectionUtils.mustHaveServiceAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KClientRequestBuilder(this, SVC::class.java, null) +} + +/** + * Create a builder for invoking a Restate virtual object. + * + * Example usage: + * ```kotlin + * val response = client.toVirtualObject("my-counter") + * .request { it.add(1) } + * .call() + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.toVirtualObject(key: String): KClientRequestBuilder { + ReflectionUtils.mustHaveVirtualObjectAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KClientRequestBuilder(this, SVC::class.java, key) +} + +/** + * Create a builder for invoking a Restate workflow. + * + * Example usage: + * ```kotlin + * val response = client.toWorkflow("workflow-123") + * .request { it.run("input") } + * .call() + * ``` + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun Client.toWorkflow(key: String): KClientRequestBuilder { + ReflectionUtils.mustHaveWorkflowAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KClientRequestBuilder(this, SVC::class.java, key) +} + +/** Implementation of [KClientRequest] for ingress client. */ +private class KClientRequestImpl( + private val client: Client, + private val request: Request, +) : KClientRequest, Request by request { + + override fun withOptions(block: InvocationOptions.Builder.() -> Unit): KClientRequest { + val builder = InvocationOptions.builder() + builder.block() + return KClientRequestImpl( + client, + this.toBuilder().headers(builder.headers).idempotencyKey(builder.idempotencyKey).build(), + ) + } + + override suspend fun call(): Response { + return client.callSuspend(request) + } + + override suspend fun send(delay: Duration?): SendResponse { + return client.sendSuspend(request, delay) + } +} diff --git a/client/src/main/java/dev/restate/client/Client.java b/client/src/main/java/dev/restate/client/Client.java index 3d7b9d754..3fa5cee01 100644 --- a/client/src/main/java/dev/restate/client/Client.java +++ b/client/src/main/java/dev/restate/client/Client.java @@ -555,6 +555,9 @@ default Response> getOutput() throws IngressException { @org.jetbrains.annotations.ApiStatus.Experimental default SVC service(Class clazz) { ReflectionUtils.mustHaveServiceAnnotation(clazz); + if (ReflectionUtils.isKotlinClass(clazz)) { + throw new IllegalArgumentException("Using Kotlin classes with Java's API is not supported"); + } var serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -606,6 +609,9 @@ default SVC service(Class clazz) { @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceHandle serviceHandle(Class clazz) { ReflectionUtils.mustHaveServiceAnnotation(clazz); + if (ReflectionUtils.isKotlinClass(clazz)) { + throw new IllegalArgumentException("Using Kotlin classes with Java's API is not supported"); + } return new ClientServiceHandleImpl<>(this, clazz, null); } @@ -634,6 +640,9 @@ default ClientServiceHandle serviceHandle(Class clazz) { @org.jetbrains.annotations.ApiStatus.Experimental default SVC virtualObject(Class clazz, String key) { ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz); + if (ReflectionUtils.isKotlinClass(clazz)) { + throw new IllegalArgumentException("Using Kotlin classes with Java's API is not supported"); + } var serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -686,6 +695,9 @@ default SVC virtualObject(Class clazz, String key) { @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceHandle virtualObjectHandle(Class clazz, String key) { ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz); + if (ReflectionUtils.isKotlinClass(clazz)) { + throw new IllegalArgumentException("Using Kotlin classes with Java's API is not supported"); + } return new ClientServiceHandleImpl<>(this, clazz, key); } @@ -714,6 +726,9 @@ default ClientServiceHandle virtualObjectHandle(Class clazz, Str @org.jetbrains.annotations.ApiStatus.Experimental default SVC workflow(Class clazz, String key) { ReflectionUtils.mustHaveWorkflowAnnotation(clazz); + if (ReflectionUtils.isKotlinClass(clazz)) { + throw new IllegalArgumentException("Using Kotlin classes with Java's API is not supported"); + } var serviceName = ReflectionUtils.extractServiceName(clazz); return ProxySupport.createProxy( clazz, @@ -766,6 +781,9 @@ default SVC workflow(Class clazz, String key) { @org.jetbrains.annotations.ApiStatus.Experimental default ClientServiceHandle workflowHandle(Class clazz, String key) { ReflectionUtils.mustHaveWorkflowAnnotation(clazz); + if (ReflectionUtils.isKotlinClass(clazz)) { + throw new IllegalArgumentException("Using Kotlin classes with Java's API is not supported"); + } return new ClientServiceHandleImpl<>(this, clazz, key); } diff --git a/common-kotlin/build.gradle.kts b/common-kotlin/build.gradle.kts new file mode 100644 index 000000000..6176414b0 --- /dev/null +++ b/common-kotlin/build.gradle.kts @@ -0,0 +1,12 @@ +plugins { + `kotlin-conventions` + `library-publishing-conventions` +} + +description = "Common types used by different Restate Kotlin modules" + +dependencies { + api(project(":common")) + api(project(":sdk-serde-kotlinx")) + implementation(kotlin("reflect")) +} diff --git a/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt b/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt new file mode 100644 index 000000000..e98ed10b3 --- /dev/null +++ b/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt @@ -0,0 +1,71 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflection.kotlin + +import dev.restate.common.reflections.ProxySupport +import dev.restate.common.reflections.ReflectionUtils +import kotlin.coroutines.Continuation +import kotlin.coroutines.EmptyCoroutineContext +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED +import kotlin.coroutines.startCoroutine + +/** + * Captures method invocations on a proxy to extract invocation information. + * + * This class is used to intercept calls on service proxies and extract the method metadata and + * arguments without actually executing the method. The captured information can then be used to + * build requests for remote invocation. + * + * @param SVC the service type + * @property clazz the service class + * @property serviceName the resolved service name + * @property key the virtual object/workflow key (null for stateless services) + */ +class RequestCaptureProxy(private val clazz: Class, private val key: String?) { + + private val serviceName: String = ReflectionUtils.extractServiceName(clazz) + + /** + * Capture a method invocation from the given block. + * + * @param block the suspend lambda that invokes a method on the service proxy + * @return the captured invocation information + */ + fun capture(block: suspend (SVC) -> Any?): CapturedInvocation { + var capturedInvocation: CapturedInvocation? = null + + val proxy = + ProxySupport.createProxy(clazz) { invocation -> + capturedInvocation = invocation.captureInvocation(serviceName, key) + + // Return COROUTINE_SUSPENDED to prevent actual execution + COROUTINE_SUSPENDED + } + + // Invoke the block with the proxy to capture the method call. + // Since the proxy returns COROUTINE_SUSPENDED, we use startCoroutine + // which starts but doesn't block waiting for completion. + val capturingContinuation = + object : Continuation { + override val context = EmptyCoroutineContext + + override fun resumeWith(result: Result) { + // Do nothing - we're just capturing, the coroutine suspends immediately + } + } + + val suspendBlock: suspend () -> Any? = { block(proxy) } + suspendBlock.startCoroutine(capturingContinuation) + + return capturedInvocation + ?: error( + "Method invocation was not captured. Make sure to call ONLY a method of the service proxy." + ) + } +} diff --git a/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/reflections.kt b/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/reflections.kt new file mode 100644 index 000000000..60e65bda2 --- /dev/null +++ b/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/reflections.kt @@ -0,0 +1,111 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.common.reflection.kotlin + +import dev.restate.common.Request +import dev.restate.common.Target +import dev.restate.common.reflections.ProxyFactory +import dev.restate.common.reflections.ReflectionUtils +import dev.restate.sdk.annotation.Raw +import dev.restate.serde.Serde +import dev.restate.serde.TypeTag +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory +import kotlin.reflect.KClass +import kotlin.reflect.KType +import kotlin.reflect.full.findAnnotation +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.kotlinFunction +import kotlin.reflect.typeOf + +/** + * Captured information from a method invocation on a proxy. + * + * @property target the target service/handler + * @property inputTypeTag type tag for serializing the input + * @property outputTypeTag type tag for deserializing the output + * @property input the input value (may be null for no-arg methods) + */ +data class CapturedInvocation( + val target: Target, + val inputTypeTag: TypeTag<*>, + val outputTypeTag: TypeTag<*>, + val input: Any?, +) { + @Suppress("UNCHECKED_CAST") + fun toRequest(): Request<*, *> { + return Request.of(target, inputTypeTag as TypeTag, outputTypeTag as TypeTag, input) + } +} + +fun ProxyFactory.MethodInvocation.captureInvocation( + serviceName: String, + key: String?, +): CapturedInvocation { + val handlerInfo = ReflectionUtils.mustHaveHandlerAnnotation(method) + val handlerName = handlerInfo.name + val kFunction = method.kotlinFunction + require(kFunction != null && kFunction.isSuspend) { + "Method '${method.name}' is not a suspend function, this is not supported." + } + + val parameters = kFunction.valueParameters + val inputTypeTag = + if (parameters.isEmpty()) { + resolveKotlinTypeTag(typeOf(), null) + } else { + parameters[0].let { inputParam -> + resolveKotlinTypeTag( + inputParam.type, + inputParam.findAnnotation(), + ) + } + } + + val outputTypeTag = + resolveKotlinTypeTag( + kFunction.returnType, + kFunction.findAnnotation(), + ) + + val target = + if (key != null) { + Target.virtualObject(serviceName, key, handlerName) + } else { + Target.service(serviceName, handlerName) + } + + // For suspend functions, arguments are: [input?, continuation] + // Extract the input (first argument, excluding continuation) + val input = + if (this.arguments.size > 1) { + this.arguments[0] + } else { + null + } + + return CapturedInvocation(target, inputTypeTag, outputTypeTag, input) +} + +private fun resolveKotlinTypeTag(kType: KType, rawAnnotation: Raw?): TypeTag<*> { + if (kType.classifier == Unit::class) { + return KotlinSerializationSerdeFactory.UNIT + } + + if (rawAnnotation != null && rawAnnotation.contentType != "application/octet-stream") { + return Serde.withContentType(rawAnnotation.contentType, Serde.RAW) + } else if (rawAnnotation != null) { + return Serde.RAW + } + + @Suppress("UNCHECKED_CAST") + return KotlinSerializationSerdeFactory.KtTypeTag( + kType.classifier as KClass<*>, + kType, + ) +} diff --git a/common/src/main/java/dev/restate/common/InvocationOptions.java b/common/src/main/java/dev/restate/common/InvocationOptions.java index fd989caa3..ab4d58094 100644 --- a/common/src/main/java/dev/restate/common/InvocationOptions.java +++ b/common/src/main/java/dev/restate/common/InvocationOptions.java @@ -57,6 +57,10 @@ public String toString() { + '}'; } + public static Builder builder() { + return new Builder(null, null); + } + public static Builder idempotencyKey(String idempotencyKey) { return new Builder(null, null).idempotencyKey(idempotencyKey); } diff --git a/common/src/main/java/dev/restate/common/Request.java b/common/src/main/java/dev/restate/common/Request.java index 97d487afb..4a503e8bf 100644 --- a/common/src/main/java/dev/restate/common/Request.java +++ b/common/src/main/java/dev/restate/common/Request.java @@ -77,4 +77,9 @@ static RequestBuilder of(Target target, byte[] request) { * @return the request headers */ @Nullable Map getHeaders(); + + /** + * @return a builder filled with this request + */ + RequestBuilder toBuilder(); } diff --git a/common/src/main/java/dev/restate/common/RequestImpl.java b/common/src/main/java/dev/restate/common/RequestImpl.java index cebfa8722..6664bcdcb 100644 --- a/common/src/main/java/dev/restate/common/RequestImpl.java +++ b/common/src/main/java/dev/restate/common/RequestImpl.java @@ -180,6 +180,11 @@ public Builder setIdempotencyKey(@Nullable String idempotencyKey) { return headers; } + @Override + public RequestBuilder toBuilder() { + return this; + } + /** * @param headers headers to send together with the request. This will overwrite the already * configured headers @@ -203,6 +208,7 @@ public RequestImpl build() { } } + @Override public Builder toBuilder() { return new Builder<>( this.target, diff --git a/common/src/main/java/dev/restate/common/reflections/MethodInfo.java b/common/src/main/java/dev/restate/common/reflections/MethodInfo.java index 1fb86f029..4c9d2b29c 100644 --- a/common/src/main/java/dev/restate/common/reflections/MethodInfo.java +++ b/common/src/main/java/dev/restate/common/reflections/MethodInfo.java @@ -40,6 +40,9 @@ public TypeTag getOutputType() { } public static MethodInfo fromMethod(Method method) { + if (ReflectionUtils.isKotlinClass(method.getDeclaringClass())) { + throw new IllegalArgumentException("Using Kotlin classes with Java's API is not supported"); + } var handlerInfo = ReflectionUtils.mustHaveHandlerAnnotation(method); var genericParameters = method.getGenericParameterTypes(); var handlerName = handlerInfo.name(); diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 86530b6f0..08640c16e 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -3,14 +3,11 @@ import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar plugins { `java-conventions` `kotlin-conventions` - alias(libs.plugins.ksp) application alias(libs.plugins.shadow) } dependencies { - ksp(project(":sdk-api-kotlin-gen")) - implementation(project(":client")) implementation(project(":client-kotlin")) implementation(project(":sdk-api")) diff --git a/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt b/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt index 294644724..6a5eeb67c 100644 --- a/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt +++ b/examples/src/main/kotlin/my/restate/sdk/examples/CounterKt.kt @@ -29,29 +29,29 @@ class CounterKt { @Serializable data class CounterUpdate(var oldValue: Long, val newValue: Long) @Handler - suspend fun reset(ctx: ObjectContext) { - ctx.clear(TOTAL) + suspend fun reset() { + state().clear(TOTAL) } @Handler - suspend fun add(ctx: ObjectContext, value: Long) { - val currentValue = ctx.get(TOTAL) ?: 0L + suspend fun add(value: Long) { + val currentValue = state().get(TOTAL) ?: 0L val newValue = currentValue + value - ctx.set(TOTAL, newValue) + state().set(TOTAL, newValue) } @Handler @Shared - suspend fun get(ctx: SharedObjectContext): Long? { - return ctx.get(TOTAL) + suspend fun get(): Long? { + return state().get(TOTAL) } @Handler - suspend fun getAndAdd(ctx: ObjectContext, value: Long): CounterUpdate { + suspend fun getAndAdd(value: Long): CounterUpdate { LOG.info("Invoked get and add with $value") - val currentValue = ctx.get(TOTAL) ?: 0L + val currentValue = state().get(TOTAL) ?: 0L val newValue = currentValue + value - ctx.set(TOTAL, newValue) + state().set(TOTAL, newValue) return CounterUpdate(currentValue, newValue) } } diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt index 64c20a94e..06a1bff0d 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt @@ -312,7 +312,7 @@ class KElementConverter( clazz.qualifiedName ) { logger.error( - "The method signature must have ${clazz.qualifiedName} as first parameter, was ${function.parameters[0].type.resolve().declaration.qualifiedName!!.asString()}", + "The method ${function.qualifiedName?.asString()} signature must have ${clazz.qualifiedName} as first parameter, was ${function.parameters[0].type.resolve().declaration.qualifiedName!!.asString()}", function, ) } diff --git a/sdk-api-kotlin/build.gradle.kts b/sdk-api-kotlin/build.gradle.kts index ab2d515a0..f2c2aeb83 100644 --- a/sdk-api-kotlin/build.gradle.kts +++ b/sdk-api-kotlin/build.gradle.kts @@ -6,13 +6,17 @@ plugins { description = "Restate SDK Kotlin APIs" dependencies { - implementation(libs.kotlinx.coroutines.core) - implementation(libs.kotlinx.serialization.core) api(libs.kotlinx.serialization.json) - api(project(":sdk-common")) api(project(":sdk-serde-kotlinx")) + // For concrete class proxying in service-to-service calls + runtimeOnly(project(":bytebuddy-proxy-support")) + + implementation(project(":common-kotlin")) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.kotlinx.serialization.core) + implementation(kotlin("reflect")) implementation(libs.log4j.api) implementation(libs.opentelemetry.kotlin) } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt index 75f1ac56a..d2f997667 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt @@ -11,6 +11,7 @@ package dev.restate.sdk.kotlin import dev.restate.common.Slice import dev.restate.sdk.common.TerminalException import dev.restate.sdk.endpoint.definition.HandlerContext +import dev.restate.sdk.kotlin.internal.RestateContextElement import dev.restate.serde.Serde import dev.restate.serde.SerdeFactory import io.opentelemetry.extension.kotlin.asContextElement @@ -108,6 +109,7 @@ internal constructor( val scope = CoroutineScope( options.coroutineContext + + RestateContextElement(ctx) + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL .asContextElement(handlerContext) + handlerContext.request().openTelemetryContext()!!.asContextElement() diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index 78fbc7316..471609d6e 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -8,9 +8,14 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin +import dev.restate.common.InvocationOptions import dev.restate.common.Output import dev.restate.common.Request import dev.restate.common.Slice +import dev.restate.common.reflection.kotlin.RequestCaptureProxy +import dev.restate.common.reflection.kotlin.captureInvocation +import dev.restate.common.reflections.ProxySupport +import dev.restate.common.reflections.ReflectionUtils import dev.restate.sdk.common.DurablePromiseKey import dev.restate.sdk.common.HandlerRequest import dev.restate.sdk.common.InvocationId @@ -20,8 +25,12 @@ import dev.restate.serde.TypeTag import dev.restate.serde.kotlinx.* import java.nio.ByteBuffer import java.util.* +import kotlin.coroutines.Continuation +import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED +import kotlin.coroutines.startCoroutine import kotlin.random.Random import kotlin.time.Duration +import kotlinx.coroutines.currentCoroutineContext /** * This interface exposes the Restate functionalities to Restate services. It can be used to @@ -209,7 +218,6 @@ sealed interface Context { * running invocation, for example to cancel it or retrieve its result. * * @param invocationId The invocation to interact with. - * @param responseClazz The response class. */ inline fun Context.invocationHandle( invocationId: String @@ -733,3 +741,695 @@ val HandlerRequest.bodyAsByteBuffer: ByteBuffer get() = this.bodyAsBodyBuffer() val HandlerRequest.headers: Map get() = this.headers() + +// ============================================================================= +// Free-floating API functions for the reflection-based API +// ============================================================================= + +/** + * Get the current Restate [Context] from within a handler. + * + * This function must be called from within a Restate handler's suspend function. It retrieves the + * context from the coroutine context. + * + * Example usage: + * ```kotlin + * @Service + * class MyService { + * @Handler + * suspend fun myHandler(input: String): String { + * val ctx = context() + * // Use ctx for Restate operations + * return "processed: $input" + * } + * } + * ``` + * + * @throws IllegalStateException if called outside of a Restate handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun context(): Context { + val element = + currentCoroutineContext()[dev.restate.sdk.kotlin.internal.RestateContextElement] + ?: error("context() must be called from within a Restate handler") + return element.ctx +} + +/** + * Get the current request information. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.request + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun request(): HandlerRequest { + return context().request() +} + +/** + * Get the deterministic random instance. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.random + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun random(): RestateRandom { + return context().random() +} + +/** + * Causes the current execution of the function invocation to sleep for the given duration. + * + * @param duration for which to sleep. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.sleep + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun sleep(duration: Duration) { + context().sleep(duration) +} + +/** + * Causes the start of a timer for the given duration. + * + * @param duration for which to sleep. + * @param name name to be used for the timer + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.timer + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun timer(name: String = "", duration: Duration): DurableFuture { + return context().timer(duration, name) +} + +/** + * Execute a closure, recording the result value in the journal. + * + * @param name the name of the side effect. + * @param retryPolicy optional retry policy. + * @param block closure to execute. + * @return value of the run operation. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.runBlock + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun runBlock( + name: String = "", + retryPolicy: RetryPolicy? = null, + noinline block: suspend () -> T, +): T { + return context().runBlock(typeTag(), name, retryPolicy, block) +} + +/** + * Execute a closure asynchronously. + * + * @param name the name of the side effect. + * @param retryPolicy optional retry policy. + * @param block closure to execute. + * @return a [DurableFuture] that you can combine and select. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.runAsync + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun runAsync( + name: String = "", + retryPolicy: RetryPolicy? = null, + noinline block: suspend () -> T, +): DurableFuture { + return context().runAsync(typeTag(), name, retryPolicy, block) +} + +/** + * Create an [Awakeable], addressable through [Awakeable.id]. + * + * @return the [Awakeable] to await on. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.awakeable + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun awakeable(): Awakeable { + return context().awakeable(typeTag()) +} + +/** + * Create an [Awakeable], addressable through [Awakeable.id]. + * + * You can use this feature to implement external asynchronous systems interactions, for example you + * can send a Kafka record including the [Awakeable.id], and then let another service consume from + * Kafka the responses of given external system interaction by using [awakeableHandle]. + * + * @param typeTag the type tag for deserializing the [Awakeable] result. + * @return the [Awakeable] to await on. + * @see Awakeable + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun awakeable(typeTag: TypeTag): Awakeable { + return context().awakeable(typeTag) +} + +/** + * Create a new [AwakeableHandle] for the provided identifier. + * + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.awakeableHandle + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun awakeableHandle(id: String): AwakeableHandle { + return context().awakeableHandle(id) +} + +/** + * Get an [InvocationHandle] for an already existing invocation. + * + * @param invocationId The invocation to interact with. + * @throws IllegalStateException if called outside of a Restate handler + * @see Context.invocationHandle + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun invocationHandle(invocationId: String): InvocationHandle { + return context().invocationHandle(invocationId, typeTag()) +} + +/** + * Get the key of this Virtual Object or Workflow. + * + * @return the key of this object + * @throws IllegalStateException if called from a regular Service handler or outside of a Restate + * handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun key(): String { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("key() must be called from within a Restate handler") + + if (!handlerContext.canReadState()) { + error( + "key() can be used only within Virtual Object or Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/state for more details." + ) + } + + return (ctx as SharedObjectContext).key() +} + +/** + * Access to this Virtual Object/Workflow state. + * + * @return [KotlinState] for this Virtual Object/Workflow + * @throws IllegalStateException if called from a regular Service handler or outside of a Restate + * handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun state(): KotlinState { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("state() must be called from within a Restate handler") + + if (!handlerContext.canReadState()) { + error( + "state() can be used only within Virtual Object or Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/state for more details." + ) + } + + return KotlinStateImpl(ctx as SharedObjectContext, handlerContext) +} + +/** + * Create a [DurablePromise] for the given key. + * + * @throws IllegalStateException if called from a non-Workflow handler or outside of a Restate + * handler + * @see SharedWorkflowContext.promise + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun promise(key: DurablePromiseKey): DurablePromise { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("promise() must be called from within a Restate handler") + + if (!handlerContext.canReadPromises() || !handlerContext.canWritePromises()) { + error( + "promise(key) can be used only within Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/external-events#durable-promises for more details." + ) + } + + return (ctx as SharedWorkflowContext).promise(key) +} + +/** + * Create a new [DurablePromiseHandle] for the provided key. + * + * @throws IllegalStateException if called from a non-Workflow handler or outside of a Restate + * handler + * @see SharedWorkflowContext.promiseHandle + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun promiseHandle(key: DurablePromiseKey): DurablePromiseHandle { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("promiseHandle() must be called from within a Restate handler") + + if (!handlerContext.canReadPromises() || !handlerContext.canWritePromises()) { + error( + "promiseHandle(key) can be used only within Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/external-events#durable-promises for more details." + ) + } + + return (ctx as SharedWorkflowContext).promiseHandle(key) +} + +/** + * Interface for accessing Virtual Object/Workflow state in the reflection-based API. + * + * This interface provides suspend-friendly state operations that can be used from within Restate + * handlers using the free-floating `state()` function. + * + * Example usage: + * ```kotlin + * @VirtualObject + * class Counter { + * companion object { + * private val COUNT = stateKey("count") + * } + * + * @Handler + * suspend fun increment(): Long { + * val current = state().get(COUNT) ?: 0L + * val next = current + 1 + * state().set(COUNT, next) + * return next + * } + * } + * ``` + */ +@org.jetbrains.annotations.ApiStatus.Experimental +interface KotlinState { + /** + * Gets the state stored under key, deserializing the raw value using the [StateKey.serdeInfo]. + * + * @param key identifying the state to get and its type. + * @return the value containing the stored state deserialized, or null if not set. + * @throws RuntimeException when the state cannot be deserialized. + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun get(key: StateKey): T? + + /** + * Sets the given value under the given key, serializing the value using the [StateKey.serdeInfo]. + * + * @param key identifying the value to store and its type. + * @param value to store under the given key. + * @throws IllegalStateException if called from a Shared handler + */ + @org.jetbrains.annotations.ApiStatus.Experimental + suspend fun set(key: StateKey, value: T) + + /** + * Clears the state stored under key. + * + * @param key identifying the state to clear. + * @throws IllegalStateException if called from a Shared handler + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun clear(key: StateKey<*>) + + /** + * Clears all the state of this virtual object instance key-value state storage. + * + * @throws IllegalStateException if called from a Shared handler + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun clearAll() + + /** + * Gets all the known state keys for this virtual object instance. + * + * @return the immutable collection of known state keys. + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun keys(): Collection +} + +/** + * Gets the state stored under key. + * + * @param key the name of the state key. + * @return the value containing the stored state deserialized, or null if not set. + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun KotlinState.get(key: String): T? { + return this.get(StateKey.of(key, typeTag())) +} + +/** + * Sets the given value under the given key. + * + * @param key the name of the state key. + * @param value to store under the given key. + * @throws IllegalStateException if called from a Shared handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun KotlinState.set(key: String, value: T) { + this.set(StateKey.of(key, typeTag()), value) +} + +// Internal implementation of KotlinState +private class KotlinStateImpl( + private val ctx: SharedObjectContext, + private val handlerContext: dev.restate.sdk.endpoint.definition.HandlerContext, +) : KotlinState { + override suspend fun get(key: StateKey): T? { + return ctx.get(key) + } + + override suspend fun set(key: StateKey, value: T) { + checkCanWriteState("set") + (ctx as ObjectContext).set(key, value) + } + + override suspend fun clear(key: StateKey<*>) { + checkCanWriteState("clear") + (ctx as ObjectContext).clear(key) + } + + override suspend fun clearAll() { + checkCanWriteState("clearAll") + (ctx as ObjectContext).clearAll() + } + + override suspend fun keys(): Collection { + return ctx.stateKeys() + } + + private fun checkCanWriteState(opName: String) { + if (!handlerContext.canWriteState()) { + error( + "state().$opName() cannot be used in shared handlers. " + + "Check https://docs.restate.dev/develop/java/state for more details." + ) + } + } +} + +/** + * Kotlin-idiomatic request for invoking Restate services from within a handler. + * + * Example usage: + * ```kotlin + * toService() + * .request { it.add(1) } + * .withOptions { idempotencyKey = "123" } + * .call() + * ``` + * + * @param Req the request type + * @param Res the response type + */ +@org.jetbrains.annotations.ApiStatus.Experimental +interface KRequest : Request { + + /** + * Configure invocation options using a DSL. + * + * @param block builder block for options + * @return a new request with the configured options + */ + @org.jetbrains.annotations.ApiStatus.Experimental + fun withOptions(block: InvocationOptions.Builder.() -> Unit): KRequest + + /** + * Call the target handler and return a [CallDurableFuture] for the result. + * + * @return a [CallDurableFuture] that will contain the response + */ + @org.jetbrains.annotations.ApiStatus.Experimental suspend fun call(): CallDurableFuture + + /** + * Send the request without waiting for the response. + * + * @param delay optional delay before the invocation is executed + * @return an [InvocationHandle] to interact with the sent request + */ + @org.jetbrains.annotations.ApiStatus.Experimental + suspend fun send(delay: Duration? = null): InvocationHandle +} + +/** + * Builder for creating type-safe requests from within a handler. + * + * This builder allows the response type to be inferred from the lambda passed to [request]. + * + * @param SVC the service/virtual object/workflow class + */ +@org.jetbrains.annotations.ApiStatus.Experimental +class KRequestBuilder +@PublishedApi +internal constructor( + private val clazz: Class, + private val key: String?, +) { + /** + * Create a request by invoking a method on the target. + * + * The response type is inferred from the return type of the invoked method. + * + * @param Res the response type (inferred from the lambda) + * @param block a suspend lambda that invokes a method on the target + * @return a [KRequest] with the correct response type + */ + @Suppress("UNCHECKED_CAST") + fun request(block: suspend (SVC) -> Res): KRequest { + return KRequestImpl( + RequestCaptureProxy(clazz, key).capture(block as suspend (SVC) -> Any?).toRequest() + ) + as KRequest + } +} + +/** + * Create a builder for invoking a Restate service from within a handler. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): String { + * val result = toService() + * .request { it.greet("Alice") } + * .call() + * .await() + * return result + * } + * ``` + * + * @param SVC the service class annotated with @Service + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun toService(): KRequestBuilder { + ReflectionUtils.mustHaveServiceAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KRequestBuilder(SVC::class.java, null) +} + +/** + * Create a builder for invoking a Restate virtual object from within a handler. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): Long { + * val result = toVirtualObject("my-counter") + * .request { it.add(1) } + * .call() + * .await() + * return result + * } + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun toVirtualObject(key: String): KRequestBuilder { + ReflectionUtils.mustHaveVirtualObjectAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KRequestBuilder(SVC::class.java, key) +} + +/** + * Create a builder for invoking a Restate workflow from within a handler. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): String { + * val result = toWorkflow("workflow-123") + * .request { it.run("input") } + * .call() + * .await() + * return result + * } + * ``` + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a builder for creating typed requests + */ +@org.jetbrains.annotations.ApiStatus.Experimental +inline fun toWorkflow(key: String): KRequestBuilder { + ReflectionUtils.mustHaveWorkflowAnnotation(SVC::class.java) + require(ReflectionUtils.isKotlinClass(SVC::class.java)) { + "Using Java classes with Kotlin's API is not supported" + } + return KRequestBuilder(SVC::class.java, key) +} + +/** Implementation of [KRequest] for SDK context. */ +private class KRequestImpl(private val request: Request) : + KRequest, Request by request { + override fun withOptions(block: InvocationOptions.Builder.() -> Unit): KRequest { + val builder = InvocationOptions.builder() + builder.block() + return KRequestImpl( + this.toBuilder().headers(builder.headers).idempotencyKey(builder.idempotencyKey).build() + ) + } + + override suspend fun call(): CallDurableFuture { + return context().call(request) + } + + override suspend fun send(delay: Duration?): InvocationHandle { + return context().send(request, delay) + } +} + +/** + * Create a proxy client for a Restate service. + * + * This creates a proxy that allows calling service methods directly. The proxy intercepts method + * calls, converts them to Restate requests, and awaits the result. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): String { + * val greeter = service() + * val response = greeter.greet("Alice") + * return "Got: $response" + * } + * ``` + * + * @param SVC the service class annotated with @Service + * @return a proxy client to invoke the service + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun service(): SVC { + return service(SVC::class.java) +} + +/** + * Create a proxy client for a Restate virtual object. + * + * Example usage: + * ```kotlin + * @Handler + * suspend fun myHandler(): Long { + * val counter = virtualObject("my-counter") + * return counter.increment() + * } + * ``` + * + * @param SVC the virtual object class annotated with @VirtualObject + * @param key the key identifying the specific virtual object instance + * @return a proxy client to invoke the virtual object + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun virtualObject(key: String): SVC { + return virtualObject(SVC::class.java, key) +} + +/** + * Create a proxy client for a Restate workflow. + * + * @param SVC the workflow class annotated with @Workflow + * @param key the key identifying the specific workflow instance + * @return a proxy client to invoke the workflow + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend inline fun workflow(key: String): SVC { + return workflow(SVC::class.java, key) +} + +@PublishedApi +internal fun service(clazz: Class): SVC { + ReflectionUtils.mustHaveServiceAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + val serviceName = ReflectionUtils.extractServiceName(clazz) + + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, null).toRequest() + + // Last argument is the continuation for suspend functions + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { context().call(request).await() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +@PublishedApi +internal fun virtualObject(clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + val serviceName = ReflectionUtils.extractServiceName(clazz) + + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + + // Last argument is the continuation for suspend functions + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { context().call(request).await() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} + +@PublishedApi +internal fun workflow(clazz: Class, key: String): SVC { + ReflectionUtils.mustHaveWorkflowAnnotation(clazz) + require(ReflectionUtils.isKotlinClass(clazz)) { + "Using Java classes with Kotlin's API is not supported" + } + val serviceName = ReflectionUtils.extractServiceName(clazz) + + return ProxySupport.createProxy(clazz) { invocation -> + val request = invocation.captureInvocation(serviceName, key).toRequest() + + // Last argument is the continuation for suspend functions + @Suppress("UNCHECKED_CAST") val continuation = invocation.arguments.last() as Continuation + + // Start a coroutine that calls the client and resumes the continuation + val suspendBlock: suspend () -> Any? = { context().call(request).await() } + suspendBlock.startCoroutine(continuation) + COROUTINE_SUSPENDED + } +} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/MalformedRestateServiceException.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/MalformedRestateServiceException.kt new file mode 100644 index 000000000..12587722d --- /dev/null +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/MalformedRestateServiceException.kt @@ -0,0 +1,22 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.internal + +internal class MalformedRestateServiceException : Exception { + constructor( + serviceName: String, + message: String, + ) : super("Failed to instantiate Restate service '$serviceName'.\nReason: $message") + + constructor( + serviceName: String, + message: String, + cause: Throwable, + ) : super("Failed to instantiate Restate service '$serviceName'.\nReason: $message", cause) +} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt new file mode 100644 index 000000000..c63582fee --- /dev/null +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt @@ -0,0 +1,432 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.internal + +import dev.restate.common.reflections.ReflectionUtils +import dev.restate.sdk.annotation.Accept +import dev.restate.sdk.annotation.CustomSerdeFactory +import dev.restate.sdk.annotation.Exclusive +import dev.restate.sdk.annotation.Handler +import dev.restate.sdk.annotation.Json +import dev.restate.sdk.annotation.Raw +import dev.restate.sdk.annotation.Shared +import dev.restate.sdk.annotation.Workflow +import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.endpoint.definition.HandlerRunner +import dev.restate.sdk.endpoint.definition.HandlerType +import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory +import dev.restate.sdk.endpoint.definition.ServiceType +import dev.restate.sdk.kotlin.Context +import dev.restate.sdk.kotlin.ObjectContext +import dev.restate.sdk.kotlin.SharedObjectContext +import dev.restate.sdk.kotlin.SharedWorkflowContext +import dev.restate.sdk.kotlin.WorkflowContext +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory.KtTypeTag +import dev.restate.serde.provider.DefaultSerdeFactoryProvider +import java.lang.reflect.Modifier +import java.util.* +import kotlin.reflect.KFunction +import kotlin.reflect.KVisibility +import kotlin.reflect.full.callSuspend +import kotlin.reflect.full.findAnnotation +import kotlin.reflect.full.memberFunctions +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.javaMethod +import kotlin.reflect.jvm.jvmErasure + +internal class ReflectionServiceDefinitionFactory : ServiceDefinitionFactory { + @Volatile private var cachedDefaultSerdeFactory: SerdeFactory? = null + + override fun create( + serviceInstance: Any, + overrideHandlerOptions: HandlerRunner.Options?, + ): ServiceDefinition { + val handlerRunnerOptions: dev.restate.sdk.kotlin.HandlerRunner.Options? + if ( + overrideHandlerOptions == null || + overrideHandlerOptions is dev.restate.sdk.kotlin.HandlerRunner.Options + ) { + handlerRunnerOptions = overrideHandlerOptions + } else { + throw IllegalArgumentException( + "The provided options class MUST be instance of dev.restate.sdk.kotlin.HandlerRunner.Options, but was " + + overrideHandlerOptions.javaClass + ) + } + + val serviceClazz: Class<*> = serviceInstance.javaClass + + val hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(serviceClazz) + val hasVirtualObjectAnnotation = ReflectionUtils.hasVirtualObjectAnnotation(serviceClazz) + val hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(serviceClazz) + + val hasAnyAnnotation = + hasServiceAnnotation || hasVirtualObjectAnnotation || hasWorkflowAnnotation + if (!hasAnyAnnotation) { + throw MalformedRestateServiceException( + serviceClazz.simpleName, + "A restate component MUST be annotated with " + + "exactly one annotation between @Service/@VirtualObject/@Workflow, no annotation was found", + ) + } + val hasExactlyOneAnnotation = + hasServiceAnnotation xor (hasVirtualObjectAnnotation xor hasWorkflowAnnotation) + + if (!hasExactlyOneAnnotation) { + throw MalformedRestateServiceException( + serviceClazz.simpleName, + "A restate component MUST be annotated with " + + "exactly one annotation between @Service/@VirtualObject/@Workflow, more than one annotation found", + ) + } + + val serviceName = ReflectionUtils.extractServiceName(serviceClazz) + val serviceType = + if (hasServiceAnnotation) ServiceType.SERVICE + else if (hasVirtualObjectAnnotation) ServiceType.VIRTUAL_OBJECT else ServiceType.WORKFLOW + val serdeFactory: SerdeFactory = resolveSerdeFactory(serviceClazz) + + val kFunctions = + serviceClazz.kotlin.memberFunctions.filter { + it.javaMethod?.let { method -> + // Can't use findAnnotations because that won't walk the stack! + ReflectionUtils.findAnnotation(method, Handler::class.java) != null || + ReflectionUtils.findAnnotation(method, Shared::class.java) != null || + ReflectionUtils.findAnnotation(method, Workflow::class.java) != null || + ReflectionUtils.findAnnotation(method, Exclusive::class.java) != null + } ?: false + } + + if (kFunctions.isEmpty()) { + throw MalformedRestateServiceException(serviceName, "No @Handler method found") + } + return ServiceDefinition.of( + serviceName, + serviceType, + kFunctions + .map { + this.createHandlerDefinition( + serviceInstance, + it, + serviceName, + serviceType, + serdeFactory, + handlerRunnerOptions, + ) + } + .toList(), + ) + } + + private fun createHandlerDefinition( + serviceInstance: Any, + kFunction: KFunction<*>, + serviceName: String, + serviceType: ServiceType, + serdeFactory: SerdeFactory, + overrideHandlerOptions: dev.restate.sdk.kotlin.HandlerRunner.Options?, + ): HandlerDefinition<*, *> { + val handlerInfo: ReflectionUtils.HandlerInfo = + ReflectionUtils.mustHaveHandlerAnnotation(kFunction.javaMethod!!) + val handlerName: String? = handlerInfo.name + + // Check if this is a Kotlin suspend function + validateKFunction(kFunction, serviceName) + + val parameters = kFunction.valueParameters + + // Check for old-style context parameter + if ( + (parameters.size == 1 || parameters.size == 2) && + (parameters[0] == Context::class.java || + parameters[0] == SharedObjectContext::class.java || + parameters[0] == ObjectContext::class.java || + parameters[0] == WorkflowContext::class.java || + parameters[0] == SharedWorkflowContext::class.java) + ) { + val ctxTypeName = parameters[0].type.toString() + val returnTypeName = kFunction.returnType.toString() + val actualSignature = + if (parameters.size == 1) "ctx: $ctxTypeName" + else "ctx: $ctxTypeName, input: ${parameters[1].type}" + val expectedSignature = if (parameters.isEmpty()) "" else "input: ${parameters[1].type}" + throw MalformedRestateServiceException( + serviceName, + """ + The service is being loaded with the new Reflection based API, but handler '${handlerName}' contains $ctxTypeName as first parameter. Suggestions: + * If you want to use the new Reflection based API, remove $ctxTypeName from the method definition and use the functions from dev.restate.sdk.kotlin inside the handler: + - suspend fun ${handlerName}(${actualSignature}): $returnTypeName { + - // code + - } + Replace with: + + suspend fun ${handlerName}(${expectedSignature}): $returnTypeName { + + // Use functions from dev.restate.sdk.kotlin.* + + // code + + } + * If you''re still using the KSP based API, make sure the ServiceDefinitionFactory class was correctly generated. + """ + .trimIndent(), + ) + } + + if (parameters.size > 1) { + throw MalformedRestateServiceException( + serviceName, + "More than one parameter found in method ${kFunction.name}. Only zero or one parameter is supported.", + ) + } + + if (serviceType == ServiceType.SERVICE && handlerInfo.shared) { + throw MalformedRestateServiceException( + serviceName, + "@Shared is only supported on virtual objects and workflow handlers", + ) + } + val handlerType = + if (handlerInfo.shared) HandlerType.SHARED + else if (serviceType == ServiceType.VIRTUAL_OBJECT) HandlerType.EXCLUSIVE + else if (serviceType == ServiceType.WORKFLOW) HandlerType.WORKFLOW else null + + val inputSerde = + resolveInputSerde( + kFunction, + serdeFactory, + serviceName, + ) + val outputSerde = resolveOutputSerde(kFunction, serdeFactory, serviceName) + + val runner = + createSuspendHandlerRunner( + serviceInstance, + kFunction, + parameters.size, + serdeFactory, + overrideHandlerOptions, + ) + + var handlerDefinition: HandlerDefinition = + HandlerDefinition.of(handlerName, handlerType, inputSerde, outputSerde, runner) + + // Look for the accept annotation + if (parameters.isNotEmpty()) { + val acceptAnnotation: Accept? = parameters[0].findAnnotation() + if (acceptAnnotation != null) { + handlerDefinition = handlerDefinition.withAcceptContentType(acceptAnnotation.value) + } + } + + return handlerDefinition + } + + private fun createSuspendHandlerRunner( + serviceInstance: Any, + kFunction: KFunction<*>, + parameterCount: Int, + serdeFactory: SerdeFactory, + overrideHandlerOptions: dev.restate.sdk.kotlin.HandlerRunner.Options?, + ): dev.restate.sdk.kotlin.HandlerRunner { + return dev.restate.sdk.kotlin.HandlerRunner.of( + serdeFactory, + overrideHandlerOptions ?: dev.restate.sdk.kotlin.HandlerRunner.Options.DEFAULT, + ) { _, input -> + if (parameterCount == 0) { + kFunction.callSuspend(serviceInstance) + } else { + kFunction.callSuspend(serviceInstance, input) + } + } + } + + @Suppress("UNCHECKED_CAST") + private fun resolveInputSerde( + kFunction: KFunction<*>, + serdeFactory: SerdeFactory, + serviceName: String, + ): Serde { + if (kFunction.valueParameters.isEmpty()) { + return KotlinSerializationSerdeFactory.UNIT as Serde + } + + val parameter = kFunction.valueParameters[0] + + val rawAnnotation = parameter.findAnnotation() + val jsonAnnotation = parameter.findAnnotation() + + // Validate annotations + if (rawAnnotation != null && jsonAnnotation != null) { + throw MalformedRestateServiceException( + serviceName, + "Parameter in method ${kFunction.name} cannot be annotated with both @Raw and @Json", + ) + } + + if (rawAnnotation != null) { + // Validate parameter type is byte[] + if (parameter.type.jvmErasure != ByteArray::class) { + throw MalformedRestateServiceException( + serviceName, + "Parameter annotated with @Raw in method ${kFunction.name} MUST be of type ByteArray, was ${parameter.type}", + ) + } + var serde: Serde = Serde.RAW as Serde + // Apply content type if not default + if (rawAnnotation.contentType != "application/octet-stream") { + serde = Serde.withContentType(rawAnnotation.contentType, serde) + } + return serde + } + + // Use serdeFactory to create serde + var serde = + serdeFactory.create(KtTypeTag(parameter.type.jvmErasure, parameter.type)) + as Serde + + // Apply custom content-type from @Json if present + if (jsonAnnotation != null && jsonAnnotation.contentType != "application/json") { + serde = Serde.withContentType(jsonAnnotation.contentType, serde) + } + + return serde + } + + @Suppress("UNCHECKED_CAST") + private fun resolveOutputSerde( + kFunction: KFunction<*>, + serdeFactory: SerdeFactory, + serviceName: String, + ): Serde { + val outputType = kFunction.returnType + + // Handle Unit type (Kotlin void equivalent) + if (outputType == Void.TYPE || outputType.jvmErasure == Unit::class) { + return KotlinSerializationSerdeFactory.UNIT as Serde + } + + val rawAnnotation = kFunction.findAnnotation() + val jsonAnnotation = kFunction.findAnnotation() + + // Validate annotations + if (rawAnnotation != null && jsonAnnotation != null) { + throw MalformedRestateServiceException( + serviceName, + "Method ${kFunction.name} cannot be annotated with both @Raw and @Json", + ) + } + + if (rawAnnotation != null) { + // Validate return type is byte[] + if (outputType.jvmErasure != ByteArray::class) { + throw MalformedRestateServiceException( + serviceName, + "Method ${kFunction.name} annotated with @Raw MUST return byte[], was $outputType", + ) + } + var serde: Serde = Serde.RAW as Serde + // Apply content type if not default + if (rawAnnotation.contentType != "application/octet-stream") { + serde = Serde.withContentType(rawAnnotation.contentType, serde) + } + return serde + } + + // Use serdeFactory to create serde + var serde = + serdeFactory.create(KtTypeTag(outputType.jvmErasure, outputType)) as Serde + + // Apply custom content-type from @Json if present + if (jsonAnnotation != null && jsonAnnotation.contentType != "application/json") { + serde = Serde.withContentType(jsonAnnotation.contentType, serde) + } + + return serde + } + + private fun resolveSerdeFactory(serviceClazz: Class<*>): SerdeFactory { + // Check for CustomSerdeFactory annotation + val customSerdeFactoryAnnotation: CustomSerdeFactory? = + ReflectionUtils.findAnnotation( + serviceClazz, + CustomSerdeFactory::class.java, + ) + + if (customSerdeFactoryAnnotation != null) { + try { + return customSerdeFactoryAnnotation.value.java.getDeclaredConstructor().newInstance() + } catch (e: Exception) { + throw MalformedRestateServiceException( + serviceClazz.simpleName, + "Failed to instantiate custom SerdeFactory: ${customSerdeFactoryAnnotation.value.java.name}", + e, + ) + } + } + + // Try DefaultSerdeFactoryProvider -> if there's one, it's an easy pick! + if (this.cachedDefaultSerdeFactory != null) { + return this.cachedDefaultSerdeFactory!! + } + + val loadedFactories: MutableList?> = + ServiceLoader.load(DefaultSerdeFactoryProvider::class.java).stream().toList() + if (loadedFactories.size == 1) { + this.cachedDefaultSerdeFactory = loadedFactories[0]!!.get()!!.create() + return this.cachedDefaultSerdeFactory!! + } + + // Load kotlinx serde factory + try { + val jacksonSerdeFactoryClass = + Class.forName("dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory") + val defaultInstance = jacksonSerdeFactoryClass.getConstructor().newInstance() + this.cachedDefaultSerdeFactory = defaultInstance as SerdeFactory? + return this.cachedDefaultSerdeFactory!! + } catch (e: Exception) { + throw MalformedRestateServiceException( + serviceClazz.simpleName, + "Failed to load KotlinSerializationSerdeFactory for Kotlin service. " + + "Make sure sdk-serde-kotlinx is on the classpath.", + e, + ) + } + } + + override fun supports(serviceObject: Any?): Boolean { + return serviceObject?.javaClass?.let { ReflectionUtils.isKotlinClass(it) } ?: false + } + + override fun priority(): Int { + // Run before last - after code-generated factories, before java + return ServiceDefinitionFactory.LOWEST_PRIORITY - 1 + } + + private fun validateKFunction(kFunction: KFunction<*>, serviceName: String) { + if (!kFunction.isSuspend) { + throw MalformedRestateServiceException( + serviceName, + "Method '${kFunction.name}' is not a suspend function, this is not supported.", + ) + } + if (kFunction.visibility != KVisibility.PUBLIC) { + throw MalformedRestateServiceException( + serviceName, + "Method '${kFunction.name}' is not public.", + ) + } + if (Modifier.isStatic(kFunction.javaMethod!!.modifiers)) { + throw MalformedRestateServiceException( + serviceName, + "Method '" + kFunction.name + "' is static, cannot be used as Restate handler", + ) + } + } +} diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/RestateContextElement.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/RestateContextElement.kt new file mode 100644 index 000000000..e421c6a22 --- /dev/null +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/RestateContextElement.kt @@ -0,0 +1,24 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.kotlin.internal + +import dev.restate.sdk.kotlin.Context +import kotlin.coroutines.AbstractCoroutineContextElement +import kotlin.coroutines.CoroutineContext + +/** + * Coroutine context element that holds the Restate [Context]. + * + * This element is added to the coroutine context when a handler is invoked, allowing free-floating + * API functions like `context()`, `run()`, etc. to access the current context from within suspend + * functions. + */ +internal class RestateContextElement(val ctx: Context) : AbstractCoroutineContextElement(Key) { + companion object Key : CoroutineContext.Key +} diff --git a/sdk-api-kotlin/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory b/sdk-api-kotlin/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory new file mode 100644 index 000000000..cba8a5774 --- /dev/null +++ b/sdk-api-kotlin/src/main/resources/META-INF/services/dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory @@ -0,0 +1 @@ +dev.restate.sdk.kotlin.internal.ReflectionServiceDefinitionFactory diff --git a/sdk-core/build.gradle.kts b/sdk-core/build.gradle.kts index 3fd1c34ba..87ce62e49 100644 --- a/sdk-core/build.gradle.kts +++ b/sdk-core/build.gradle.kts @@ -150,6 +150,26 @@ tasks { } } +ksp { + val disabledClassesCodegen = + listOf( + "dev.restate.sdk.core.kotlinapi.reflections.CheckedException", + "dev.restate.sdk.core.kotlinapi.reflections.CustomSerdeService", + "dev.restate.sdk.core.kotlinapi.reflections.Empty", + "dev.restate.sdk.core.kotlinapi.reflections.GreeterInterface", + "dev.restate.sdk.core.kotlinapi.reflections.NestedDataClass", + "dev.restate.sdk.core.kotlinapi.reflections.CornerCases", + "dev.restate.sdk.core.kotlinapi.reflections.GreeterWithExplicitName", + "dev.restate.sdk.core.kotlinapi.reflections.MyWorkflow", + "dev.restate.sdk.core.kotlinapi.reflections.ObjectGreeter", + "dev.restate.sdk.core.kotlinapi.reflections.ObjectGreeterImplementedFromInterface", + "dev.restate.sdk.core.kotlinapi.reflections.PrimitiveTypes", + "dev.restate.sdk.core.kotlinapi.reflections.RawInputOutput", + "dev.restate.sdk.core.kotlinapi.reflections.ServiceGreeter", + ) + arg("dev.restate.codegen.disabledClasses", disabledClassesCodegen.joinToString(",")) +} + // spotless configuration for protobuf configure { diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt index ea132e9dc..e50305c56 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/KotlinAPITests.kt @@ -12,6 +12,7 @@ import dev.restate.common.Request import dev.restate.sdk.core.* import dev.restate.sdk.core.TestDefinitions.TestExecutor import dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder +import dev.restate.sdk.core.kotlinapi.reflections.ReflectionTest import dev.restate.sdk.core.statemachine.ProtoUtils import dev.restate.sdk.endpoint.definition.HandlerDefinition import dev.restate.sdk.endpoint.definition.HandlerType @@ -43,6 +44,7 @@ class KotlinAPITests : TestRunner() { UserFailuresTest(), RandomTest(), CodegenTest(), + ReflectionTest(), ) } diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt new file mode 100644 index 000000000..be615e446 --- /dev/null +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt @@ -0,0 +1,91 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi.reflections + +import dev.restate.sdk.core.AssertUtils.assertThatDiscovery +import dev.restate.sdk.core.generated.manifest.Handler +import dev.restate.sdk.core.generated.manifest.Input +import dev.restate.sdk.core.generated.manifest.Output +import dev.restate.sdk.core.generated.manifest.Service +import dev.restate.sdk.kotlin.endpoint.* +import org.assertj.core.api.InstanceOfAssertFactories.type +import org.junit.jupiter.api.Test + +class ReflectionDiscoveryTest { + + @Test + fun checkCustomInputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomCt") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } + + @Test + fun checkCustomInputAcceptContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomAccept") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/*") + } + + @Test + fun checkCustomOutputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutputWithCustomCT") + .extracting({ it.output }, type(Output::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } + + @Test + fun explicitNames() { + assertThatDiscovery( + object : GreeterWithExplicitName { + override suspend fun greet(request: String): String { + TODO("Not yet implemented") + } + } + ) + .extractingService("MyExplicitName") + .extractingHandler("my_greeter") + } + + @Test + fun workflowType() { + assertThatDiscovery(MyWorkflow()) + .extractingService("MyWorkflow") + .returns(Service.Ty.WORKFLOW) { obj -> obj.ty } + .extractingHandler("run") + .returns(Handler.Ty.WORKFLOW) { obj -> obj.ty } + } + + @Test + fun usingTransformer() { + assertThatDiscovery( + endpoint { + bind(RawInputOutput()) { + it.documentation = "My service documentation" + it.configureHandler("rawInputWithCustomCt") { + it.documentation = "My handler documentation" + } + } + } + ) + .extractingService("RawInputOutput") + .returns("My service documentation", Service::getDocumentation) + .extractingHandler("rawInputWithCustomCt") + .returns("My handler documentation", Handler::getDocumentation) + } +} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt new file mode 100644 index 000000000..7946a8643 --- /dev/null +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt @@ -0,0 +1,229 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi.reflections + +import dev.restate.common.Slice +import dev.restate.common.Target +import dev.restate.sdk.core.TestDefinitions +import dev.restate.sdk.core.TestDefinitions.TestDefinition +import dev.restate.sdk.core.TestDefinitions.testInvocation +import dev.restate.sdk.core.TestSerdes +import dev.restate.sdk.core.statemachine.ProtoUtils.* +import dev.restate.serde.Serde +import dev.restate.serde.kotlinx.* +import java.util.stream.Stream + +class ReflectionTest : TestDefinitions.TestSuite { + + override fun definitions(): Stream { + return Stream.of( + testInvocation({ ServiceGreeter() }, "greet") + .withInput(startMessage(1), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ ObjectGreeter() }, "greet") + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ ObjectGreeter() }, "sharedGreet") + .withInput(startMessage(1, "slinkydeveloper"), inputCmd("Francesco")) + .onlyBidiStream() + .expectingOutput(outputCmd("Francesco"), END_MESSAGE), + testInvocation({ NestedDataClass() }, "greet") + .withInput( + startMessage(1, "slinkydeveloper"), + inputCmd(jsonSerde(), NestedDataClass.Input("123")), + ) + .onlyBidiStream() + .expectingOutput( + outputCmd(jsonSerde(), NestedDataClass.Output("123")), + END_MESSAGE, + ), + testInvocation({ ObjectGreeterImplementedFromInterface() }, "greet") + .withInput( + startMessage(1, "slinkydeveloper"), + inputCmd("Francesco"), + callCompletion(2, "Francesco"), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.virtualObject("GreeterInterface", "slinkydeveloper", "greet"), + "Francesco", + ), + outputCmd("Francesco"), + END_MESSAGE, + ), + testInvocation({ Empty() }, "emptyInput") + .withInput(startMessage(1), inputCmd(), callCompletion(2, "Till")) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyInput")), + outputCmd("Till"), + END_MESSAGE, + ) + .named("empty output"), + testInvocation({ Empty() }, "emptyOutput") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyOutput"), "Francesco"), + outputCmd(), + END_MESSAGE, + ) + .named("empty output"), + testInvocation({ Empty() }, "emptyInputOutput") + .withInput(startMessage(1), inputCmd("Francesco"), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("Empty", "emptyInputOutput")), + outputCmd(), + END_MESSAGE, + ) + .named("empty input and empty output"), + testInvocation({ PrimitiveTypes() }, "primitiveOutput") + .withInput(startMessage(1), inputCmd(), callCompletion(2, TestSerdes.INT, 10)) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("PrimitiveTypes", "primitiveOutput"), + Serde.VOID, + null, + ), + outputCmd(TestSerdes.INT, 10), + END_MESSAGE, + ) + .named("primitive output"), + testInvocation({ PrimitiveTypes() }, "primitiveInput") + .withInput(startMessage(1), inputCmd(10), callCompletion(2, Serde.VOID, null)) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("PrimitiveTypes", "primitiveInput"), + TestSerdes.INT, + 10, + ), + outputCmd(), + END_MESSAGE, + ) + .named("primitive input"), + testInvocation({ RawInputOutput() }, "rawInput") + .withInput( + startMessage(1), + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), + ) + .onlyBidiStream() + .expectingOutput( + callCmd(1, 2, Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawInputWithCustomCt") + .withInput( + startMessage(1), + inputCmd("{{".toByteArray()), + callCompletion(2, KotlinSerializationSerdeFactory.UNIT, Unit), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawInputWithCustomCt"), + "{{".toByteArray(), + ), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawOutput") + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, Serde.RAW, "{{".toByteArray()), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutput"), + KotlinSerializationSerdeFactory.UNIT, + Unit, + ), + outputCmd("{{".toByteArray()), + END_MESSAGE, + ), + testInvocation({ RawInputOutput() }, "rawOutputWithCustomCT") + .withInput( + startMessage(1), + inputCmd(), + callCompletion(2, Serde.RAW, "{{".toByteArray()), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.service("RawInputOutput", "rawOutputWithCustomCT"), + KotlinSerializationSerdeFactory.UNIT, + Unit, + ), + outputCmd("{{".toByteArray()), + END_MESSAGE, + ), + testInvocation({ CornerCases() }, "returnNull") + .withInput( + startMessage(1, "mykey"), + inputCmd(jsonSerde(), null), + callCompletion(2, jsonSerde(), null), + ) + .onlyBidiStream() + .expectingOutput( + callCmd( + 1, + 2, + Target.virtualObject("CornerCases", "mykey", "returnNull"), + jsonSerde(), + null, + ), + outputCmd(jsonSerde(), null), + END_MESSAGE, + ), + testInvocation({ CornerCases() }, "badReturnTypeInferred") + .withInput(startMessage(1, "mykey"), inputCmd()) + .onlyBidiStream() + .expectingOutput( + oneWayCallCmd( + 1, + Target.virtualObject( + "CornerCases", + "mykey", + "badReturnTypeInferred", + ), + null, + null, + Slice.EMPTY, + ), + outputCmd(), + END_MESSAGE, + ), + testInvocation({ CustomSerdeService() }, "echo") + .withInput(startMessage(1), inputCmd(byteArrayOf(1))) + .onlyBidiStream() + .expectingOutput(outputCmd(byteArrayOf(1)), END_MESSAGE), + ) + } +} diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt new file mode 100644 index 000000000..1a0243283 --- /dev/null +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt @@ -0,0 +1,193 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.kotlinapi.reflections + +import dev.restate.sdk.annotation.* +import dev.restate.sdk.kotlin.* +import dev.restate.serde.Serde +import dev.restate.serde.SerdeFactory +import dev.restate.serde.TypeRef +import dev.restate.serde.TypeTag +import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory +import kotlinx.serialization.Serializable + +@Service +class ServiceGreeter { + @Handler + suspend fun greet(request: String): String { + return request + } +} + +@VirtualObject +class ObjectGreeter { + @Exclusive + suspend fun greet(request: String): String { + return request + } + + @Handler + @Shared + suspend fun sharedGreet(request: String): String { + return request + } +} + +@VirtualObject +class NestedDataClass { + @Serializable data class Input(val a: String) + + @Serializable data class Output(val a: String) + + @Exclusive + suspend fun greet(request: Input): Output { + return Output(request.a) + } + + @Exclusive + suspend fun complexType(request: Map>): Map> { + return mapOf() + } +} + +@VirtualObject +interface GreeterInterface { + @Exclusive suspend fun greet(request: String): String +} + +class ObjectGreeterImplementedFromInterface : GreeterInterface { + override suspend fun greet(request: String): String { + return virtualObject(key()).greet(request) + } +} + +@Service +@Name("Empty") +open class Empty { + @Handler + open suspend fun emptyInput(): String { + return service().emptyInput() + } + + @Handler + open suspend fun emptyOutput(request: String) { + service().emptyOutput(request) + } + + @Handler + open suspend fun emptyInputOutput() { + service().emptyInputOutput() + } +} + +@Service +@Name("PrimitiveTypes") +open class PrimitiveTypes { + @Handler + open suspend fun primitiveOutput(): Int { + return service().primitiveOutput() + } + + @Handler + open suspend fun primitiveInput(input: Int) { + service().primitiveInput(input) + } +} + +@VirtualObject +open class CornerCases { + + @Exclusive + open suspend fun returnNull(request: String?): String? { + return virtualObject(key()).returnNull(request) + } + + @Exclusive + open suspend fun badReturnTypeInferred(): Unit { + toVirtualObject(key()).request { it.badReturnTypeInferred() }.send() + } +} + +@Service +@Name("RawInputOutput") +open class RawInputOutput { + @Handler @Raw open suspend fun rawOutput(): ByteArray = service().rawOutput() + + @Handler + @Raw(contentType = "application/vnd.my.custom") + open suspend fun rawOutputWithCustomCT(): ByteArray = + service().rawOutputWithCustomCT() + + @Handler + open suspend fun rawInput(@Raw input: ByteArray) { + service().rawInput(input) + } + + @Handler + open suspend fun rawInputWithCustomCt( + @Raw(contentType = "application/vnd.my.custom") input: ByteArray + ) { + service().rawInputWithCustomCt(input) + } + + @Handler + open suspend fun rawInputWithCustomAccept( + @Accept("application/*") @Raw(contentType = "application/vnd.my.custom") input: ByteArray + ) { + service().rawInputWithCustomAccept(input) + } +} + +@Workflow +@Name("MyWorkflow") +open class MyWorkflow { + @Workflow + open suspend fun run(myInput: String) { + toWorkflow(key()).request { it.sharedHandler(myInput) }.send() + } + + @Handler + open suspend fun sharedHandler(myInput: String): String = + workflow(key()).sharedHandler(myInput) +} + +@Suppress("UNCHECKED_CAST") +class MyCustomSerdeFactory : SerdeFactory { + override fun create(typeTag: TypeTag): Serde { + check(typeTag is KotlinSerializationSerdeFactory.KtTypeTag) + check(typeTag.type == Byte::class) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } + + override fun create(typeRef: TypeRef): Serde { + check(typeRef.type == Byte::class) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } + + override fun create(clazz: Class?): Serde { + check(clazz == Byte::class.java) + return Serde.using({ b -> byteArrayOf(b) }, { it[0] }) as Serde + } +} + +@CustomSerdeFactory(MyCustomSerdeFactory::class) +@Service +@Name("CustomSerdeService") +class CustomSerdeService { + @Handler + suspend fun echo(input: Byte): Byte { + return input + } +} + +@Service +@Name("MyExplicitName") +interface GreeterWithExplicitName { + @Handler @Name("my_greeter") suspend fun greet(request: String): String +} diff --git a/sdk-spring-boot-kotlin-starter/build.gradle.kts b/sdk-spring-boot-kotlin-starter/build.gradle.kts index 2b7d6a125..4224c3b3f 100644 --- a/sdk-spring-boot-kotlin-starter/build.gradle.kts +++ b/sdk-spring-boot-kotlin-starter/build.gradle.kts @@ -7,6 +7,11 @@ plugins { description = "Restate SDK Spring Boot Kotlin starter" +configurations.all { + // Gonna conflict with sdk-serde-kotlinx + exclude(group = "dev.restate", module = "sdk-serde-jackson") +} + dependencies { compileOnly(libs.jspecify) @@ -30,3 +35,11 @@ dependencies { testImplementation(libs.jackson.databind) testRuntimeOnly(libs.junit.platform.launcher) } + +ksp { + val disabledClassesCodegen = + listOf( + "dev.restate.sdk.springboot.kotlin.GreeterNewApi", + ) + arg("dev.restate.codegen.disabledClasses", disabledClassesCodegen.joinToString(",")) +} diff --git a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/GreeterNewApi.kt b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/GreeterNewApi.kt new file mode 100644 index 000000000..840978445 --- /dev/null +++ b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/GreeterNewApi.kt @@ -0,0 +1,26 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.springboot.kotlin + +import dev.restate.sdk.annotation.Handler +import dev.restate.sdk.annotation.Name +import dev.restate.sdk.kotlin.runBlock +import dev.restate.sdk.springboot.RestateService +import org.springframework.beans.factory.annotation.Value + +@RestateService +@Name("greeterNewApi") +open class GreeterNewApi { + @Value($$"${greetingPrefix}") internal lateinit var greetingPrefix: String + + @Handler + open suspend fun greet(person: String): String { + return runBlock { greetingPrefix } + person + } +} diff --git a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt index c03889688..65fdf38e8 100644 --- a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt +++ b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt @@ -9,26 +9,48 @@ package dev.restate.sdk.springboot.kotlin import dev.restate.client.Client +import dev.restate.client.kotlin.service +import dev.restate.client.kotlin.toService import dev.restate.sdk.testing.BindService import dev.restate.sdk.testing.RestateClient import dev.restate.sdk.testing.RestateTest import kotlinx.coroutines.test.runTest -import org.assertj.core.api.Assertions +import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.junit.jupiter.api.Timeout import org.springframework.beans.factory.annotation.Autowired import org.springframework.boot.test.context.SpringBootTest -@SpringBootTest(classes = [Greeter::class], properties = ["greetingPrefix=Something something "]) +@SpringBootTest( + classes = [Greeter::class, GreeterNewApi::class], + properties = ["greetingPrefix=Something something "], +) @RestateTest(containerImage = "ghcr.io/restatedev/restate:main") class SdkTestingIntegrationTest { @Autowired @BindService lateinit var greeter: Greeter + @Autowired @BindService lateinit var greeterNewApi: GreeterNewApi @Test @Timeout(value = 10) fun greet(@RestateClient ingressClient: Client) = runTest { val client = GreeterClient.fromClient(ingressClient) - Assertions.assertThat(client.greet("Francesco")).isEqualTo("Something something Francesco") + assertThat(client.greet("Francesco")).isEqualTo("Something something Francesco") + } + + @Test + @Timeout(value = 10) + fun greetNewApi(@RestateClient ingressClient: Client) = runTest { + assertThat(ingressClient.service().greet("Francesco")) + .isEqualTo("Something something Francesco") + } + + @Test + @Timeout(value = 10) + fun greetNewApiWithRequestTo(@RestateClient ingressClient: Client) = runTest { + val response: String = + ingressClient.toService().request { it.greet("Francesco") }.call().response() + + assertThat(response).isEqualTo("Something something Francesco") } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 7513cfa09..339d9b2ed 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -15,6 +15,7 @@ include( "admin-client", "bytebuddy-proxy-support", "common", + "common-kotlin", "client", "client-kotlin", "sdk-common", From b1679bae7a95feeda9fe23bb31f92d71b6527d04 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 30 Jan 2026 17:04:50 +0100 Subject: [PATCH 02/10] Migrate test-services to new reflection api --- test-services/build.gradle.kts | 3 - .../sdk/testservices/AwakeableHolderImpl.kt | 15 +- .../testservices/BlockAndWaitWorkflowImpl.kt | 16 +- .../sdk/testservices/CancelTestImpl.kt | 36 ++--- .../restate/sdk/testservices/CounterImpl.kt | 25 ++- .../restate/sdk/testservices/FailingImpl.kt | 24 +-- .../restate/sdk/testservices/KillTestImpl.kt | 20 ++- .../sdk/testservices/ListObjectImpl.kt | 16 +- .../dev/restate/sdk/testservices/Main.kt | 31 ++-- .../restate/sdk/testservices/MapObjectImpl.kt | 19 ++- .../sdk/testservices/NonDeterministicImpl.kt | 58 +++---- .../dev/restate/sdk/testservices/ProxyImpl.kt | 90 ++++++----- .../sdk/testservices/TestUtilsServiceImpl.kt | 24 +-- .../VirtualObjectCommandInterpreterImpl.kt | 55 +++---- .../testservices/contracts/AwakeableHolder.kt | 6 +- .../contracts/BlockAndWaitWorkflow.kt | 6 +- .../sdk/testservices/contracts/CancelTest.kt | 22 +-- .../sdk/testservices/contracts/Counter.kt | 12 +- .../sdk/testservices/contracts/Failing.kt | 11 +- .../sdk/testservices/contracts/KillTest.kt | 6 +- .../sdk/testservices/contracts/ListObject.kt | 6 +- .../sdk/testservices/contracts/MapObject.kt | 11 +- .../contracts/NonDeterministic.kt | 8 +- .../sdk/testservices/contracts/Proxy.kt | 52 +++--- .../{TestUtils.kt => TestUtilsService.kt} | 15 +- .../VirtualObjectCommandInterpreter.kt | 12 +- .../sdk/testservices/contracts/interpreter.kt | 21 ++- .../restate/sdk/testservices/interpreter.kt | 148 ++++++++++-------- 28 files changed, 378 insertions(+), 390 deletions(-) rename test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/{TestUtils.kt => TestUtilsService.kt} (62%) diff --git a/test-services/build.gradle.kts b/test-services/build.gradle.kts index e52b9faf8..2abecd89f 100644 --- a/test-services/build.gradle.kts +++ b/test-services/build.gradle.kts @@ -3,14 +3,11 @@ import org.gradle.nativeplatform.platform.internal.DefaultNativePlatform.getCurr plugins { `java-conventions` `kotlin-conventions` - alias(libs.plugins.ksp) application alias(libs.plugins.jib) } dependencies { - ksp(project(":sdk-api-kotlin-gen")) - implementation(project(":sdk-kotlin-http")) implementation(project(":sdk-request-identity")) diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt index b273a58a7..c373689a0 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/AwakeableHolderImpl.kt @@ -18,17 +18,16 @@ class AwakeableHolderImpl : AwakeableHolder { private val ID_KEY: StateKey = stateKey("id") } - override suspend fun hold(context: ObjectContext, id: String) { - context.set(ID_KEY, id) + override suspend fun hold(id: String) { + state().set(ID_KEY, id) } - override suspend fun hasAwakeable(context: ObjectContext): Boolean { - return context.get(ID_KEY) != null + override suspend fun hasAwakeable(): Boolean { + return state().get(ID_KEY) != null } - override suspend fun unlock(context: ObjectContext, payload: String) { - val awakeableId: String = - context.get(ID_KEY) ?: throw TerminalException("No awakeable registered") - context.awakeableHandle(awakeableId).resolve(payload) + override suspend fun unlock(payload: String) { + val awakeableId = state().get(ID_KEY) ?: throw TerminalException("No awakeable registered") + awakeableHandle(awakeableId).resolve(payload) } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt index 386c4ccef..5d181ec83 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/BlockAndWaitWorkflowImpl.kt @@ -20,24 +20,24 @@ class BlockAndWaitWorkflowImpl : BlockAndWaitWorkflow { private val MY_STATE: StateKey = stateKey("my-state") } - override suspend fun run(context: WorkflowContext, input: String): String { - context.set(MY_STATE, input) + override suspend fun run(input: String): String { + state().set(MY_STATE, input) // Wait on unblock - val output: String = context.promise(MY_DURABLE_PROMISE).future().await() + val output: String = promise(MY_DURABLE_PROMISE).future().await() - if (!context.promise(MY_DURABLE_PROMISE).peek().isReady) { + if (!promise(MY_DURABLE_PROMISE).peek().isReady) { throw TerminalException("Durable promise should be completed") } return output } - override suspend fun unblock(context: SharedWorkflowContext, output: String) { - context.promiseHandle(MY_DURABLE_PROMISE).resolve(output) + override suspend fun unblock(output: String) { + promiseHandle(MY_DURABLE_PROMISE).resolve(output) } - override suspend fun getState(context: SharedWorkflowContext): String? { - return context.get(MY_STATE) + override suspend fun getState(): String? { + return state().get(MY_STATE) } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt index dbeec0fb2..dffae4ed5 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt @@ -11,10 +11,8 @@ package dev.restate.sdk.testservices import dev.restate.sdk.common.StateKey import dev.restate.sdk.common.TerminalException import dev.restate.sdk.kotlin.* -import dev.restate.sdk.testservices.contracts.AwakeableHolderClient -import dev.restate.sdk.testservices.contracts.BlockingOperation +import dev.restate.sdk.testservices.contracts.AwakeableHolder import dev.restate.sdk.testservices.contracts.CancelTest -import dev.restate.sdk.testservices.contracts.CancelTestBlockingServiceClient import kotlin.time.Duration.Companion.days class CancelTestImpl { @@ -23,45 +21,43 @@ class CancelTestImpl { private val CANCELED_STATE: StateKey = stateKey("canceled") } - override suspend fun startTest(context: ObjectContext, operation: BlockingOperation) { - val client = CancelTestBlockingServiceClient.fromContext(context, context.key()) - + override suspend fun startTest(operation: CancelTest.BlockingOperation) { try { - client.block(operation).await() + virtualObject(key()).block(operation) } catch (e: TerminalException) { if (e.code == TerminalException.CANCELLED_CODE) { - context.set(CANCELED_STATE, true) + state().set(CANCELED_STATE, true) } else { throw e } } } - override suspend fun verifyTest(context: ObjectContext): Boolean { - return context.get(CANCELED_STATE) ?: false + override suspend fun verifyTest(): Boolean { + return state().get(CANCELED_STATE) ?: false } } class BlockingService : CancelTest.BlockingService { - override suspend fun block(context: ObjectContext, operation: BlockingOperation) { - val self = CancelTestBlockingServiceClient.fromContext(context, context.key()) - val client = AwakeableHolderClient.fromContext(context, context.key()) + override suspend fun block(operation: CancelTest.BlockingOperation) { + val self = virtualObject(key()) + val awakeableHolder = virtualObject(key()) - val awakeable = context.awakeable() - client.hold(awakeable.id).await() + val awakeable = awakeable() + awakeableHolder.hold(awakeable.id) awakeable.await() when (operation) { - BlockingOperation.CALL -> self.block(operation).await() - BlockingOperation.SLEEP -> context.sleep(1024.days) - BlockingOperation.AWAKEABLE -> { - val uncompletable: Awakeable = context.awakeable() + CancelTest.BlockingOperation.CALL -> self.block(operation) + CancelTest.BlockingOperation.SLEEP -> sleep(1024.days) + CancelTest.BlockingOperation.AWAKEABLE -> { + val uncompletable: Awakeable = awakeable() uncompletable.await() } } } - override suspend fun isUnlocked(context: ObjectContext) { + override suspend fun isUnlocked() { // no-op } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt index ab97ba4f3..a1cb419b0 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt @@ -12,7 +12,6 @@ import dev.restate.sdk.common.StateKey import dev.restate.sdk.common.TerminalException import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.Counter -import dev.restate.sdk.testservices.contracts.CounterUpdateResponse import org.apache.logging.log4j.LogManager import org.apache.logging.log4j.Logger @@ -24,37 +23,37 @@ class CounterImpl : Counter { private val COUNTER_KEY: StateKey = stateKey("counter") } - override suspend fun reset(context: ObjectContext) { + override suspend fun reset() { logger.info("Counter cleaned up") - context.clear(COUNTER_KEY) + state().clear(COUNTER_KEY) } - override suspend fun addThenFail(context: ObjectContext, value: Long) { - var counter: Long = context.get(COUNTER_KEY) ?: 0L + override suspend fun addThenFail(value: Long) { + var counter: Long = state().get(COUNTER_KEY) ?: 0L logger.info("Old counter value: {}", counter) counter += value - context.set(COUNTER_KEY, counter) + state().set(COUNTER_KEY, counter) logger.info("New counter value: {}", counter) - throw TerminalException(context.key()) + throw TerminalException(key()) } - override suspend fun get(context: SharedObjectContext): Long { - val counter: Long = context.get(COUNTER_KEY) ?: 0L + override suspend fun get(): Long { + val counter: Long = state().get(COUNTER_KEY) ?: 0L logger.info("Get counter value: {}", counter) return counter } - override suspend fun add(context: ObjectContext, value: Long): CounterUpdateResponse { - val oldCount: Long = context.get(COUNTER_KEY) ?: 0L + override suspend fun add(value: Long): Counter.CounterUpdateResponse { + val oldCount: Long = state().get(COUNTER_KEY) ?: 0L val newCount = oldCount + value - context.set(COUNTER_KEY, newCount) + state().set(COUNTER_KEY, newCount) logger.info("Old counter value: {}", oldCount) logger.info("New counter value: {}", newCount) - return CounterUpdateResponse(oldCount, newCount) + return Counter.CounterUpdateResponse(oldCount, newCount) } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt index e3edf4d73..0aa1224d9 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/FailingImpl.kt @@ -9,11 +9,8 @@ package dev.restate.sdk.testservices import dev.restate.sdk.common.TerminalException -import dev.restate.sdk.kotlin.ObjectContext -import dev.restate.sdk.kotlin.retryPolicy -import dev.restate.sdk.kotlin.runBlock +import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.Failing -import dev.restate.sdk.testservices.contracts.FailingClient import java.util.concurrent.atomic.AtomicInteger import kotlin.time.Duration.Companion.milliseconds import org.apache.logging.log4j.LogManager @@ -28,26 +25,23 @@ class FailingImpl : Failing { private val eventualSuccessSideEffectCalls = AtomicInteger(0) private val eventualFailureSideEffectCalls = AtomicInteger(0) - override suspend fun terminallyFailingCall(context: ObjectContext, errorMessage: String) { + override suspend fun terminallyFailingCall(errorMessage: String) { LOG.info("Invoked fail") throw TerminalException(errorMessage) } override suspend fun callTerminallyFailingCall( - context: ObjectContext, errorMessage: String, ): String { LOG.info("Invoked failAndHandle") - FailingClient.fromContext(context, context.random().nextUUID().toString()) - .terminallyFailingCall(errorMessage) - .await() + virtualObject(random().nextUUID().toString()).terminallyFailingCall(errorMessage) throw IllegalStateException("This should be unreachable") } - override suspend fun failingCallWithEventualSuccess(context: ObjectContext): Int { + override suspend fun failingCallWithEventualSuccess(): Int { val currentAttempt = eventualSuccessCalls.incrementAndGet() if (currentAttempt >= 4) { @@ -58,17 +52,16 @@ class FailingImpl : Failing { } } - override suspend fun terminallyFailingSideEffect(context: ObjectContext, errorMessage: String) { - context.runBlock { throw TerminalException(errorMessage) } + override suspend fun terminallyFailingSideEffect(errorMessage: String) { + runBlock { throw TerminalException(errorMessage) } throw IllegalStateException("Should not be reached.") } override suspend fun sideEffectSucceedsAfterGivenAttempts( - context: ObjectContext, minimumAttempts: Int, ): Int = - context.runBlock( + runBlock( name = "failing_side_effect", retryPolicy = retryPolicy { @@ -86,11 +79,10 @@ class FailingImpl : Failing { } override suspend fun sideEffectFailsAfterGivenAttempts( - context: ObjectContext, retryPolicyMaxRetryCount: Int, ): Int { try { - context.runBlock( + runBlock( name = "failing_side_effect", retryPolicy = retryPolicy { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt index 42ed6d362..93341255c 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt @@ -8,10 +8,9 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices -import dev.restate.sdk.kotlin.ObjectContext -import dev.restate.sdk.testservices.contracts.AwakeableHolderClient +import dev.restate.sdk.kotlin.* +import dev.restate.sdk.testservices.contracts.AwakeableHolder import dev.restate.sdk.testservices.contracts.KillTest -import dev.restate.sdk.testservices.contracts.KillTestSingletonClient import dev.restate.serde.Serde class KillTestImpl { @@ -21,22 +20,21 @@ class KillTestImpl { // This will ensure that we have a call tree that is two calls deep and has a pending invocation // in the inbox: // startCallTree --> recursiveCall --> recursiveCall:inboxed - override suspend fun startCallTree(context: ObjectContext) { - KillTestSingletonClient.fromContext(context, context.key()).recursiveCall().await() + override suspend fun startCallTree() { + virtualObject(key()).recursiveCall() } } class SingletonImpl : KillTest.Singleton { - override suspend fun recursiveCall(context: ObjectContext) { - val awakeable = context.awakeable(Serde.RAW) - AwakeableHolderClient.fromContext(context, context.key()).send().hold(awakeable.id) - + override suspend fun recursiveCall() { + val awakeable = awakeable(Serde.RAW) + toVirtualObject(key()).request { it.hold(awakeable.id) }.send() awakeable.await() - KillTestSingletonClient.fromContext(context, context.key()).recursiveCall().await() + virtualObject(key()).recursiveCall() } - override suspend fun isUnlocked(context: ObjectContext) { + override suspend fun isUnlocked() { // no-op } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt index 884d1ab93..54440a2d5 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ListObjectImpl.kt @@ -20,18 +20,18 @@ class ListObjectImpl : ListObject { ) } - override suspend fun append(context: ObjectContext, value: String) { - val list = context.get(LIST_KEY) ?: emptyList() - context.set(LIST_KEY, list + value) + override suspend fun append(value: String) { + val list = state().get(LIST_KEY) ?: emptyList() + state().set(LIST_KEY, list + value) } - override suspend fun get(context: ObjectContext): List { - return context.get(LIST_KEY) ?: emptyList() + override suspend fun get(): List { + return state().get(LIST_KEY) ?: emptyList() } - override suspend fun clear(context: ObjectContext): List { - val result = context.get(LIST_KEY) ?: emptyList() - context.clear(LIST_KEY) + override suspend fun clear(): List { + val result = state().get(LIST_KEY) ?: emptyList() + state().clear(LIST_KEY) return result } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt index 6d556b3d1..978cac115 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/Main.kt @@ -8,6 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.testservices +import dev.restate.common.reflections.ReflectionUtils.extractServiceName import dev.restate.sdk.auth.signing.RestateRequestIdentityVerifier import dev.restate.sdk.http.vertx.RestateHttpServer import dev.restate.sdk.kotlin.endpoint.endpoint @@ -15,30 +16,30 @@ import dev.restate.sdk.testservices.contracts.* val KNOWN_SERVICES_FACTORIES: Map Any> = mapOf( - AwakeableHolderHandlers.Metadata.SERVICE_NAME to { AwakeableHolderImpl() }, - BlockAndWaitWorkflowHandlers.Metadata.SERVICE_NAME to { BlockAndWaitWorkflowImpl() }, - CancelTestBlockingServiceHandlers.Metadata.SERVICE_NAME to + extractServiceName(AwakeableHolder::class.java) to { AwakeableHolderImpl() }, + extractServiceName(BlockAndWaitWorkflow::class.java) to { BlockAndWaitWorkflowImpl() }, + extractServiceName(CancelTest.BlockingService::class.java) to { CancelTestImpl.BlockingService() }, - CancelTestRunnerHandlers.Metadata.SERVICE_NAME to { CancelTestImpl.RunnerImpl() }, - CounterHandlers.Metadata.SERVICE_NAME to { CounterImpl() }, - FailingHandlers.Metadata.SERVICE_NAME to { FailingImpl() }, - KillTestRunnerHandlers.Metadata.SERVICE_NAME to { KillTestImpl.RunnerImpl() }, - KillTestSingletonHandlers.Metadata.SERVICE_NAME to { KillTestImpl.SingletonImpl() }, - ListObjectHandlers.Metadata.SERVICE_NAME to { ListObjectImpl() }, - MapObjectHandlers.Metadata.SERVICE_NAME to { MapObjectImpl() }, - NonDeterministicHandlers.Metadata.SERVICE_NAME to { NonDeterministicImpl() }, - ProxyHandlers.Metadata.SERVICE_NAME to { ProxyImpl() }, - TestUtilsServiceHandlers.Metadata.SERVICE_NAME to { TestUtilsServiceImpl() }, - VirtualObjectCommandInterpreterHandlers.Metadata.SERVICE_NAME to + extractServiceName(CancelTest.Runner::class.java) to { CancelTestImpl.RunnerImpl() }, + extractServiceName(Counter::class.java) to { CounterImpl() }, + extractServiceName(Failing::class.java) to { FailingImpl() }, + extractServiceName(KillTest.Runner::class.java) to { KillTestImpl.RunnerImpl() }, + extractServiceName(KillTest.Singleton::class.java) to { KillTestImpl.SingletonImpl() }, + extractServiceName(ListObject::class.java) to { ListObjectImpl() }, + extractServiceName(MapObject::class.java) to { MapObjectImpl() }, + extractServiceName(NonDeterministic::class.java) to { NonDeterministicImpl() }, + extractServiceName(Proxy::class.java) to { ProxyImpl() }, + extractServiceName(TestUtilsService::class.java) to { TestUtilsServiceImpl() }, + extractServiceName(VirtualObjectCommandInterpreter::class.java) to { VirtualObjectCommandInterpreterImpl() }, interpreterName(0) to { ObjectInterpreterImpl.getInterpreterDefinition(0) }, interpreterName(1) to { ObjectInterpreterImpl.getInterpreterDefinition(1) }, interpreterName(2) to { ObjectInterpreterImpl.getInterpreterDefinition(2) }, - ServiceInterpreterHelperHandlers.Metadata.SERVICE_NAME to + extractServiceName(ServiceInterpreterHelper::class.java) to { ServiceInterpreterHelperImpl() }, diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt index 71745347e..69263d6f2 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/MapObjectImpl.kt @@ -9,26 +9,25 @@ package dev.restate.sdk.testservices import dev.restate.sdk.kotlin.* -import dev.restate.sdk.testservices.contracts.Entry import dev.restate.sdk.testservices.contracts.MapObject class MapObjectImpl : MapObject { - override suspend fun set(context: ObjectContext, entry: Entry) { - context.set(stateKey(entry.key), entry.value) + override suspend fun set(entry: MapObject.Entry) { + state().set(stateKey(entry.key), entry.value) } - override suspend fun get(context: ObjectContext, key: String): String { - return context.get(stateKey(key)) ?: "" + override suspend fun get(key: String): String { + return state().get(stateKey(key)) ?: "" } - override suspend fun clearAll(context: ObjectContext): List { - val keys = context.stateKeys() + override suspend fun clearAll(): List { + val keys = state().keys() // AH AH AH and here I wanna see if you really respect determinism!!! - val result = mutableListOf() + val result = mutableListOf() for (k in keys) { - result.add(Entry(k, context.get(stateKey(k))!!)) + result.add(MapObject.Entry(k, state().get(stateKey(k))!!)) } - context.clearAll() + state().clearAll() return result } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt index 7509c541a..e1c56ab72 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt @@ -10,7 +10,7 @@ package dev.restate.sdk.testservices import dev.restate.sdk.common.StateKey import dev.restate.sdk.kotlin.* -import dev.restate.sdk.testservices.contracts.CounterClient +import dev.restate.sdk.testservices.contracts.Counter import dev.restate.sdk.testservices.contracts.NonDeterministic import java.util.concurrent.ConcurrentHashMap import kotlin.time.Duration.Companion.milliseconds @@ -20,57 +20,57 @@ class NonDeterministicImpl : NonDeterministic { private val STATE_A: StateKey = stateKey("a") private val STATE_B: StateKey = stateKey("b") - override suspend fun eitherSleepOrCall(context: ObjectContext) { - if (doLeftAction(context)) { - context.sleep(100.milliseconds) + override suspend fun eitherSleepOrCall() { + if (doLeftAction()) { + sleep(100.milliseconds) } else { - CounterClient.fromContext(context, "abc").get().await() + virtualObject("abc").get() } // This is required to cause a suspension after the non-deterministic operation - context.sleep(100.milliseconds) - incrementCounter(context) + sleep(100.milliseconds) + incrementCounter() } - override suspend fun callDifferentMethod(context: ObjectContext) { - if (doLeftAction(context)) { - CounterClient.fromContext(context, "abc").get().await() + override suspend fun callDifferentMethod() { + if (doLeftAction()) { + virtualObject("abc").get() } else { - CounterClient.fromContext(context, "abc").reset().await() + virtualObject("abc").reset() } // This is required to cause a suspension after the non-deterministic operation - context.sleep(100.milliseconds) - incrementCounter(context) + sleep(100.milliseconds) + incrementCounter() } - override suspend fun backgroundInvokeWithDifferentTargets(context: ObjectContext) { - if (doLeftAction(context)) { - CounterClient.fromContext(context, "abc").send().get() + override suspend fun backgroundInvokeWithDifferentTargets() { + if (doLeftAction()) { + toVirtualObject("abc").request { it.get() }.send() } else { - CounterClient.fromContext(context, "abc").send().reset() + toVirtualObject("abc").request { it.reset() }.send() } // This is required to cause a suspension after the non-deterministic operation - context.sleep(100.milliseconds) - incrementCounter(context) + sleep(100.milliseconds) + incrementCounter() } - override suspend fun setDifferentKey(context: ObjectContext) { - if (doLeftAction(context)) { - context.set(STATE_A, "my-state") + override suspend fun setDifferentKey() { + if (doLeftAction()) { + state().set(STATE_A, "my-state") } else { - context.set(STATE_B, "my-state") + state().set(STATE_B, "my-state") } // This is required to cause a suspension after the non-deterministic operation - context.sleep(100.milliseconds) - incrementCounter(context) + sleep(100.milliseconds) + incrementCounter() } - private suspend fun incrementCounter(context: ObjectContext) { - CounterClient.fromContext(context, context.key()).send().add(1) + private suspend fun incrementCounter() { + toVirtualObject("abc").request { it.add(1) }.send() } - private fun doLeftAction(context: ObjectContext): Boolean { + private suspend fun doLeftAction(): Boolean { // Test runner sets an appropriate key here - val countKey = context.key() + val countKey = key() return invocationCounts.merge(countKey, 1) { a: Int, b: Int -> a + b }!! % 2 == 1 } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt index 21fba1b2f..ba2795c6f 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/ProxyImpl.kt @@ -11,15 +11,13 @@ package dev.restate.sdk.testservices import dev.restate.common.Request import dev.restate.common.Target import dev.restate.sdk.kotlin.* -import dev.restate.sdk.testservices.contracts.ManyCallRequest import dev.restate.sdk.testservices.contracts.Proxy -import dev.restate.sdk.testservices.contracts.ProxyRequest import dev.restate.serde.Serde import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds class ProxyImpl : Proxy { - private fun ProxyRequest.toTarget(): Target { + private fun Proxy.ProxyRequest.toTarget(): Target { return if (this.virtualObjectKey == null) { Target.service(this.serviceName, this.handlerName) } else { @@ -27,58 +25,66 @@ class ProxyImpl : Proxy { } } - override suspend fun call(context: Context, request: ProxyRequest): ByteArray { - return Request.of(request.toTarget(), Serde.RAW, Serde.RAW, request.message) - .also { - if (request.idempotencyKey != null) { - it.idempotencyKey = request.idempotencyKey - } - } - .call(context) + override suspend fun call(request: Proxy.ProxyRequest): ByteArray { + return context() + .call( + Request.of(request.toTarget(), Serde.RAW, Serde.RAW, request.message).also { + if (request.idempotencyKey != null) { + it.idempotencyKey = request.idempotencyKey + } + } + ) .await() } - override suspend fun oneWayCall(context: Context, request: ProxyRequest): String = - Request.of(request.toTarget(), Serde.RAW, Serde.SLICE, request.message) - .also { - if (request.idempotencyKey != null) { - it.idempotencyKey = request.idempotencyKey - } - } - .send(context, request.delayMillis?.milliseconds ?: Duration.ZERO) + override suspend fun oneWayCall(request: Proxy.ProxyRequest): String = + context() + .send( + Request.of(request.toTarget(), Serde.RAW, Serde.SLICE, request.message).also { + if (request.idempotencyKey != null) { + it.idempotencyKey = request.idempotencyKey + } + }, + request.delayMillis?.milliseconds ?: Duration.ZERO, + ) .invocationId() - override suspend fun manyCalls(context: Context, requests: List) { + override suspend fun manyCalls(requests: List) { val toAwait = mutableListOf>() for (request in requests) { if (request.oneWayCall) { - Request.of( - request.proxyRequest.toTarget(), - Serde.RAW, - Serde.SLICE, - request.proxyRequest.message, + context() + .send( + Request.of( + request.proxyRequest.toTarget(), + Serde.RAW, + Serde.SLICE, + request.proxyRequest.message, + ) + .also { + if (request.proxyRequest.idempotencyKey != null) { + it.idempotencyKey = request.proxyRequest.idempotencyKey + } + }, + request.proxyRequest.delayMillis?.milliseconds ?: Duration.ZERO, ) - .also { - if (request.proxyRequest.idempotencyKey != null) { - it.idempotencyKey = request.proxyRequest.idempotencyKey - } - } - .send(context, request.proxyRequest.delayMillis?.milliseconds ?: Duration.ZERO) } else { val fut = - Request.of( - request.proxyRequest.toTarget(), - Serde.RAW, - Serde.RAW, - request.proxyRequest.message, + context() + .call( + Request.of( + request.proxyRequest.toTarget(), + Serde.RAW, + Serde.RAW, + request.proxyRequest.message, + ) + .also { + if (request.proxyRequest.idempotencyKey != null) { + it.idempotencyKey = request.proxyRequest.idempotencyKey + } + } ) - .also { - if (request.proxyRequest.idempotencyKey != null) { - it.idempotencyKey = request.proxyRequest.idempotencyKey - } - } - .call(context) if (request.awaitAtTheEnd) { toAwait.add(fut) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt index d040e1dec..398cdcacb 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/TestUtilsServiceImpl.kt @@ -15,40 +15,40 @@ import java.util.concurrent.atomic.AtomicInteger import kotlin.time.Duration.Companion.milliseconds class TestUtilsServiceImpl : TestUtilsService { - override suspend fun echo(context: Context, input: String): String { + override suspend fun echo(input: String): String { return input } - override suspend fun uppercaseEcho(context: Context, input: String): String { + override suspend fun uppercaseEcho(input: String): String { return input.uppercase(Locale.getDefault()) } - override suspend fun echoHeaders(context: Context): Map { - return context.request().headers + override suspend fun echoHeaders(): Map { + return request().headers } - override suspend fun rawEcho(context: Context, input: ByteArray): ByteArray { - check(input.contentEquals(context.request().bodyAsByteArray)) + override suspend fun rawEcho(input: ByteArray): ByteArray { + check(input.contentEquals(request().bodyAsByteArray)) return input } - override suspend fun sleepConcurrently(context: Context, millisDuration: List) { - val timers = millisDuration.map { context.timer(it.milliseconds) }.toList() + override suspend fun sleepConcurrently(millisDuration: List) { + val timers = millisDuration.map { timer("${it.milliseconds}ms", it.milliseconds) }.toList() timers.awaitAll() } - override suspend fun countExecutedSideEffects(context: Context, increments: Int): Int { + override suspend fun countExecutedSideEffects(increments: Int): Int { val invokedSideEffects = AtomicInteger(0) for (i in 0..(invocationId).cancel() + override suspend fun cancelInvocation(invocationId: String) { + invocationHandle(invocationId).cancel() } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/VirtualObjectCommandInterpreterImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/VirtualObjectCommandInterpreterImpl.kt index 81414fdba..1ebdede17 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/VirtualObjectCommandInterpreterImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/VirtualObjectCommandInterpreterImpl.kt @@ -24,7 +24,6 @@ class VirtualObjectCommandInterpreterImpl : VirtualObjectCommandInterpreter { } override suspend fun interpretCommands( - context: ObjectContext, req: VirtualObjectCommandInterpreter.InterpretRequest, ): String { LOG.info("Interpreting commands {}", req) @@ -35,7 +34,7 @@ class VirtualObjectCommandInterpreterImpl : VirtualObjectCommandInterpreter { LOG.info("Start interpreting command {}", it) when (it) { is VirtualObjectCommandInterpreter.AwaitAny -> { - val cmds = it.commands.map { it.toAwaitable(context) } + val cmds = it.commands.map { it.toAwaitable() } result = select { for (cmd in cmds) { @@ -45,7 +44,7 @@ class VirtualObjectCommandInterpreterImpl : VirtualObjectCommandInterpreter { .await() } is VirtualObjectCommandInterpreter.AwaitAnySuccessful -> { - val cmds = it.commands.map { it.toAwaitable(context) }.toMutableList() + val cmds = it.commands.map { it.toAwaitable() }.toMutableList() while (true) { @Suppress("UNCHECKED_CAST") @@ -61,22 +60,22 @@ class VirtualObjectCommandInterpreterImpl : VirtualObjectCommandInterpreter { } } is VirtualObjectCommandInterpreter.AwaitOne -> { - result = it.command.toAwaitable(context).await() + result = it.command.toAwaitable().await() } is VirtualObjectCommandInterpreter.GetEnvVariable -> { - result = context.runBlock { System.getenv(it.envName) ?: "" } + result = runBlock { System.getenv(it.envName) ?: "" } } is VirtualObjectCommandInterpreter.ResolveAwakeable -> { - resolveAwakeable(context, it) + resolveAwakeable(it) result = "" } is VirtualObjectCommandInterpreter.RejectAwakeable -> { - rejectAwakeable(context, it) + rejectAwakeable(it) result = "" } is VirtualObjectCommandInterpreter.AwaitAwakeableOrTimeout -> { - val awk = context.awakeable() - context.set("awk-${it.awakeableKey}", awk.id) + val awk = awakeable() + state().set("awk-${it.awakeableKey}", awk.id) try { result = awk.await(it.timeoutMillis.milliseconds) } catch (_: TimeoutException) { @@ -85,60 +84,54 @@ class VirtualObjectCommandInterpreterImpl : VirtualObjectCommandInterpreter { } } LOG.info("Command result {}", result) - appendResult(context, result) + appendResult(result) } return result } override suspend fun resolveAwakeable( - context: SharedObjectContext, resolveAwakeable: VirtualObjectCommandInterpreter.ResolveAwakeable, ) { - context - .awakeableHandle( - context.get("awk-${resolveAwakeable.awakeableKey}") + awakeableHandle( + state().get("awk-${resolveAwakeable.awakeableKey}") ?: throw TerminalException("awakeable is not registerd yet") ) .resolve(resolveAwakeable.value) } override suspend fun rejectAwakeable( - context: SharedObjectContext, rejectAwakeable: VirtualObjectCommandInterpreter.RejectAwakeable, ) { - context - .awakeableHandle( - context.get("awk-${rejectAwakeable.awakeableKey}") + awakeableHandle( + state().get("awk-${rejectAwakeable.awakeableKey}") ?: throw TerminalException("awakeable is not registerd yet") ) .reject(rejectAwakeable.reason) } - override suspend fun hasAwakeable(context: SharedObjectContext, awakeableKey: String): Boolean = - !context.get("awk-$awakeableKey").isNullOrBlank() + override suspend fun hasAwakeable(awakeableKey: String): Boolean = + !state().get("awk-$awakeableKey").isNullOrBlank() - override suspend fun getResults(context: SharedObjectContext): List = - context.get("results") ?: listOf() + override suspend fun getResults(): List = state().get("results") ?: listOf() - private suspend fun VirtualObjectCommandInterpreter.AwaitableCommand.toAwaitable( - ctx: ObjectContext - ): DurableFuture { + private suspend fun VirtualObjectCommandInterpreter.AwaitableCommand.toAwaitable(): + DurableFuture { return when (this) { is VirtualObjectCommandInterpreter.CreateAwakeable -> { - val awk = ctx.awakeable() - ctx.set("awk-${this.awakeableKey}", awk.id) + val awk = awakeable() + state().set("awk-${this.awakeableKey}", awk.id) awk } is VirtualObjectCommandInterpreter.RunThrowTerminalException -> - ctx.runAsync("should-fail-with-${this.reason}") { + runAsync("should-fail-with-${this.reason}") { throw TerminalException(this.reason) } is VirtualObjectCommandInterpreter.Sleep -> - ctx.timer(this.timeoutMillis.milliseconds).map { "sleep" } + timer("command-timer", this.timeoutMillis.milliseconds).map { "sleep" } } } - private suspend fun appendResult(ctx: ObjectContext, newResult: String) = - ctx.set("results", (ctx.get("results") ?: listOf()) + listOf(newResult)) + private suspend fun appendResult(newResult: String) = + state().set("results", (state().get("results") ?: listOf()) + listOf(newResult)) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/AwakeableHolder.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/AwakeableHolder.kt index 524bd29b4..ca165b0f0 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/AwakeableHolder.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/AwakeableHolder.kt @@ -16,9 +16,9 @@ import dev.restate.sdk.kotlin.* @VirtualObject @Name("AwakeableHolder") interface AwakeableHolder { - @Exclusive suspend fun hold(context: ObjectContext, id: String) + @Exclusive suspend fun hold(id: String) - @Exclusive suspend fun hasAwakeable(context: ObjectContext): Boolean + @Exclusive suspend fun hasAwakeable(): Boolean - @Exclusive suspend fun unlock(context: ObjectContext, payload: String) + @Exclusive suspend fun unlock(payload: String) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/BlockAndWaitWorkflow.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/BlockAndWaitWorkflow.kt index ea6126bff..d5cde2da6 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/BlockAndWaitWorkflow.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/BlockAndWaitWorkflow.kt @@ -14,9 +14,9 @@ import dev.restate.sdk.kotlin.* @Workflow @Name("BlockAndWaitWorkflow") interface BlockAndWaitWorkflow { - @Workflow suspend fun run(context: WorkflowContext, input: String): String + @Workflow suspend fun run(input: String): String - @Shared suspend fun unblock(context: SharedWorkflowContext, output: String) + @Shared suspend fun unblock(output: String) - @Shared suspend fun getState(context: SharedWorkflowContext): String? + @Shared suspend fun getState(): String? } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/CancelTest.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/CancelTest.kt index fbc524baa..d7613bd21 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/CancelTest.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/CancelTest.kt @@ -12,28 +12,28 @@ import dev.restate.sdk.annotation.* import dev.restate.sdk.kotlin.* import kotlinx.serialization.Serializable -@Serializable -enum class BlockingOperation { - CALL, - SLEEP, - AWAKEABLE, -} - interface CancelTest { + @Serializable + enum class BlockingOperation { + CALL, + SLEEP, + AWAKEABLE, + } + @VirtualObject @Name("CancelTestRunner") interface Runner { - @Exclusive suspend fun startTest(context: ObjectContext, operation: BlockingOperation) + @Handler suspend fun startTest(operation: BlockingOperation) - @Exclusive suspend fun verifyTest(context: ObjectContext): Boolean + @Handler suspend fun verifyTest(): Boolean } @VirtualObject @Name("CancelTestBlockingService") interface BlockingService { - @Exclusive suspend fun block(context: ObjectContext, operation: BlockingOperation) + @Handler suspend fun block(operation: BlockingOperation) - @Exclusive suspend fun isUnlocked(context: ObjectContext) + @Handler suspend fun isUnlocked() } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Counter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Counter.kt index d7a7e2d67..34024297e 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Counter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Counter.kt @@ -12,20 +12,20 @@ import dev.restate.sdk.annotation.* import dev.restate.sdk.kotlin.* import kotlinx.serialization.Serializable -@Serializable data class CounterUpdateResponse(val oldValue: Long, val newValue: Long) - @VirtualObject @Name("Counter") interface Counter { + @Serializable data class CounterUpdateResponse(val oldValue: Long, val newValue: Long) + /** Add value to counter */ - @Handler suspend fun add(context: ObjectContext, value: Long): CounterUpdateResponse + @Handler suspend fun add(value: Long): CounterUpdateResponse /** Add value to counter, then fail with a Terminal error */ - @Handler suspend fun addThenFail(context: ObjectContext, value: Long) + @Handler suspend fun addThenFail(value: Long) /** Get count */ - @Shared suspend fun get(context: SharedObjectContext): Long + @Shared suspend fun get(): Long /** Reset count */ - @Handler suspend fun reset(context: ObjectContext) + @Handler suspend fun reset() } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Failing.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Failing.kt index b2e68240a..4d5d1a921 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Failing.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Failing.kt @@ -14,14 +14,13 @@ import dev.restate.sdk.kotlin.* @VirtualObject @Name("Failing") interface Failing { - @Handler suspend fun terminallyFailingCall(context: ObjectContext, errorMessage: String) + @Handler suspend fun terminallyFailingCall(errorMessage: String) - @Handler - suspend fun callTerminallyFailingCall(context: ObjectContext, errorMessage: String): String + @Handler suspend fun callTerminallyFailingCall(errorMessage: String): String - @Handler suspend fun failingCallWithEventualSuccess(context: ObjectContext): Int + @Handler suspend fun failingCallWithEventualSuccess(): Int - @Handler suspend fun terminallyFailingSideEffect(context: ObjectContext, errorMessage: String) + @Handler suspend fun terminallyFailingSideEffect(errorMessage: String) /** * `minimumAttempts` should be used to check when to succeed. The retry policy should be @@ -32,7 +31,6 @@ interface Failing { */ @Handler suspend fun sideEffectSucceedsAfterGivenAttempts( - context: ObjectContext, minimumAttempts: Int, ): Int @@ -44,7 +42,6 @@ interface Failing { */ @Handler suspend fun sideEffectFailsAfterGivenAttempts( - context: ObjectContext, retryPolicyMaxRetryCount: Int, ): Int } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/KillTest.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/KillTest.kt index c88275a35..13916eb05 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/KillTest.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/KillTest.kt @@ -15,14 +15,14 @@ interface KillTest { @VirtualObject @Name("KillTestRunner") interface Runner { - @Handler suspend fun startCallTree(context: ObjectContext) + @Handler suspend fun startCallTree() } @VirtualObject @Name("KillTestSingleton") interface Singleton { - @Handler suspend fun recursiveCall(context: ObjectContext) + @Handler suspend fun recursiveCall() - @Handler suspend fun isUnlocked(context: ObjectContext) + @Handler suspend fun isUnlocked() } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/ListObject.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/ListObject.kt index 933bd94a2..5398b4692 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/ListObject.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/ListObject.kt @@ -15,11 +15,11 @@ import dev.restate.sdk.kotlin.* @Name("ListObject") interface ListObject { /** Append a value to the list object */ - @Handler suspend fun append(context: ObjectContext, value: String) + @Handler suspend fun append(value: String) /** Get current list */ - @Handler suspend fun get(context: ObjectContext): List + @Handler suspend fun get(): List /** Clear list */ - @Handler suspend fun clear(context: ObjectContext): List + @Handler suspend fun clear(): List } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/MapObject.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/MapObject.kt index a17c0b0d5..ca4024c4f 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/MapObject.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/MapObject.kt @@ -12,22 +12,23 @@ import dev.restate.sdk.annotation.* import dev.restate.sdk.kotlin.* import kotlinx.serialization.Serializable -@Serializable data class Entry(val key: String, val value: String) - @VirtualObject @Name("MapObject") interface MapObject { + + @Serializable data class Entry(val key: String, val value: String) + /** * Set value in map. * * The individual entries should be stored as separate Restate state keys, and not in a single * state key */ - @Handler suspend fun set(context: ObjectContext, entry: Entry) + @Handler suspend fun set(entry: Entry) /** Get value from map. */ - @Handler suspend fun get(context: ObjectContext, key: String): String + @Handler suspend fun get(key: String): String /** Clear all entries */ - @Handler suspend fun clearAll(context: ObjectContext): List + @Handler suspend fun clearAll(): List } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/NonDeterministic.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/NonDeterministic.kt index 67f594890..ea83f4ecc 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/NonDeterministic.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/NonDeterministic.kt @@ -15,11 +15,11 @@ import dev.restate.sdk.kotlin.* @Name("NonDeterministic") interface NonDeterministic { /** On first invocation sleeps, on second invocation calls */ - @Handler suspend fun eitherSleepOrCall(context: ObjectContext) + @Handler suspend fun eitherSleepOrCall() - @Handler suspend fun callDifferentMethod(context: ObjectContext) + @Handler suspend fun callDifferentMethod() - @Handler suspend fun backgroundInvokeWithDifferentTargets(context: ObjectContext) + @Handler suspend fun backgroundInvokeWithDifferentTargets() - @Handler suspend fun setDifferentKey(context: ObjectContext) + @Handler suspend fun setDifferentKey() } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Proxy.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Proxy.kt index 82b51a548..c7392cb55 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Proxy.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/Proxy.kt @@ -12,37 +12,37 @@ import dev.restate.sdk.annotation.* import dev.restate.sdk.kotlin.* import kotlinx.serialization.Serializable -@Serializable -data class ProxyRequest( - val serviceName: String, - val virtualObjectKey: String? = null, // If null, the request is to a service - val handlerName: String, - // Bytes are encoded as array of numbers - val message: ByteArray, - val delayMillis: Int? = null, - val idempotencyKey: String? = null, -) - -@Serializable -data class ManyCallRequest( - val proxyRequest: ProxyRequest, - /** If true, perform a one way call instead of a regular call */ - val oneWayCall: Boolean, - /** - * If await at the end, then perform the call as regular call, and collect all the futures to - * wait at the end, before returning, instead of awaiting them immediately. - */ - val awaitAtTheEnd: Boolean, -) - @Service @Name("Proxy") interface Proxy { + @Serializable + data class ProxyRequest( + val serviceName: String, + val virtualObjectKey: String? = null, // If null, the request is to a service + val handlerName: String, + // Bytes are encoded as array of numbers + val message: ByteArray, + val delayMillis: Int? = null, + val idempotencyKey: String? = null, + ) + + @Serializable + data class ManyCallRequest( + val proxyRequest: ProxyRequest, + /** If true, perform a one way call instead of a regular call */ + val oneWayCall: Boolean, + /** + * If await at the end, then perform the call as regular call, and collect all the futures to + * wait at the end, before returning, instead of awaiting them immediately. + */ + val awaitAtTheEnd: Boolean, + ) + // Bytes are encoded as array of numbers - @Handler suspend fun call(context: Context, request: ProxyRequest): ByteArray + @Handler suspend fun call(request: ProxyRequest): ByteArray // Returns the invocation id of the call - @Handler suspend fun oneWayCall(context: Context, request: ProxyRequest): String + @Handler suspend fun oneWayCall(request: ProxyRequest): String - @Handler suspend fun manyCalls(context: Context, requests: List) + @Handler suspend fun manyCalls(requests: List) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/TestUtils.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/TestUtilsService.kt similarity index 62% rename from test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/TestUtils.kt rename to test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/TestUtilsService.kt index b03e01178..ee966f810 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/TestUtils.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/TestUtilsService.kt @@ -9,26 +9,25 @@ package dev.restate.sdk.testservices.contracts import dev.restate.sdk.annotation.* -import dev.restate.sdk.kotlin.* /** Collection of various utilities/corner cases scenarios used by tests */ @Service @Name("TestUtilsService") interface TestUtilsService { /** Just echo */ - @Handler suspend fun echo(context: Context, input: String): String + @Handler suspend fun echo(input: String): String /** Just echo but with uppercase */ - @Handler suspend fun uppercaseEcho(context: Context, input: String): String + @Handler suspend fun uppercaseEcho(input: String): String /** Echo ingress headers */ - @Handler suspend fun echoHeaders(context: Context): Map + @Handler suspend fun echoHeaders(): Map /** Just echo */ - @Handler @Raw suspend fun rawEcho(context: Context, @Raw input: ByteArray): ByteArray + @Handler @Raw suspend fun rawEcho(@Raw input: ByteArray): ByteArray /** Create timers and await them all. Durations in milliseconds */ - @Handler suspend fun sleepConcurrently(context: Context, millisDuration: List) + @Handler suspend fun sleepConcurrently(millisDuration: List) /** * Invoke `ctx.run` incrementing a local variable counter (not a restate state key!). @@ -37,8 +36,8 @@ interface TestUtilsService { * * This is used to verify acks will suspend when using the always suspend test-suite */ - @Handler suspend fun countExecutedSideEffects(context: Context, increments: Int): Int + @Handler suspend fun countExecutedSideEffects(increments: Int): Int /** Cancel invocation using the context. */ - @Handler suspend fun cancelInvocation(context: Context, invocationId: String) + @Handler suspend fun cancelInvocation(invocationId: String) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/VirtualObjectCommandInterpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/VirtualObjectCommandInterpreter.kt index fcd2ae31c..ec962eb08 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/VirtualObjectCommandInterpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/VirtualObjectCommandInterpreter.kt @@ -75,15 +75,13 @@ interface VirtualObjectCommandInterpreter { * For each command, the output should be appended to the given list name. Returns the result of * the last command, or empty string otherwise. */ - @Handler suspend fun interpretCommands(context: ObjectContext, req: InterpretRequest): String + @Handler suspend fun interpretCommands(req: InterpretRequest): String - @Shared - suspend fun resolveAwakeable(context: SharedObjectContext, resolveAwakeable: ResolveAwakeable) + @Shared suspend fun resolveAwakeable(resolveAwakeable: ResolveAwakeable) - @Shared - suspend fun rejectAwakeable(context: SharedObjectContext, rejectAwakeable: RejectAwakeable) + @Shared suspend fun rejectAwakeable(rejectAwakeable: RejectAwakeable) - @Shared suspend fun hasAwakeable(context: SharedObjectContext, awakeableKey: String): Boolean + @Shared suspend fun hasAwakeable(awakeableKey: String): Boolean - @Shared suspend fun getResults(context: SharedObjectContext): List + @Shared suspend fun getResults(): List } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/interpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/interpreter.kt index 41090d344..de9e5cb2c 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/interpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/contracts/interpreter.kt @@ -231,9 +231,9 @@ object CommandSerializer : @Name("ObjectInterpreter") interface ObjectInterpreter { - @Shared suspend fun counter(ctx: SharedObjectContext): Int + @Shared suspend fun counter(): Int - @Handler suspend fun interpret(ctx: ObjectContext, program: Program) + @Handler suspend fun interpret(program: Program) } @Serializable data class EchoLaterRequest(val sleep: Int, val parameter: String) @@ -249,20 +249,19 @@ data class IncrementViaAwakeableDanceRequest( @Service @Name("ServiceInterpreterHelper") interface ServiceInterpreterHelper { - @Handler suspend fun ping(ctx: Context) + @Handler suspend fun ping() - @Handler suspend fun echo(ctx: Context, param: String): String + @Handler suspend fun echo(param: String): String - @Handler suspend fun echoLater(ctx: Context, req: EchoLaterRequest): String + @Handler suspend fun echoLater(req: EchoLaterRequest): String - @Handler suspend fun terminalFailure(ctx: Context) + @Handler suspend fun terminalFailure() - @Handler suspend fun incrementIndirectly(ctx: Context, id: InterpreterId) + @Handler suspend fun incrementIndirectly(id: InterpreterId) - @Handler suspend fun resolveAwakeable(ctx: Context, id: String) + @Handler suspend fun resolveAwakeable(id: String) - @Handler suspend fun rejectAwakeable(ctx: Context, id: String) + @Handler suspend fun rejectAwakeable(id: String) - @Handler - suspend fun incrementViaAwakeableDance(ctx: Context, req: IncrementViaAwakeableDanceRequest) + @Handler suspend fun incrementViaAwakeableDance(req: IncrementViaAwakeableDanceRequest) } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt index 1541bd263..86da8cd09 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt @@ -10,18 +10,21 @@ package dev.restate.sdk.testservices import dev.restate.common.Request import dev.restate.common.Target +import dev.restate.common.reflections.ReflectionUtils import dev.restate.sdk.common.StateKey import dev.restate.sdk.common.TerminalException import dev.restate.sdk.endpoint.definition.ServiceDefinition +import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactories import dev.restate.sdk.kotlin.* import dev.restate.sdk.testservices.contracts.* import dev.restate.sdk.testservices.contracts.Program import dev.restate.serde.Serde +import dev.restate.serde.kotlinx.typeTag import kotlin.random.Random import kotlin.time.Duration.Companion.milliseconds fun interpreterName(layer: Int): String { - return "${ObjectInterpreterHandlers.Metadata.SERVICE_NAME}L$layer" + return "${ReflectionUtils.extractServiceName(ObjectInterpreter::class.java)}L$layer" } fun interpretTarget(layer: Int, key: String): Target { @@ -66,8 +69,9 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { private val COUNTER: StateKey = stateKey("counter") fun getInterpreterDefinition(layer: Int): ServiceDefinition { + val serviceImpl = ObjectInterpreterImpl(layer) val originalDefinition = - ObjectInterpreterServiceDefinitionFactory().create(ObjectInterpreterImpl(layer), null) + ServiceDefinitionFactories.discover(serviceImpl).create(serviceImpl, null) return ServiceDefinition.of( interpreterName(layer), originalDefinition.serviceType, @@ -76,15 +80,15 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } } - private fun interpreterId(ctx: SharedObjectContext): InterpreterId { - return InterpreterId(layer, ctx.key()) + private suspend fun interpreterId(): InterpreterId { + return InterpreterId(layer, key()) } - override suspend fun counter(ctx: SharedObjectContext): Int { - return ctx.get(COUNTER) ?: 0 + override suspend fun counter(): Int { + return state().get(COUNTER) ?: 0 } - override suspend fun interpret(ctx: ObjectContext, program: Program) { + override suspend fun interpret(program: Program) { val promises: MutableMap Unit> = mutableMapOf() for ((i, cmd) in program.commands.withIndex()) { when (cmd) { @@ -99,58 +103,66 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } is CallObject -> { val awaitable = - ctx.call( - Request.of( - interpretTarget(layer + 1, cmd.key.toString()), - ObjectInterpreterHandlers.Metadata.Serde.INTERPRET_INPUT, - ObjectInterpreterHandlers.Metadata.Serde.INTERPRET_OUTPUT, - cmd.program, + context() + .call( + Request.of( + interpretTarget(layer + 1, cmd.key.toString()), + typeTag(), + typeTag(), + cmd.program, + ) ) - ) promises[i] = { awaitable.await() } } is CallService -> { val expected = "hello-$i" - val awaitable = ServiceInterpreterHelperHandlers.echo(expected).call(ctx) + val awaitable = toService().request { it.echo(expected) }.call() promises[i] = { checkAwaitable(awaitable, expected, i, cmd) } } is CallSlowService -> { val expected = "hello-$i" val awaitable = - ServiceInterpreterHelperHandlers.echoLater(EchoLaterRequest(cmd.sleep, expected)) - .call(ctx) + toService() + .request { it.echoLater(EchoLaterRequest(cmd.sleep, expected)) } + .call() promises[i] = { checkAwaitable(awaitable, expected, i, cmd) } } is ClearState -> { - ctx.clear(cmdStateKey(cmd.key)) + state().clear(cmdStateKey(cmd.key)) } is GetState -> { - ctx.get(cmdStateKey(cmd.key)) + state().get(cmdStateKey(cmd.key)) } is IncrementStateCounter -> { - ctx.set(COUNTER, (ctx.get(COUNTER) ?: 0) + 1) + state().set(COUNTER, (state().get(COUNTER) ?: 0) + 1) } is IncrementStateCounterIndirectly -> { - ServiceInterpreterHelperHandlers.incrementIndirectly(interpreterId(ctx)).send(ctx) + toService() + .request { it.incrementIndirectly(interpreterId()) } + .send() } is IncrementStateCounterViaAwakeable -> { // Dancing in the mooonlight! - val awakeable = ctx.awakeable() - ServiceInterpreterHelperHandlers.incrementViaAwakeableDance( - IncrementViaAwakeableDanceRequest(interpreterId(ctx), awakeable.id) - ) - .send(ctx) + val awakeable = awakeable() + toService() + .request { + it.incrementViaAwakeableDance( + IncrementViaAwakeableDanceRequest(interpreterId(), awakeable.id) + ) + } + .send() val theirPromiseIdForUsToResolve = awakeable.await() - ctx.awakeableHandle(theirPromiseIdForUsToResolve).resolve("ok") + awakeableHandle(theirPromiseIdForUsToResolve).resolve("ok") } is IncrementViaDelayedCall -> { - ServiceInterpreterHelperHandlers.incrementIndirectly(interpreterId(ctx)) - .send(ctx, delay = cmd.duration.milliseconds) + toService() + .request { it.incrementIndirectly(interpreterId()) } + .send(delay = cmd.duration.milliseconds) } is RecoverTerminalCall -> { var caught = false try { - ServiceInterpreterHelperHandlers.terminalFailure().call(ctx).await() + service().terminalFailure() } catch (e: TerminalException) { caught = true } @@ -161,37 +173,38 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } } is RecoverTerminalCallMaybeUnAwaited -> { - val awaitable = ServiceInterpreterHelperHandlers.terminalFailure().call(ctx) + val awaitable = + toService().request { it.terminalFailure() }.call() promises[i] = { checkAwaitableFails(awaitable, i, cmd) } } is RejectAwakeable -> { - val awakeable = ctx.awakeable() + val awakeable = awakeable() promises[i] = { checkAwaitableFails(awakeable, i, cmd) } - ServiceInterpreterHelperHandlers.rejectAwakeable(awakeable.id).send(ctx) + toService().request { it.rejectAwakeable(awakeable.id) }.send() } is ResolveAwakeable -> { - val awakeable = ctx.awakeable() + val awakeable = awakeable() promises[i] = { checkAwaitable(awakeable, "ok", i, cmd) } - ServiceInterpreterHelperHandlers.resolveAwakeable(awakeable.id).send(ctx) + toService().request { it.resolveAwakeable(awakeable.id) }.send() } is SetState -> { - ctx.set(cmdStateKey(cmd.key), "value-${cmd.key}") + state().set(cmdStateKey(cmd.key), "value-${cmd.key}") } is SideEffect -> { val expected = "hello-$i" - val result = ctx.runBlock { expected } + val result = runBlock { expected } if (result != expected) { throw TerminalException("Side effect result don't match: $result != $expected") } } is Sleep -> { - ctx.sleep(cmd.duration.milliseconds) + sleep(cmd.duration.milliseconds) } is SlowSideEffect -> { - ctx.runBlock { kotlinx.coroutines.delay(1.milliseconds) } + runBlock { kotlinx.coroutines.delay(1.milliseconds) } } is ThrowingSideEffect -> { - ctx.runBlock { + runBlock { check(Random.nextBoolean()) { "Random failure caused by a very cool language." } } } @@ -201,53 +214,53 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } class ServiceInterpreterHelperImpl : ServiceInterpreterHelper { - override suspend fun ping(ctx: Context) {} + override suspend fun ping() {} - override suspend fun echo(ctx: Context, param: String): String { + override suspend fun echo(param: String): String { return param } - override suspend fun echoLater(ctx: Context, req: EchoLaterRequest): String { - ctx.sleep(req.sleep.milliseconds) + override suspend fun echoLater(req: EchoLaterRequest): String { + sleep(req.sleep.milliseconds) return req.parameter } - override suspend fun terminalFailure(ctx: Context) { + override suspend fun terminalFailure() { throw TerminalException("bye") } - override suspend fun incrementIndirectly(ctx: Context, id: InterpreterId) { + override suspend fun incrementIndirectly(id: InterpreterId) { val ignored = - ctx.send( - Request.of( - interpretTarget(id.layer, id.key), - ObjectInterpreterHandlers.Metadata.Serde.INTERPRET_INPUT, - Serde.SLICE, - Program(listOf(IncrementStateCounter())), + context() + .send( + Request.of( + interpretTarget(id.layer, id.key), + typeTag(), + Serde.SLICE, + Program(listOf(IncrementStateCounter())), + ) ) - ) } - override suspend fun resolveAwakeable(ctx: Context, id: String) { - ctx.awakeableHandle(id).resolve("ok") + override suspend fun resolveAwakeable(id: String) { + awakeableHandle(id).resolve("ok") } - override suspend fun rejectAwakeable(ctx: Context, id: String) { - ctx.awakeableHandle(id).resolve("error") + override suspend fun rejectAwakeable(id: String) { + awakeableHandle(id).resolve("error") } override suspend fun incrementViaAwakeableDance( - ctx: Context, req: IncrementViaAwakeableDanceRequest, ) { // // 1. create an awakeable that we will be blocked on // - val awakeable = ctx.awakeable() + val awakeable = awakeable() // // 2. send our awakeable id to the interpreter via txPromise. // - ctx.awakeableHandle(req.txPromiseId).resolve(awakeable.id) + awakeableHandle(req.txPromiseId).resolve(awakeable.id) // // 3. wait for the interpreter resolve us // @@ -256,13 +269,14 @@ class ServiceInterpreterHelperImpl : ServiceInterpreterHelper { // 4. to thank our interpret, let us ask it to inc its state. // val ignored = - ctx.send( - Request.of( - interpretTarget(req.interpreter.layer, req.interpreter.key), - ObjectInterpreterHandlers.Metadata.Serde.INTERPRET_INPUT, - Serde.SLICE, - Program(listOf(IncrementStateCounter())), + context() + .send( + Request.of( + interpretTarget(req.interpreter.layer, req.interpreter.key), + typeTag(), + Serde.SLICE, + Program(listOf(IncrementStateCounter())), + ) ) - ) } } From 42a80ce2bdd307bccd89d06f01db66060d40c1d2 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 30 Jan 2026 20:46:02 +0100 Subject: [PATCH 03/10] Change a bit the `request` api --- .../dev/restate/client/kotlin/ingress.kt | 12 +++++------ .../reflection/kotlin/RequestCaptureProxy.kt | 4 ++-- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 12 +++++------ .../reflections/ReflectionDiscoveryTest.kt | 21 +++++++++++++++++++ .../core/kotlinapi/reflections/testClasses.kt | 4 ++-- .../kotlin/SdkTestingIntegrationTest.kt | 2 +- .../restate/sdk/testservices/KillTestImpl.kt | 2 +- .../sdk/testservices/NonDeterministicImpl.kt | 6 +++--- .../restate/sdk/testservices/interpreter.kt | 16 +++++++------- 9 files changed, 50 insertions(+), 29 deletions(-) diff --git a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt index db59e7f01..34a5ed1a8 100644 --- a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt +++ b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt @@ -429,10 +429,10 @@ internal constructor( * @return a [KClientRequest] with the correct response type */ @Suppress("UNCHECKED_CAST") - fun request(block: suspend (SVC) -> Res): KClientRequest { + fun request(block: suspend SVC.() -> Res): KClientRequest { return KClientRequestImpl( client, - RequestCaptureProxy(clazz, key).capture(block as suspend (SVC) -> Any?).toRequest(), + RequestCaptureProxy(clazz, key).capture(block as suspend SVC.() -> Any?).toRequest(), ) as KClientRequest } @@ -444,7 +444,7 @@ internal constructor( * Example usage: * ```kotlin * client.toService() - * .request { it.add(1) } + * .request { add(1) } * .withOptions { idempotencyKey = "123" } * .call() * ``` @@ -485,7 +485,7 @@ interface KClientRequest : Request { * Example usage: * ```kotlin * val response = client.toService() - * .request { it.greet("Alice") } + * .request { greet("Alice") } * .call() * ``` * @@ -507,7 +507,7 @@ inline fun Client.toService(): KClientRequestBuilder { * Example usage: * ```kotlin * val response = client.toVirtualObject("my-counter") - * .request { it.add(1) } + * .request { add(1) } * .call() * ``` * @@ -530,7 +530,7 @@ inline fun Client.toVirtualObject(key: String): KClientReque * Example usage: * ```kotlin * val response = client.toWorkflow("workflow-123") - * .request { it.run("input") } + * .request { run("input") } * .call() * ``` * diff --git a/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt b/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt index e98ed10b3..d72295faa 100644 --- a/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt +++ b/common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt @@ -37,7 +37,7 @@ class RequestCaptureProxy(private val clazz: Class, private val * @param block the suspend lambda that invokes a method on the service proxy * @return the captured invocation information */ - fun capture(block: suspend (SVC) -> Any?): CapturedInvocation { + fun capture(block: suspend SVC.() -> Any?): CapturedInvocation { var capturedInvocation: CapturedInvocation? = null val proxy = @@ -60,7 +60,7 @@ class RequestCaptureProxy(private val clazz: Class, private val } } - val suspendBlock: suspend () -> Any? = { block(proxy) } + val suspendBlock: suspend () -> Any? = { proxy.block() } suspendBlock.startCoroutine(capturingContinuation) return capturedInvocation diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index 471609d6e..6d51ac5c1 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -1142,7 +1142,7 @@ private class KotlinStateImpl( * Example usage: * ```kotlin * toService() - * .request { it.add(1) } + * .request { add(1) } * .withOptions { idempotencyKey = "123" } * .call() * ``` @@ -1203,9 +1203,9 @@ internal constructor( * @return a [KRequest] with the correct response type */ @Suppress("UNCHECKED_CAST") - fun request(block: suspend (SVC) -> Res): KRequest { + fun request(block: suspend SVC.() -> Res): KRequest { return KRequestImpl( - RequestCaptureProxy(clazz, key).capture(block as suspend (SVC) -> Any?).toRequest() + RequestCaptureProxy(clazz, key).capture(block as suspend SVC.() -> Any?).toRequest() ) as KRequest } @@ -1219,7 +1219,7 @@ internal constructor( * @Handler * suspend fun myHandler(): String { * val result = toService() - * .request { it.greet("Alice") } + * .request { greet("Alice") } * .call() * .await() * return result @@ -1246,7 +1246,7 @@ inline fun toService(): KRequestBuilder { * @Handler * suspend fun myHandler(): Long { * val result = toVirtualObject("my-counter") - * .request { it.add(1) } + * .request { add(1) } * .call() * .await() * return result @@ -1274,7 +1274,7 @@ inline fun toVirtualObject(key: String): KRequestBuilder("workflow-123") - * .request { it.run("input") } + * .request { run("input") } * .call() * .await() * return result diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt index be615e446..5ca02abfa 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt @@ -14,6 +14,7 @@ import dev.restate.sdk.core.generated.manifest.Input import dev.restate.sdk.core.generated.manifest.Output import dev.restate.sdk.core.generated.manifest.Service import dev.restate.sdk.kotlin.endpoint.* +import dev.restate.serde.Serde import org.assertj.core.api.InstanceOfAssertFactories.type import org.junit.jupiter.api.Test @@ -49,6 +50,26 @@ class ReflectionDiscoveryTest { .isEqualTo("application/vnd.my.custom") } + @Test + fun checkRawInputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInput") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo(Serde.RAW.contentType()) + } + + @Test + fun checkRawOutputContentType() { + assertThatDiscovery(RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutput") + .extracting({ it.output }, type(Output::class.java)) + .extracting { it.contentType } + .isEqualTo(Serde.RAW.contentType()) + } + @Test fun explicitNames() { assertThatDiscovery( diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt index 1a0243283..5c87fc83c 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt @@ -110,7 +110,7 @@ open class CornerCases { @Exclusive open suspend fun badReturnTypeInferred(): Unit { - toVirtualObject(key()).request { it.badReturnTypeInferred() }.send() + toVirtualObject(key()).request { badReturnTypeInferred() }.send() } } @@ -149,7 +149,7 @@ open class RawInputOutput { open class MyWorkflow { @Workflow open suspend fun run(myInput: String) { - toWorkflow(key()).request { it.sharedHandler(myInput) }.send() + toWorkflow(key()).request { sharedHandler(myInput) }.send() } @Handler diff --git a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt index 65fdf38e8..2015f253d 100644 --- a/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt +++ b/sdk-spring-boot-kotlin-starter/src/test/kotlin/dev/restate/sdk/springboot/kotlin/SdkTestingIntegrationTest.kt @@ -49,7 +49,7 @@ class SdkTestingIntegrationTest { @Timeout(value = 10) fun greetNewApiWithRequestTo(@RestateClient ingressClient: Client) = runTest { val response: String = - ingressClient.toService().request { it.greet("Francesco") }.call().response() + ingressClient.toService().request { greet("Francesco") }.call().response() assertThat(response).isEqualTo("Something something Francesco") } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt index 93341255c..1a8bf7cb3 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt @@ -28,7 +28,7 @@ class KillTestImpl { class SingletonImpl : KillTest.Singleton { override suspend fun recursiveCall() { val awakeable = awakeable(Serde.RAW) - toVirtualObject(key()).request { it.hold(awakeable.id) }.send() + toVirtualObject(key()).request { hold(awakeable.id) }.send() awakeable.await() virtualObject(key()).recursiveCall() diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt index e1c56ab72..3907e3fa9 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt @@ -44,9 +44,9 @@ class NonDeterministicImpl : NonDeterministic { override suspend fun backgroundInvokeWithDifferentTargets() { if (doLeftAction()) { - toVirtualObject("abc").request { it.get() }.send() + toVirtualObject("abc").request { get() }.send() } else { - toVirtualObject("abc").request { it.reset() }.send() + toVirtualObject("abc").request { reset() }.send() } // This is required to cause a suspension after the non-deterministic operation sleep(100.milliseconds) @@ -65,7 +65,7 @@ class NonDeterministicImpl : NonDeterministic { } private suspend fun incrementCounter() { - toVirtualObject("abc").request { it.add(1) }.send() + toVirtualObject("abc").request { add(1) }.send() } private suspend fun doLeftAction(): Boolean { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt index 86da8cd09..ed4ee74e4 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt @@ -116,14 +116,14 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } is CallService -> { val expected = "hello-$i" - val awaitable = toService().request { it.echo(expected) }.call() + val awaitable = toService().request { echo(expected) }.call() promises[i] = { checkAwaitable(awaitable, expected, i, cmd) } } is CallSlowService -> { val expected = "hello-$i" val awaitable = toService() - .request { it.echoLater(EchoLaterRequest(cmd.sleep, expected)) } + .request { echoLater(EchoLaterRequest(cmd.sleep, expected)) } .call() promises[i] = { checkAwaitable(awaitable, expected, i, cmd) } } @@ -138,7 +138,7 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } is IncrementStateCounterIndirectly -> { toService() - .request { it.incrementIndirectly(interpreterId()) } + .request { incrementIndirectly(interpreterId()) } .send() } is IncrementStateCounterViaAwakeable -> { @@ -146,7 +146,7 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { val awakeable = awakeable() toService() .request { - it.incrementViaAwakeableDance( + incrementViaAwakeableDance( IncrementViaAwakeableDanceRequest(interpreterId(), awakeable.id) ) } @@ -156,7 +156,7 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } is IncrementViaDelayedCall -> { toService() - .request { it.incrementIndirectly(interpreterId()) } + .request { incrementIndirectly(interpreterId()) } .send(delay = cmd.duration.milliseconds) } is RecoverTerminalCall -> { @@ -174,18 +174,18 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } is RecoverTerminalCallMaybeUnAwaited -> { val awaitable = - toService().request { it.terminalFailure() }.call() + toService().request { terminalFailure() }.call() promises[i] = { checkAwaitableFails(awaitable, i, cmd) } } is RejectAwakeable -> { val awakeable = awakeable() promises[i] = { checkAwaitableFails(awakeable, i, cmd) } - toService().request { it.rejectAwakeable(awakeable.id) }.send() + toService().request { rejectAwakeable(awakeable.id) }.send() } is ResolveAwakeable -> { val awakeable = awakeable() promises[i] = { checkAwaitable(awakeable, "ok", i, cmd) } - toService().request { it.resolveAwakeable(awakeable.id) }.send() + toService().request { resolveAwakeable(awakeable.id) }.send() } is SetState -> { state().set(cmdStateKey(cmd.key), "value-${cmd.key}") From 0fb331cb364f0413f2dd30127b2cbb65392b98ac Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 30 Jan 2026 20:54:33 +0100 Subject: [PATCH 04/10] Fix the RAW content type --- .../reflections/ReflectionDiscoveryTest.kt | 20 +++++++++---------- .../restate/sdk/testservices/interpreter.kt | 3 +-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt index 5ca02abfa..b8098ba95 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionDiscoveryTest.kt @@ -53,21 +53,21 @@ class ReflectionDiscoveryTest { @Test fun checkRawInputContentType() { assertThatDiscovery(RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawInput") - .extracting({ it.input }, type(Input::class.java)) - .extracting { it.contentType } - .isEqualTo(Serde.RAW.contentType()) + .extractingService("RawInputOutput") + .extractingHandler("rawInput") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo(Serde.RAW.contentType()) } @Test fun checkRawOutputContentType() { assertThatDiscovery(RawInputOutput()) - .extractingService("RawInputOutput") - .extractingHandler("rawOutput") - .extracting({ it.output }, type(Output::class.java)) - .extracting { it.contentType } - .isEqualTo(Serde.RAW.contentType()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutput") + .extracting({ it.output }, type(Output::class.java)) + .extracting { it.contentType } + .isEqualTo(Serde.RAW.contentType()) } @Test diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt index ed4ee74e4..44eb387cd 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt @@ -173,8 +173,7 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } } is RecoverTerminalCallMaybeUnAwaited -> { - val awaitable = - toService().request { terminalFailure() }.call() + val awaitable = toService().request { terminalFailure() }.call() promises[i] = { checkAwaitableFails(awaitable, i, cmd) } } is RejectAwakeable -> { From ac4951f3be475ea229f31d335c071170dc7a5c53 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 2 Feb 2026 09:15:09 +0100 Subject: [PATCH 05/10] Fixed problem with handling of parameter annotations --- .../common/reflections/ReflectionUtils.java | 9 +-- .../ReflectionServiceDefinitionFactory.kt | 70 ++++++++----------- 2 files changed, 34 insertions(+), 45 deletions(-) diff --git a/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java index 15fd0e33f..da5e27e6b 100644 --- a/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java +++ b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java @@ -35,10 +35,11 @@ public record HandlerInfo(String name, boolean shared) {} return getUniqueDeclaredMethods( restateAnnotatedClazz, method -> - method.getDeclaredAnnotation(Handler.class) != null - || method.getDeclaredAnnotation(Shared.class) != null - || method.getDeclaredAnnotation(Workflow.class) != null - || method.getDeclaredAnnotation(Exclusive.class) != null); + !Modifier.isStatic(method.getModifiers()) + && (method.getDeclaredAnnotation(Handler.class) != null + || method.getDeclaredAnnotation(Shared.class) != null + || method.getDeclaredAnnotation(Workflow.class) != null + || method.getDeclaredAnnotation(Exclusive.class) != null)); } /** Find the class where the Restate annotations are declared. */ diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt index c63582fee..b3b3cf9ed 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt @@ -9,25 +9,10 @@ package dev.restate.sdk.kotlin.internal import dev.restate.common.reflections.ReflectionUtils -import dev.restate.sdk.annotation.Accept -import dev.restate.sdk.annotation.CustomSerdeFactory -import dev.restate.sdk.annotation.Exclusive -import dev.restate.sdk.annotation.Handler -import dev.restate.sdk.annotation.Json -import dev.restate.sdk.annotation.Raw -import dev.restate.sdk.annotation.Shared -import dev.restate.sdk.annotation.Workflow -import dev.restate.sdk.endpoint.definition.HandlerDefinition +import dev.restate.sdk.annotation.* +import dev.restate.sdk.endpoint.definition.* import dev.restate.sdk.endpoint.definition.HandlerRunner -import dev.restate.sdk.endpoint.definition.HandlerType -import dev.restate.sdk.endpoint.definition.ServiceDefinition -import dev.restate.sdk.endpoint.definition.ServiceDefinitionFactory -import dev.restate.sdk.endpoint.definition.ServiceType -import dev.restate.sdk.kotlin.Context -import dev.restate.sdk.kotlin.ObjectContext -import dev.restate.sdk.kotlin.SharedObjectContext -import dev.restate.sdk.kotlin.SharedWorkflowContext -import dev.restate.sdk.kotlin.WorkflowContext +import dev.restate.sdk.kotlin.* import dev.restate.serde.Serde import dev.restate.serde.SerdeFactory import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory @@ -35,10 +20,12 @@ import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory.KtTypeTag import dev.restate.serde.provider.DefaultSerdeFactoryProvider import java.lang.reflect.Modifier import java.util.* +import kotlin.reflect.KClass import kotlin.reflect.KFunction import kotlin.reflect.KVisibility import kotlin.reflect.full.callSuspend import kotlin.reflect.full.findAnnotation +import kotlin.reflect.full.hasAnnotation import kotlin.reflect.full.memberFunctions import kotlin.reflect.full.valueParameters import kotlin.reflect.jvm.javaMethod @@ -66,15 +53,23 @@ internal class ReflectionServiceDefinitionFactory : ServiceDefinitionFactory = serviceInstance.javaClass - val hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(serviceClazz) - val hasVirtualObjectAnnotation = ReflectionUtils.hasVirtualObjectAnnotation(serviceClazz) - val hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(serviceClazz) + // The behavior of the reflections work as follows: + // * There is one class that has all the restate annotations. That being either the serviceClazz + // itself (concrete class) or some interface in the hierarchy. + // * Then there is the serviceInstance, which is where we call the methods themselves. + val restateAnnotatedClazz = ReflectionUtils.findRestateAnnotatedClass(serviceClazz) + val restateAnnotatedKotlinClazz = restateAnnotatedClazz.kotlin + + val hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(restateAnnotatedClazz) + val hasVirtualObjectAnnotation = + ReflectionUtils.hasVirtualObjectAnnotation(restateAnnotatedClazz) + val hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(restateAnnotatedClazz) val hasAnyAnnotation = hasServiceAnnotation || hasVirtualObjectAnnotation || hasWorkflowAnnotation if (!hasAnyAnnotation) { throw MalformedRestateServiceException( - serviceClazz.simpleName, + restateAnnotatedClazz.simpleName, "A restate component MUST be annotated with " + "exactly one annotation between @Service/@VirtualObject/@Workflow, no annotation was found", ) @@ -84,27 +79,24 @@ internal class ReflectionServiceDefinitionFactory : ServiceDefinitionFactory - // Can't use findAnnotations because that won't walk the stack! - ReflectionUtils.findAnnotation(method, Handler::class.java) != null || - ReflectionUtils.findAnnotation(method, Shared::class.java) != null || - ReflectionUtils.findAnnotation(method, Workflow::class.java) != null || - ReflectionUtils.findAnnotation(method, Exclusive::class.java) != null - } ?: false + restateAnnotatedKotlinClazz.memberFunctions.filter { + it.hasAnnotation() || + it.hasAnnotation() || + it.hasAnnotation() || + it.hasAnnotation() } if (kFunctions.isEmpty()) { @@ -351,20 +343,16 @@ internal class ReflectionServiceDefinitionFactory : ServiceDefinitionFactory): SerdeFactory { + private fun resolveSerdeFactory(serviceClazz: KClass<*>): SerdeFactory { // Check for CustomSerdeFactory annotation - val customSerdeFactoryAnnotation: CustomSerdeFactory? = - ReflectionUtils.findAnnotation( - serviceClazz, - CustomSerdeFactory::class.java, - ) + val customSerdeFactoryAnnotation = serviceClazz.findAnnotation() if (customSerdeFactoryAnnotation != null) { try { return customSerdeFactoryAnnotation.value.java.getDeclaredConstructor().newInstance() } catch (e: Exception) { throw MalformedRestateServiceException( - serviceClazz.simpleName, + serviceClazz.simpleName!!, "Failed to instantiate custom SerdeFactory: ${customSerdeFactoryAnnotation.value.java.name}", e, ) @@ -392,7 +380,7 @@ internal class ReflectionServiceDefinitionFactory : ServiceDefinitionFactory Date: Mon, 2 Feb 2026 11:44:54 +0100 Subject: [PATCH 06/10] Fix exception handling --- .../internal/ReflectionServiceDefinitionFactory.kt | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt index b3b3cf9ed..4c9815ec2 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/internal/ReflectionServiceDefinitionFactory.kt @@ -18,6 +18,7 @@ import dev.restate.serde.SerdeFactory import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory.KtTypeTag import dev.restate.serde.provider.DefaultSerdeFactoryProvider +import java.lang.reflect.InvocationTargetException import java.lang.reflect.Modifier import java.util.* import kotlin.reflect.KClass @@ -231,10 +232,14 @@ internal class ReflectionServiceDefinitionFactory : ServiceDefinitionFactory - if (parameterCount == 0) { - kFunction.callSuspend(serviceInstance) - } else { - kFunction.callSuspend(serviceInstance, input) + try { + if (parameterCount == 0) { + kFunction.callSuspend(serviceInstance) + } else { + kFunction.callSuspend(serviceInstance, input) + } + } catch (t: InvocationTargetException) { + throw t.cause!! } } } From 628d23b8cd1a57cd8271c8a71550a582ba7a8737 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 2 Feb 2026 13:54:46 +0100 Subject: [PATCH 07/10] Better error message for SerdeFactory --- common/src/main/java/dev/restate/serde/SerdeFactory.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/dev/restate/serde/SerdeFactory.java b/common/src/main/java/dev/restate/serde/SerdeFactory.java index 913774c62..4ee8c08e7 100644 --- a/common/src/main/java/dev/restate/serde/SerdeFactory.java +++ b/common/src/main/java/dev/restate/serde/SerdeFactory.java @@ -41,8 +41,10 @@ default Serde create(TypeTag typeTag) { return this.create(tClass.type()); } else if (typeTag instanceof TypeRef tTypeRef) { return this.create(tTypeRef); - } else { + } else if (typeTag instanceof Serde) { return ((Serde) typeTag); + } else { + throw new IllegalArgumentException("TypeTag not supported by this SerdeFactory: " + typeTag); } } From 26e56879e061f431e6bf1536ca7d04d528f0d74a Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 2 Feb 2026 14:12:35 +0100 Subject: [PATCH 08/10] Change naming --- .../src/main/kotlin/dev/restate/client/kotlin/ingress.kt | 6 +++--- .../src/main/kotlin/dev/restate/sdk/kotlin/api.kt | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt index 34a5ed1a8..22275b590 100644 --- a/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt +++ b/client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt @@ -445,7 +445,7 @@ internal constructor( * ```kotlin * client.toService() * .request { add(1) } - * .withOptions { idempotencyKey = "123" } + * .options { idempotencyKey = "123" } * .call() * ``` * @@ -461,7 +461,7 @@ interface KClientRequest : Request { * @param block builder block for options * @return a new request with the configured options */ - fun withOptions(block: InvocationOptions.Builder.() -> Unit): KClientRequest + fun options(block: InvocationOptions.Builder.() -> Unit): KClientRequest /** * Call the target handler and wait for the response. @@ -553,7 +553,7 @@ private class KClientRequestImpl( private val request: Request, ) : KClientRequest, Request by request { - override fun withOptions(block: InvocationOptions.Builder.() -> Unit): KClientRequest { + override fun options(block: InvocationOptions.Builder.() -> Unit): KClientRequest { val builder = InvocationOptions.builder() builder.block() return KClientRequestImpl( diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index 6d51ac5c1..c53196253 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -1143,7 +1143,7 @@ private class KotlinStateImpl( * ```kotlin * toService() * .request { add(1) } - * .withOptions { idempotencyKey = "123" } + * .options { idempotencyKey = "123" } * .call() * ``` * @@ -1160,7 +1160,7 @@ interface KRequest : Request { * @return a new request with the configured options */ @org.jetbrains.annotations.ApiStatus.Experimental - fun withOptions(block: InvocationOptions.Builder.() -> Unit): KRequest + fun options(block: InvocationOptions.Builder.() -> Unit): KRequest /** * Call the target handler and return a [CallDurableFuture] for the result. @@ -1297,7 +1297,7 @@ inline fun toWorkflow(key: String): KRequestBuilder { /** Implementation of [KRequest] for SDK context. */ private class KRequestImpl(private val request: Request) : KRequest, Request by request { - override fun withOptions(block: InvocationOptions.Builder.() -> Unit): KRequest { + override fun options(block: InvocationOptions.Builder.() -> Unit): KRequest { val builder = InvocationOptions.builder() builder.block() return KRequestImpl( From 81f7b745670b754851285a5d9c355cb4d6365951 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 2 Feb 2026 14:30:49 +0100 Subject: [PATCH 09/10] Fix issue with null headers --- .../src/main/java/dev/restate/common/RequestBuilder.java | 4 ++-- common/src/main/java/dev/restate/common/RequestImpl.java | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/common/src/main/java/dev/restate/common/RequestBuilder.java b/common/src/main/java/dev/restate/common/RequestBuilder.java index 15efea398..ae9643d41 100644 --- a/common/src/main/java/dev/restate/common/RequestBuilder.java +++ b/common/src/main/java/dev/restate/common/RequestBuilder.java @@ -18,7 +18,7 @@ public interface RequestBuilder extends Request { * @param idempotencyKey Idempotency key to attach in the request. * @return this instance, so the builder can be used fluently. */ - RequestBuilder idempotencyKey(String idempotencyKey); + RequestBuilder idempotencyKey(@Nullable String idempotencyKey); /** * @param idempotencyKey Idempotency key to attach in the request. @@ -40,7 +40,7 @@ public interface RequestBuilder extends Request { * @param newHeaders headers to send together with the request. * @return this instance, so the builder can be used fluently. */ - RequestBuilder headers(Map newHeaders); + RequestBuilder headers(@Nullable Map newHeaders); /** * @param headers headers to send together with the request. This will overwrite the already diff --git a/common/src/main/java/dev/restate/common/RequestImpl.java b/common/src/main/java/dev/restate/common/RequestImpl.java index 6664bcdcb..41faf3b4a 100644 --- a/common/src/main/java/dev/restate/common/RequestImpl.java +++ b/common/src/main/java/dev/restate/common/RequestImpl.java @@ -155,10 +155,12 @@ public Builder header(String key, String value) { */ @Override public Builder headers(Map newHeaders) { - if (this.headers == null) { - this.headers = new LinkedHashMap<>(); + if (newHeaders != null) { + if (this.headers == null) { + this.headers = new LinkedHashMap<>(); + } + this.headers.putAll(newHeaders); } - this.headers.putAll(newHeaders); return this; } From 3369a44b0f3be8e2e5b84b8420758bf3a0fd10bb Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Mon, 2 Feb 2026 14:48:09 +0100 Subject: [PATCH 10/10] Split key in objectKey/workflowKey --- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 34 ++++++++++++++++--- .../core/kotlinapi/reflections/testClasses.kt | 10 +++--- .../sdk/testservices/CancelTestImpl.kt | 6 ++-- .../restate/sdk/testservices/CounterImpl.kt | 2 +- .../restate/sdk/testservices/KillTestImpl.kt | 6 ++-- .../sdk/testservices/NonDeterministicImpl.kt | 2 +- .../restate/sdk/testservices/interpreter.kt | 2 +- 7 files changed, 43 insertions(+), 19 deletions(-) diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index c53196253..663825ce0 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -912,23 +912,47 @@ suspend inline fun invocationHandle(invocationId: String): In } /** - * Get the key of this Virtual Object or Workflow. + * Get the key of this Virtual Object. * * @return the key of this object * @throws IllegalStateException if called from a regular Service handler or outside of a Restate * handler */ @org.jetbrains.annotations.ApiStatus.Experimental -suspend fun key(): String { +suspend fun objectKey(): String { val ctx = context() val handlerContext = dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() - ?: error("key() must be called from within a Restate handler") + ?: error("objectKey() must be called from within a Restate handler") if (!handlerContext.canReadState()) { error( - "key() can be used only within Virtual Object or Workflow handlers. " + - "Check https://docs.restate.dev/develop/java/state for more details." + "objectKey() can be used only within Virtual Object handlers. " + + "Check https://docs.restate.dev/develop/java/services#virtual-objects for more details." + ) + } + + return (ctx as SharedObjectContext).key() +} + +/** + * Get the key of this Workflow. + * + * @return the key of this workflow + * @throws IllegalStateException if called from a regular Service handler, or from a virtual object + * handler, or outside of a Restate handler + */ +@org.jetbrains.annotations.ApiStatus.Experimental +suspend fun workflowKey(): String { + val ctx = context() + val handlerContext = + dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get() + ?: error("workflowKey() must be called from within a Restate handler") + + if (!handlerContext.canReadPromises()) { + error( + "workflowKey() can be used only within Workflow handlers. " + + "Check https://docs.restate.dev/develop/java/services#workflows for more details." ) } diff --git a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt index 5c87fc83c..8fa6b59fa 100644 --- a/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt +++ b/sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt @@ -63,7 +63,7 @@ interface GreeterInterface { class ObjectGreeterImplementedFromInterface : GreeterInterface { override suspend fun greet(request: String): String { - return virtualObject(key()).greet(request) + return virtualObject(objectKey()).greet(request) } } @@ -105,12 +105,12 @@ open class CornerCases { @Exclusive open suspend fun returnNull(request: String?): String? { - return virtualObject(key()).returnNull(request) + return virtualObject(objectKey()).returnNull(request) } @Exclusive open suspend fun badReturnTypeInferred(): Unit { - toVirtualObject(key()).request { badReturnTypeInferred() }.send() + toVirtualObject(objectKey()).request { badReturnTypeInferred() }.send() } } @@ -149,12 +149,12 @@ open class RawInputOutput { open class MyWorkflow { @Workflow open suspend fun run(myInput: String) { - toWorkflow(key()).request { sharedHandler(myInput) }.send() + toWorkflow(workflowKey()).request { sharedHandler(myInput) }.send() } @Handler open suspend fun sharedHandler(myInput: String): String = - workflow(key()).sharedHandler(myInput) + workflow(workflowKey()).sharedHandler(myInput) } @Suppress("UNCHECKED_CAST") diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt index dffae4ed5..628009863 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CancelTestImpl.kt @@ -23,7 +23,7 @@ class CancelTestImpl { override suspend fun startTest(operation: CancelTest.BlockingOperation) { try { - virtualObject(key()).block(operation) + virtualObject(objectKey()).block(operation) } catch (e: TerminalException) { if (e.code == TerminalException.CANCELLED_CODE) { state().set(CANCELED_STATE, true) @@ -40,8 +40,8 @@ class CancelTestImpl { class BlockingService : CancelTest.BlockingService { override suspend fun block(operation: CancelTest.BlockingOperation) { - val self = virtualObject(key()) - val awakeableHolder = virtualObject(key()) + val self = virtualObject(objectKey()) + val awakeableHolder = virtualObject(objectKey()) val awakeable = awakeable() awakeableHolder.hold(awakeable.id) diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt index a1cb419b0..ce61ed24c 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/CounterImpl.kt @@ -37,7 +37,7 @@ class CounterImpl : Counter { logger.info("New counter value: {}", counter) - throw TerminalException(key()) + throw TerminalException(objectKey()) } override suspend fun get(): Long { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt index 1a8bf7cb3..2bfc4ec42 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/KillTestImpl.kt @@ -21,17 +21,17 @@ class KillTestImpl { // in the inbox: // startCallTree --> recursiveCall --> recursiveCall:inboxed override suspend fun startCallTree() { - virtualObject(key()).recursiveCall() + virtualObject(objectKey()).recursiveCall() } } class SingletonImpl : KillTest.Singleton { override suspend fun recursiveCall() { val awakeable = awakeable(Serde.RAW) - toVirtualObject(key()).request { hold(awakeable.id) }.send() + toVirtualObject(objectKey()).request { hold(awakeable.id) }.send() awakeable.await() - virtualObject(key()).recursiveCall() + virtualObject(objectKey()).recursiveCall() } override suspend fun isUnlocked() { diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt index 3907e3fa9..20802b3ab 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/NonDeterministicImpl.kt @@ -70,7 +70,7 @@ class NonDeterministicImpl : NonDeterministic { private suspend fun doLeftAction(): Boolean { // Test runner sets an appropriate key here - val countKey = key() + val countKey = objectKey() return invocationCounts.merge(countKey, 1) { a: Int, b: Int -> a + b }!! % 2 == 1 } } diff --git a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt index 44eb387cd..0d8add183 100644 --- a/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt +++ b/test-services/src/main/kotlin/dev/restate/sdk/testservices/interpreter.kt @@ -81,7 +81,7 @@ class ObjectInterpreterImpl(private val layer: Int) : ObjectInterpreter { } private suspend fun interpreterId(): InterpreterId { - return InterpreterId(layer, key()) + return InterpreterId(layer, objectKey()) } override suspend fun counter(): Int {