From d921e8edaf2a10a1d11b1e35fa6ef822ed510ca3 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Thu, 1 Jan 2026 17:34:34 -0500 Subject: [PATCH 1/4] propagate default num attempts to spawned tasks --- src/client.rs | 1 + src/context.rs | 12 +++++++++++ src/worker.rs | 8 +++++++ tests/fanout_test.rs | 51 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 84042b9..7c3f08b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -617,6 +617,7 @@ where self.registry.clone(), options, self.state.clone(), + self.default_max_attempts, ) .await) } diff --git a/src/context.rs b/src/context.rs index 3bb942d..2e43237 100644 --- a/src/context.rs +++ b/src/context.rs @@ -72,6 +72,9 @@ where /// Task registry for validating spawn_by_name calls. registry: Arc>>, + + /// Default max attempts for subtasks spawned via spawn_by_name. + default_max_attempts: u32, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -98,6 +101,7 @@ where lease_extender: LeaseExtender, registry: Arc>>, state: State, + default_max_attempts: u32, ) -> Result { // Load all checkpoints for this task into cache let checkpoints: Vec = sqlx::query_as( @@ -127,6 +131,7 @@ where lease_extender, registry, state, + default_max_attempts, }) } @@ -668,6 +673,12 @@ where } } + // Apply default max_attempts if not set + let options = SpawnOptions { + max_attempts: Some(options.max_attempts.unwrap_or(self.default_max_attempts)), + ..options + }; + // Build options JSON, merging user options with parent_task_id #[derive(Serialize)] struct SubtaskOptions<'a> { @@ -844,6 +855,7 @@ mod tests { LeaseExtender::dummy_for_tests(), Arc::new(RwLock::new(TaskRegistry::new())), (), + 5, // default_max_attempts ) .await .unwrap(); diff --git a/src/worker.rs b/src/worker.rs index 8714715..f286ff0 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -67,6 +67,7 @@ impl Worker { registry: Arc>>, options: WorkerOptions, state: State, + default_max_attempts: u32, ) -> Self where State: Clone + Send + Sync + 'static, @@ -92,6 +93,7 @@ impl Worker { worker_id, shutdown_rx, state, + default_max_attempts, )); Self { @@ -117,6 +119,7 @@ impl Worker { worker_id: String, mut shutdown_rx: broadcast::Receiver<()>, state: State, + default_max_attempts: u32, ) where State: Clone + Send + Sync + 'static, { @@ -200,6 +203,7 @@ impl Worker { claim_timeout, fatal_on_lease_timeout, state, + default_max_attempts, ).await; drop(permit); @@ -266,6 +270,7 @@ impl Worker { claim_timeout: Duration, fatal_on_lease_timeout: bool, state: State, + default_max_attempts: u32, ) where State: Clone + Send + Sync + 'static, { @@ -295,6 +300,7 @@ impl Worker { claim_timeout, fatal_on_lease_timeout, state, + default_max_attempts, ) .instrument(span) .await @@ -308,6 +314,7 @@ impl Worker { claim_timeout: Duration, fatal_on_lease_timeout: bool, state: State, + default_max_attempts: u32, ) where State: Clone + Send + Sync + 'static, { @@ -333,6 +340,7 @@ impl Worker { lease_extender, registry.clone(), state.clone(), + default_max_attempts, ) .await { diff --git a/tests/fanout_test.rs b/tests/fanout_test.rs index 4a41267..5a75d93 100644 --- a/tests/fanout_test.rs +++ b/tests/fanout_test.rs @@ -453,9 +453,38 @@ async fn test_cascade_cancel_when_parent_auto_cancelled_by_max_duration( // spawn_by_name Tests // ============================================================================ +/// Helper to query max_attempts from the database. +async fn get_task_max_attempts( + pool: &PgPool, + queue_name: &str, + task_id: uuid::Uuid, +) -> Option { + #[derive(sqlx::FromRow)] + struct TaskMaxAttempts { + max_attempts: Option, + } + let query = AssertSqlSafe(format!( + "SELECT max_attempts FROM durable.t_{queue_name} WHERE task_id = $1" + )); + let result: TaskMaxAttempts = sqlx::query_as(query) + .bind(task_id) + .fetch_one(pool) + .await + .expect("Failed to query task max_attempts"); + result.max_attempts +} + #[sqlx::test(migrator = "MIGRATOR")] async fn test_spawn_by_name_from_task_context(pool: PgPool) -> sqlx::Result<()> { - let client = create_client(pool.clone(), "fanout_by_name").await; + // Use custom default_max_attempts to verify subtasks inherit it + let client = Durable::builder() + .pool(pool.clone()) + .queue_name("fanout_by_name") + .default_max_attempts(7) // Custom default to verify inheritance + .build() + .await + .expect("Failed to create client"); + client.create_queue(None).await.unwrap(); client.register::().await.unwrap(); client.register::().await.unwrap(); @@ -497,6 +526,24 @@ async fn test_spawn_by_name_from_task_context(pool: PgPool) -> sqlx::Result<()> "Child should have doubled 21 to 42 (spawned via spawn_by_name)" ); + // Find the child task and verify it inherited the default_max_attempts + let child_query = "SELECT task_id FROM durable.t_fanout_by_name WHERE parent_task_id = $1"; + let child_ids: Vec<(uuid::Uuid,)> = sqlx::query_as(child_query) + .bind(spawn_result.task_id) + .fetch_all(&pool) + .await?; + + assert_eq!(child_ids.len(), 1, "Should have exactly one child task"); + let child_task_id = child_ids[0].0; + + // Verify child task has the default max_attempts from the client config + let child_max_attempts = get_task_max_attempts(&pool, "fanout_by_name", child_task_id).await; + assert_eq!( + child_max_attempts, + Some(7), + "Child task spawned via spawn_by_name should inherit default_max_attempts=7" + ); + Ok(()) } @@ -728,9 +775,11 @@ async fn test_join_timeout_when_parent_claim_expires(pool: PgPool) -> sqlx::Resu let error_name = failed_payload.get("name").and_then(|v| v.as_str()); // Could be Timeout or other error depending on how the timeout manifests + // ChildFailed is also valid when child tasks have bounded max_attempts assert!( error_name == Some("Timeout") || error_name == Some("ChildCancelled") + || error_name == Some("ChildFailed") || error_name == Some("TaskInternal"), "Expected timeout-related error, got: {:?}", error_name From 5547d42766098906fd73ea4a100135bf6d6d89ff Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Thu, 1 Jan 2026 17:54:32 -0500 Subject: [PATCH 2/4] add full defaults for task spawning setting --- src/client.rs | 32 +++++++++++++++++++++++ src/context.rs | 26 ++++++++++++++++--- src/worker.rs | 20 ++++++++++++++- tests/fanout_test.rs | 61 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 131 insertions(+), 8 deletions(-) diff --git a/src/client.rs b/src/client.rs index 7c3f08b..fef15c5 100644 --- a/src/client.rs +++ b/src/client.rs @@ -111,6 +111,8 @@ where owns_pool: bool, queue_name: String, default_max_attempts: u32, + default_retry_strategy: Option, + default_cancellation: Option, registry: Arc>>, state: State, } @@ -139,6 +141,8 @@ pub struct DurableBuilder { pool: Option, queue_name: String, default_max_attempts: u32, + default_retry_strategy: Option, + default_cancellation: Option, } impl DurableBuilder { @@ -148,6 +152,8 @@ impl DurableBuilder { pool: None, queue_name: "default".to_string(), default_max_attempts: 5, + default_retry_strategy: None, + default_cancellation: None, } } @@ -175,6 +181,18 @@ impl DurableBuilder { self } + /// Set default retry strategy for spawned tasks (default: Fixed with 5s delay) + pub fn default_retry_strategy(mut self, strategy: RetryStrategy) -> Self { + self.default_retry_strategy = Some(strategy); + self + } + + /// Set default cancellation policy for spawned tasks (default: no auto-cancellation) + pub fn default_cancellation(mut self, policy: CancellationPolicy) -> Self { + self.default_cancellation = Some(policy); + self + } + /// Build the Durable client without application state. /// /// Use this when your tasks don't need access to shared resources @@ -227,6 +245,8 @@ impl DurableBuilder { owns_pool, queue_name: self.queue_name, default_max_attempts: self.default_max_attempts, + default_retry_strategy: self.default_retry_strategy, + default_cancellation: self.default_cancellation, registry: Arc::new(RwLock::new(HashMap::new())), state, }) @@ -439,7 +459,17 @@ where #[cfg(feature = "telemetry")] tracing::Span::current().record("queue", &self.queue_name); + // Apply defaults if not set let max_attempts = options.max_attempts.unwrap_or(self.default_max_attempts); + let options = SpawnOptions { + retry_strategy: options + .retry_strategy + .or_else(|| self.default_retry_strategy.clone()), + cancellation: options + .cancellation + .or_else(|| self.default_cancellation.clone()), + ..options + }; let db_options = Self::serialize_spawn_options(&options, max_attempts)?; @@ -618,6 +648,8 @@ where options, self.state.clone(), self.default_max_attempts, + self.default_retry_strategy.clone(), + self.default_cancellation.clone(), ) .await) } diff --git a/src/context.rs b/src/context.rs index 2e43237..cdf163d 100644 --- a/src/context.rs +++ b/src/context.rs @@ -11,8 +11,8 @@ use uuid::Uuid; use crate::error::{ControlFlow, TaskError, TaskResult}; use crate::task::{Task, TaskRegistry}; use crate::types::{ - AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnOptions, - SpawnResultRow, TaskHandle, + AwaitEventResult, CancellationPolicy, CheckpointRow, ChildCompletePayload, ChildStatus, + ClaimedTask, RetryStrategy, SpawnOptions, SpawnResultRow, TaskHandle, }; use crate::worker::LeaseExtender; @@ -75,6 +75,12 @@ where /// Default max attempts for subtasks spawned via spawn_by_name. default_max_attempts: u32, + + /// Default retry strategy for subtasks spawned via spawn_by_name. + default_retry_strategy: Option, + + /// Default cancellation policy for subtasks spawned via spawn_by_name. + default_cancellation: Option, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -102,6 +108,8 @@ where registry: Arc>>, state: State, default_max_attempts: u32, + default_retry_strategy: Option, + default_cancellation: Option, ) -> Result { // Load all checkpoints for this task into cache let checkpoints: Vec = sqlx::query_as( @@ -132,6 +140,8 @@ where registry, state, default_max_attempts, + default_retry_strategy, + default_cancellation, }) } @@ -673,9 +683,15 @@ where } } - // Apply default max_attempts if not set + // Apply defaults if not set let options = SpawnOptions { max_attempts: Some(options.max_attempts.unwrap_or(self.default_max_attempts)), + retry_strategy: options + .retry_strategy + .or_else(|| self.default_retry_strategy.clone()), + cancellation: options + .cancellation + .or_else(|| self.default_cancellation.clone()), ..options }; @@ -855,7 +871,9 @@ mod tests { LeaseExtender::dummy_for_tests(), Arc::new(RwLock::new(TaskRegistry::new())), (), - 5, // default_max_attempts + 5, // default_max_attempts + None, // default_retry_strategy + None, // default_cancellation ) .await .unwrap(); diff --git a/src/worker.rs b/src/worker.rs index f286ff0..6011c31 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -11,7 +11,7 @@ use uuid::Uuid; use crate::context::TaskContext; use crate::error::{ControlFlow, TaskError, serialize_task_error}; use crate::task::TaskRegistry; -use crate::types::{ClaimedTask, ClaimedTaskRow, WorkerOptions}; +use crate::types::{CancellationPolicy, ClaimedTask, ClaimedTaskRow, RetryStrategy, WorkerOptions}; /// Notifies the worker that the lease has been extended. /// Used by TaskContext to reset warning/fatal timers. @@ -68,6 +68,8 @@ impl Worker { options: WorkerOptions, state: State, default_max_attempts: u32, + default_retry_strategy: Option, + default_cancellation: Option, ) -> Self where State: Clone + Send + Sync + 'static, @@ -94,6 +96,8 @@ impl Worker { shutdown_rx, state, default_max_attempts, + default_retry_strategy, + default_cancellation, )); Self { @@ -120,6 +124,8 @@ impl Worker { mut shutdown_rx: broadcast::Receiver<()>, state: State, default_max_attempts: u32, + default_retry_strategy: Option, + default_cancellation: Option, ) where State: Clone + Send + Sync + 'static, { @@ -193,6 +199,8 @@ impl Worker { let registry = registry.clone(); let done_tx = done_tx.clone(); let state = state.clone(); + let default_retry_strategy = default_retry_strategy.clone(); + let default_cancellation = default_cancellation.clone(); tokio::spawn(async move { Self::execute_task( @@ -204,6 +212,8 @@ impl Worker { fatal_on_lease_timeout, state, default_max_attempts, + default_retry_strategy, + default_cancellation, ).await; drop(permit); @@ -271,6 +281,8 @@ impl Worker { fatal_on_lease_timeout: bool, state: State, default_max_attempts: u32, + default_retry_strategy: Option, + default_cancellation: Option, ) where State: Clone + Send + Sync + 'static, { @@ -301,6 +313,8 @@ impl Worker { fatal_on_lease_timeout, state, default_max_attempts, + default_retry_strategy, + default_cancellation, ) .instrument(span) .await @@ -315,6 +329,8 @@ impl Worker { fatal_on_lease_timeout: bool, state: State, default_max_attempts: u32, + default_retry_strategy: Option, + default_cancellation: Option, ) where State: Clone + Send + Sync + 'static, { @@ -341,6 +357,8 @@ impl Worker { registry.clone(), state.clone(), default_max_attempts, + default_retry_strategy, + default_cancellation, ) .await { diff --git a/tests/fanout_test.rs b/tests/fanout_test.rs index 5a75d93..9247516 100644 --- a/tests/fanout_test.rs +++ b/tests/fanout_test.rs @@ -8,7 +8,8 @@ use common::tasks::{ SpawnByNameParams, SpawnByNameTask, SpawnFailingChildTask, SpawnSlowChildParams, SpawnSlowChildTask, }; -use durable::{Durable, MIGRATOR, WorkerOptions}; +use durable::{CancellationPolicy, Durable, MIGRATOR, RetryStrategy, WorkerOptions}; +use serde_json::Value as JsonValue; use sqlx::{AssertSqlSafe, PgPool}; use std::time::Duration; @@ -474,13 +475,43 @@ async fn get_task_max_attempts( result.max_attempts } +/// Helper to query retry_strategy from the database. +async fn get_task_retry_strategy( + pool: &PgPool, + queue_name: &str, + task_id: uuid::Uuid, +) -> Option { + #[derive(sqlx::FromRow)] + struct TaskRetryStrategy { + retry_strategy: Option, + } + let query = AssertSqlSafe(format!( + "SELECT retry_strategy FROM durable.t_{queue_name} WHERE task_id = $1" + )); + let result: TaskRetryStrategy = sqlx::query_as(query) + .bind(task_id) + .fetch_one(pool) + .await + .expect("Failed to query task retry_strategy"); + result.retry_strategy +} + #[sqlx::test(migrator = "MIGRATOR")] async fn test_spawn_by_name_from_task_context(pool: PgPool) -> sqlx::Result<()> { - // Use custom default_max_attempts to verify subtasks inherit it + // Use custom defaults to verify subtasks inherit them let client = Durable::builder() .pool(pool.clone()) .queue_name("fanout_by_name") - .default_max_attempts(7) // Custom default to verify inheritance + .default_max_attempts(7) + .default_retry_strategy(RetryStrategy::Exponential { + base_delay: Duration::from_secs(10), + factor: 3.0, + max_backoff: Duration::from_secs(600), + }) + .default_cancellation(CancellationPolicy { + max_pending_time: Some(Duration::from_secs(3600)), + max_running_time: None, + }) .build() .await .expect("Failed to create client"); @@ -544,6 +575,30 @@ async fn test_spawn_by_name_from_task_context(pool: PgPool) -> sqlx::Result<()> "Child task spawned via spawn_by_name should inherit default_max_attempts=7" ); + // Verify child task has the default retry_strategy from the client config + let child_retry_strategy = + get_task_retry_strategy(&pool, "fanout_by_name", child_task_id).await; + assert!( + child_retry_strategy.is_some(), + "Child task should have a retry_strategy" + ); + let strategy = child_retry_strategy.unwrap(); + assert_eq!( + strategy.get("kind").and_then(|v| v.as_str()), + Some("exponential"), + "Child task should inherit exponential retry strategy" + ); + assert_eq!( + strategy.get("base_seconds").and_then(|v| v.as_u64()), + Some(10), + "Child task should inherit base_delay=10s" + ); + assert_eq!( + strategy.get("factor").and_then(|v| v.as_f64()), + Some(3.0), + "Child task should inherit factor=3.0" + ); + Ok(()) } From 2a67bcb9f1142d312a6b29422b4564d7c6b6c2ff Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Thu, 1 Jan 2026 20:51:09 -0500 Subject: [PATCH 3/4] added a SpawnDefaults struct --- src/client.rs | 53 +++++++++++++++++++++++++++++--------------------- src/context.rs | 41 ++++++++++++++++++-------------------- src/lib.rs | 4 ++-- src/types.rs | 15 ++++++++++++++ src/worker.rs | 40 +++++++++++++------------------------ 5 files changed, 80 insertions(+), 73 deletions(-) diff --git a/src/client.rs b/src/client.rs index fef15c5..b988e8b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,7 +11,8 @@ use uuid::Uuid; use crate::error::{DurableError, DurableResult}; use crate::task::{Task, TaskRegistry}; use crate::types::{ - CancellationPolicy, RetryStrategy, SpawnOptions, SpawnResult, SpawnResultRow, WorkerOptions, + CancellationPolicy, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, SpawnResultRow, + WorkerOptions, }; /// Internal struct for serializing spawn options to the database. @@ -110,9 +111,7 @@ where pool: PgPool, owns_pool: bool, queue_name: String, - default_max_attempts: u32, - default_retry_strategy: Option, - default_cancellation: Option, + spawn_defaults: SpawnDefaults, registry: Arc>>, state: State, } @@ -122,11 +121,23 @@ where /// # Example /// /// ```ignore +/// use std::time::Duration; +/// use durable::{Durable, RetryStrategy, CancellationPolicy}; +/// /// // Without state /// let client = Durable::builder() /// .database_url("postgres://localhost/myapp") /// .queue_name("orders") /// .default_max_attempts(3) +/// .default_retry_strategy(RetryStrategy::Exponential { +/// base_delay: Duration::from_secs(5), +/// factor: 2.0, +/// max_backoff: Duration::from_secs(300), +/// }) +/// .default_cancellation(CancellationPolicy { +/// max_pending_time: Some(Duration::from_secs(3600)), +/// max_running_time: None, +/// }) /// .build() /// .await?; /// @@ -140,9 +151,7 @@ pub struct DurableBuilder { database_url: Option, pool: Option, queue_name: String, - default_max_attempts: u32, - default_retry_strategy: Option, - default_cancellation: Option, + spawn_defaults: SpawnDefaults, } impl DurableBuilder { @@ -151,9 +160,11 @@ impl DurableBuilder { database_url: None, pool: None, queue_name: "default".to_string(), - default_max_attempts: 5, - default_retry_strategy: None, - default_cancellation: None, + spawn_defaults: SpawnDefaults { + max_attempts: 5, + retry_strategy: None, + cancellation: None, + }, } } @@ -177,19 +188,19 @@ impl DurableBuilder { /// Set default max attempts for spawned tasks (default: 5) pub fn default_max_attempts(mut self, attempts: u32) -> Self { - self.default_max_attempts = attempts; + self.spawn_defaults.max_attempts = attempts; self } /// Set default retry strategy for spawned tasks (default: Fixed with 5s delay) pub fn default_retry_strategy(mut self, strategy: RetryStrategy) -> Self { - self.default_retry_strategy = Some(strategy); + self.spawn_defaults.retry_strategy = Some(strategy); self } /// Set default cancellation policy for spawned tasks (default: no auto-cancellation) pub fn default_cancellation(mut self, policy: CancellationPolicy) -> Self { - self.default_cancellation = Some(policy); + self.spawn_defaults.cancellation = Some(policy); self } @@ -244,9 +255,7 @@ impl DurableBuilder { pool, owns_pool, queue_name: self.queue_name, - default_max_attempts: self.default_max_attempts, - default_retry_strategy: self.default_retry_strategy, - default_cancellation: self.default_cancellation, + spawn_defaults: self.spawn_defaults, registry: Arc::new(RwLock::new(HashMap::new())), state, }) @@ -460,14 +469,16 @@ where tracing::Span::current().record("queue", &self.queue_name); // Apply defaults if not set - let max_attempts = options.max_attempts.unwrap_or(self.default_max_attempts); + let max_attempts = options + .max_attempts + .unwrap_or(self.spawn_defaults.max_attempts); let options = SpawnOptions { retry_strategy: options .retry_strategy - .or_else(|| self.default_retry_strategy.clone()), + .or_else(|| self.spawn_defaults.retry_strategy.clone()), cancellation: options .cancellation - .or_else(|| self.default_cancellation.clone()), + .or_else(|| self.spawn_defaults.cancellation.clone()), ..options }; @@ -647,9 +658,7 @@ where self.registry.clone(), options, self.state.clone(), - self.default_max_attempts, - self.default_retry_strategy.clone(), - self.default_cancellation.clone(), + self.spawn_defaults.clone(), ) .await) } diff --git a/src/context.rs b/src/context.rs index cdf163d..825944b 100644 --- a/src/context.rs +++ b/src/context.rs @@ -11,8 +11,8 @@ use uuid::Uuid; use crate::error::{ControlFlow, TaskError, TaskResult}; use crate::task::{Task, TaskRegistry}; use crate::types::{ - AwaitEventResult, CancellationPolicy, CheckpointRow, ChildCompletePayload, ChildStatus, - ClaimedTask, RetryStrategy, SpawnOptions, SpawnResultRow, TaskHandle, + AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnDefaults, + SpawnOptions, SpawnResultRow, TaskHandle, }; use crate::worker::LeaseExtender; @@ -73,14 +73,8 @@ where /// Task registry for validating spawn_by_name calls. registry: Arc>>, - /// Default max attempts for subtasks spawned via spawn_by_name. - default_max_attempts: u32, - - /// Default retry strategy for subtasks spawned via spawn_by_name. - default_retry_strategy: Option, - - /// Default cancellation policy for subtasks spawned via spawn_by_name. - default_cancellation: Option, + /// Default settings for subtasks spawned via spawn/spawn_by_name. + spawn_defaults: SpawnDefaults, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -99,6 +93,7 @@ where { /// Create a new TaskContext. Called by the worker before executing a task. /// Loads all existing checkpoints into the cache. + #[allow(clippy::too_many_arguments)] pub(crate) async fn create( pool: PgPool, queue_name: String, @@ -107,9 +102,7 @@ where lease_extender: LeaseExtender, registry: Arc>>, state: State, - default_max_attempts: u32, - default_retry_strategy: Option, - default_cancellation: Option, + spawn_defaults: SpawnDefaults, ) -> Result { // Load all checkpoints for this task into cache let checkpoints: Vec = sqlx::query_as( @@ -139,9 +132,7 @@ where lease_extender, registry, state, - default_max_attempts, - default_retry_strategy, - default_cancellation, + spawn_defaults, }) } @@ -685,13 +676,17 @@ where // Apply defaults if not set let options = SpawnOptions { - max_attempts: Some(options.max_attempts.unwrap_or(self.default_max_attempts)), + max_attempts: Some( + options + .max_attempts + .unwrap_or(self.spawn_defaults.max_attempts), + ), retry_strategy: options .retry_strategy - .or_else(|| self.default_retry_strategy.clone()), + .or_else(|| self.spawn_defaults.retry_strategy.clone()), cancellation: options .cancellation - .or_else(|| self.default_cancellation.clone()), + .or_else(|| self.spawn_defaults.cancellation.clone()), ..options }; @@ -871,9 +866,11 @@ mod tests { LeaseExtender::dummy_for_tests(), Arc::new(RwLock::new(TaskRegistry::new())), (), - 5, // default_max_attempts - None, // default_retry_strategy - None, // default_cancellation + SpawnDefaults { + max_attempts: 5, + retry_strategy: None, + cancellation: None, + }, ) .await .unwrap(); diff --git a/src/lib.rs b/src/lib.rs index b56ad69..1c5d404 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,8 +107,8 @@ pub use context::TaskContext; pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult}; pub use task::Task; pub use types::{ - CancellationPolicy, ClaimedTask, RetryStrategy, SpawnOptions, SpawnResult, TaskHandle, - WorkerOptions, + CancellationPolicy, ClaimedTask, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, + TaskHandle, WorkerOptions, }; pub use worker::Worker; diff --git a/src/types.rs b/src/types.rs index 5861750..5319dcf 100644 --- a/src/types.rs +++ b/src/types.rs @@ -315,6 +315,21 @@ impl TaskHandle { } } +/// Default settings for spawned tasks. +/// +/// Groups the default `max_attempts`, `retry_strategy`, and `cancellation` +/// settings that are applied when spawning tasks (either from the client +/// or from within a task context). +#[derive(Debug, Clone, Default)] +pub struct SpawnDefaults { + /// Default max attempts for spawned tasks (default: 5) + pub max_attempts: u32, + /// Default retry strategy for spawned tasks + pub retry_strategy: Option, + /// Default cancellation policy for spawned tasks + pub cancellation: Option, +} + /// Terminal status of a child task. /// /// This enum represents the possible terminal states a subtask can be in diff --git a/src/worker.rs b/src/worker.rs index 6011c31..e04bd9b 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -11,7 +11,7 @@ use uuid::Uuid; use crate::context::TaskContext; use crate::error::{ControlFlow, TaskError, serialize_task_error}; use crate::task::TaskRegistry; -use crate::types::{CancellationPolicy, ClaimedTask, ClaimedTaskRow, RetryStrategy, WorkerOptions}; +use crate::types::{ClaimedTask, ClaimedTaskRow, SpawnDefaults, WorkerOptions}; /// Notifies the worker that the lease has been extended. /// Used by TaskContext to reset warning/fatal timers. @@ -67,9 +67,7 @@ impl Worker { registry: Arc>>, options: WorkerOptions, state: State, - default_max_attempts: u32, - default_retry_strategy: Option, - default_cancellation: Option, + spawn_defaults: SpawnDefaults, ) -> Self where State: Clone + Send + Sync + 'static, @@ -95,9 +93,7 @@ impl Worker { worker_id, shutdown_rx, state, - default_max_attempts, - default_retry_strategy, - default_cancellation, + spawn_defaults, )); Self { @@ -115,6 +111,7 @@ impl Worker { let _ = self.handle.await; } + #[allow(clippy::too_many_arguments)] async fn run_loop( pool: PgPool, queue_name: String, @@ -123,9 +120,7 @@ impl Worker { worker_id: String, mut shutdown_rx: broadcast::Receiver<()>, state: State, - default_max_attempts: u32, - default_retry_strategy: Option, - default_cancellation: Option, + spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -199,8 +194,7 @@ impl Worker { let registry = registry.clone(); let done_tx = done_tx.clone(); let state = state.clone(); - let default_retry_strategy = default_retry_strategy.clone(); - let default_cancellation = default_cancellation.clone(); + let spawn_defaults = spawn_defaults.clone(); tokio::spawn(async move { Self::execute_task( @@ -211,9 +205,7 @@ impl Worker { claim_timeout, fatal_on_lease_timeout, state, - default_max_attempts, - default_retry_strategy, - default_cancellation, + spawn_defaults, ).await; drop(permit); @@ -272,6 +264,7 @@ impl Worker { Ok(tasks) } + #[allow(clippy::too_many_arguments)] async fn execute_task( pool: PgPool, queue_name: String, @@ -280,9 +273,7 @@ impl Worker { claim_timeout: Duration, fatal_on_lease_timeout: bool, state: State, - default_max_attempts: u32, - default_retry_strategy: Option, - default_cancellation: Option, + spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -312,14 +303,13 @@ impl Worker { claim_timeout, fatal_on_lease_timeout, state, - default_max_attempts, - default_retry_strategy, - default_cancellation, + spawn_defaults, ) .instrument(span) .await } + #[allow(clippy::too_many_arguments)] async fn execute_task_inner( pool: PgPool, queue_name: String, @@ -328,9 +318,7 @@ impl Worker { claim_timeout: Duration, fatal_on_lease_timeout: bool, state: State, - default_max_attempts: u32, - default_retry_strategy: Option, - default_cancellation: Option, + spawn_defaults: SpawnDefaults, ) where State: Clone + Send + Sync + 'static, { @@ -356,9 +344,7 @@ impl Worker { lease_extender, registry.clone(), state.clone(), - default_max_attempts, - default_retry_strategy, - default_cancellation, + spawn_defaults, ) .await { From eb6eab385faa021a0afc2e0cdf518931c00fe5c2 Mon Sep 17 00:00:00 2001 From: Viraj Mehta Date: Thu, 1 Jan 2026 21:12:08 -0500 Subject: [PATCH 4/4] deflake --- tests/crash_test.rs | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/crash_test.rs b/tests/crash_test.rs index bedca5d..b6eef1c 100644 --- a/tests/crash_test.rs +++ b/tests/crash_test.rs @@ -1,7 +1,7 @@ #![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] use sqlx::{AssertSqlSafe, PgPool}; -use std::time::Duration; +use std::time::{Duration, Instant}; mod common; @@ -538,14 +538,33 @@ async fn test_slow_task_outlives_lease(pool: PgPool) -> sqlx::Result<()> { // Wait for real time to pass the lease timeout tokio::time::sleep(claim_timeout + Duration::from_secs(2)).await; - // Verify a new run was created (reclaim happened) - let run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?; + // Second worker polls to reclaim the expired lease. + let worker2 = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(10), + ..Default::default() + }) + .await + .unwrap(); + + // Verify a new run was created (reclaim happened), with bounded polling. + let deadline = Instant::now() + Duration::from_secs(5); + let mut run_count = 0; + while Instant::now() < deadline { + run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?; + if run_count >= 2 { + break; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } assert!( run_count >= 2, "Should have at least 2 runs after lease expiration, got {}", run_count ); + worker2.shutdown().await; worker.shutdown().await; Ok(())