diff --git a/src/client.rs b/src/client.rs index ad7a23d..705e65f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -342,17 +342,7 @@ where params: JsonValue, options: SpawnOptions, ) -> DurableResult { - // Validate that the task is registered - { - let registry = self.registry.read().await; - if !registry.contains_key(task_name) { - return Err(DurableError::TaskNotRegistered { - task_name: task_name.to_string(), - }); - } - } - - self.spawn_by_name_internal(&self.pool, task_name, params, options) + self.spawn_by_name_with(&self.pool, task_name, params, options) .await } @@ -432,11 +422,16 @@ where // Validate that the task is registered { let registry = self.registry.read().await; - if !registry.contains_key(task_name) { + let Some(task) = registry.get(task_name) else { return Err(DurableError::TaskNotRegistered { task_name: task_name.to_string(), }); - } + }; + task.validate_params(params.clone()) + .map_err(|e| DurableError::InvalidTaskParams { + task_name: task_name.to_string(), + message: e.to_string(), + })?; } self.spawn_by_name_internal(executor, task_name, params, options) diff --git a/src/error.rs b/src/error.rs index f55fcfc..8972ad5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -311,6 +311,18 @@ pub enum DurableError { task_name: String, }, + //// Task params validation failed. + /// + /// Returned when the task definition in the registry fails to validate the params + /// (before we attempt to spawn the task in Postgres). + #[error("invalid task parameters for '{task_name}': {message}")] + InvalidTaskParams { + /// The name of the task being spawned + task_name: String, + /// The error message from the task. + message: String, + }, + /// Header key uses a reserved prefix. /// /// User-provided headers cannot start with "durable::" as this prefix diff --git a/src/task.rs b/src/task.rs index 54f6019..c857c6d 100644 --- a/src/task.rs +++ b/src/task.rs @@ -109,6 +109,8 @@ where State: Clone + Send + Sync + 'static, { fn name(&self) -> Cow<'static, str>; + /// Called before spawning, to check that the `params` are valid for this task. + fn validate_params(&self, params: JsonValue) -> Result<(), TaskError>; async fn execute( &self, params: JsonValue, @@ -127,6 +129,12 @@ where T::name() } + fn validate_params(&self, params: JsonValue) -> Result<(), TaskError> { + // For now, just deserialize + let _typed_params: T::Params = serde_json::from_value(params)?; + Ok(()) + } + async fn execute( &self, params: JsonValue, diff --git a/tests/spawn_test.rs b/tests/spawn_test.rs index 5a79dd2..247f840 100644 --- a/tests/spawn_test.rs +++ b/tests/spawn_test.rs @@ -3,7 +3,7 @@ mod common; use common::tasks::{EchoParams, EchoTask, FailingParams, FailingTask}; -use durable::{CancellationPolicy, Durable, MIGRATOR, RetryStrategy, SpawnOptions}; +use durable::{CancellationPolicy, Durable, DurableError, MIGRATOR, RetryStrategy, SpawnOptions}; use sqlx::PgPool; use std::collections::HashMap; use std::time::Duration; @@ -270,6 +270,33 @@ async fn test_spawn_by_name(pool: PgPool) -> sqlx::Result<()> { Ok(()) } +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_spawn_by_name_invalid_params(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "spawn_by_name").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + let params = serde_json::json!({ + "message": 12345 + }); + + let result = client + .spawn_by_name("echo", params, SpawnOptions::default()) + .await + .expect_err("Spawning task by name with invalid params should fail"); + + let DurableError::InvalidTaskParams { task_name, message } = result else { + panic!("Unexpected error: {}", result); + }; + assert_eq!(task_name, "echo"); + assert_eq!( + message, + "serialization error: invalid type: integer `12345`, expected a string" + ); + + Ok(()) +} + #[sqlx::test(migrator = "MIGRATOR")] async fn test_spawn_by_name_with_options(pool: PgPool) -> sqlx::Result<()> { let client = create_client(pool.clone(), "spawn_by_name_opts").await; @@ -308,9 +335,10 @@ async fn test_spawn_with_empty_params(pool: PgPool) -> sqlx::Result<()> { client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); - // Empty object is valid JSON params for EchoTask (message will be missing but that's ok for this test) + // Empty object is not valid JSON params for EchoTask, + // but spawn_by_name_unchecked does not validate the JSON let result = client - .spawn_by_name("echo", serde_json::json!({}), SpawnOptions::default()) + .spawn_by_name_unchecked("echo", serde_json::json!({}), SpawnOptions::default()) .await .expect("Failed to spawn task with empty params"); @@ -326,7 +354,8 @@ async fn test_spawn_with_complex_params(pool: PgPool) -> sqlx::Result<()> { client.register::().await.unwrap(); // Complex nested JSON structure - the params don't need to match the task's Params type - // because spawn_by_name accepts arbitrary JSON + // because spawn_by_name_unchecked does not validate the JSON + // (unlike `spawn_by_name`) let params = serde_json::json!({ "nested": { "array": [1, 2, 3], @@ -341,7 +370,7 @@ async fn test_spawn_with_complex_params(pool: PgPool) -> sqlx::Result<()> { }); let result = client - .spawn_by_name("echo", params, SpawnOptions::default()) + .spawn_by_name_unchecked("echo", params, SpawnOptions::default()) .await .expect("Failed to spawn task with complex params");