Skip to content

Commit ca986b6

Browse files
authored
Allow fully configurable defaults for tool spawn settings (#44)
* propagate default num attempts to spawned tasks * add full defaults for task spawning setting * added a SpawnDefaults struct * deflake
1 parent 161e9e5 commit ca986b6

File tree

7 files changed

+236
-17
lines changed

7 files changed

+236
-17
lines changed

src/client.rs

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ use uuid::Uuid;
1111
use crate::error::{DurableError, DurableResult};
1212
use crate::task::{Task, TaskRegistry};
1313
use crate::types::{
14-
CancellationPolicy, RetryStrategy, SpawnOptions, SpawnResult, SpawnResultRow, WorkerOptions,
14+
CancellationPolicy, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, SpawnResultRow,
15+
WorkerOptions,
1516
};
1617

1718
/// Internal struct for serializing spawn options to the database.
@@ -110,7 +111,7 @@ where
110111
pool: PgPool,
111112
owns_pool: bool,
112113
queue_name: String,
113-
default_max_attempts: u32,
114+
spawn_defaults: SpawnDefaults,
114115
registry: Arc<RwLock<TaskRegistry<State>>>,
115116
state: State,
116117
}
@@ -120,11 +121,23 @@ where
120121
/// # Example
121122
///
122123
/// ```ignore
124+
/// use std::time::Duration;
125+
/// use durable::{Durable, RetryStrategy, CancellationPolicy};
126+
///
123127
/// // Without state
124128
/// let client = Durable::builder()
125129
/// .database_url("postgres://localhost/myapp")
126130
/// .queue_name("orders")
127131
/// .default_max_attempts(3)
132+
/// .default_retry_strategy(RetryStrategy::Exponential {
133+
/// base_delay: Duration::from_secs(5),
134+
/// factor: 2.0,
135+
/// max_backoff: Duration::from_secs(300),
136+
/// })
137+
/// .default_cancellation(CancellationPolicy {
138+
/// max_pending_time: Some(Duration::from_secs(3600)),
139+
/// max_running_time: None,
140+
/// })
128141
/// .build()
129142
/// .await?;
130143
///
@@ -138,7 +151,7 @@ pub struct DurableBuilder {
138151
database_url: Option<String>,
139152
pool: Option<PgPool>,
140153
queue_name: String,
141-
default_max_attempts: u32,
154+
spawn_defaults: SpawnDefaults,
142155
}
143156

144157
impl DurableBuilder {
@@ -147,7 +160,11 @@ impl DurableBuilder {
147160
database_url: None,
148161
pool: None,
149162
queue_name: "default".to_string(),
150-
default_max_attempts: 5,
163+
spawn_defaults: SpawnDefaults {
164+
max_attempts: 5,
165+
retry_strategy: None,
166+
cancellation: None,
167+
},
151168
}
152169
}
153170

@@ -171,7 +188,19 @@ impl DurableBuilder {
171188

172189
/// Set default max attempts for spawned tasks (default: 5)
173190
pub fn default_max_attempts(mut self, attempts: u32) -> Self {
174-
self.default_max_attempts = attempts;
191+
self.spawn_defaults.max_attempts = attempts;
192+
self
193+
}
194+
195+
/// Set default retry strategy for spawned tasks (default: Fixed with 5s delay)
196+
pub fn default_retry_strategy(mut self, strategy: RetryStrategy) -> Self {
197+
self.spawn_defaults.retry_strategy = Some(strategy);
198+
self
199+
}
200+
201+
/// Set default cancellation policy for spawned tasks (default: no auto-cancellation)
202+
pub fn default_cancellation(mut self, policy: CancellationPolicy) -> Self {
203+
self.spawn_defaults.cancellation = Some(policy);
175204
self
176205
}
177206

@@ -226,7 +255,7 @@ impl DurableBuilder {
226255
pool,
227256
owns_pool,
228257
queue_name: self.queue_name,
229-
default_max_attempts: self.default_max_attempts,
258+
spawn_defaults: self.spawn_defaults,
230259
registry: Arc::new(RwLock::new(HashMap::new())),
231260
state,
232261
})
@@ -471,7 +500,19 @@ where
471500
#[cfg(feature = "telemetry")]
472501
tracing::Span::current().record("queue", &self.queue_name);
473502

474-
let max_attempts = options.max_attempts.unwrap_or(self.default_max_attempts);
503+
// Apply defaults if not set
504+
let max_attempts = options
505+
.max_attempts
506+
.unwrap_or(self.spawn_defaults.max_attempts);
507+
let options = SpawnOptions {
508+
retry_strategy: options
509+
.retry_strategy
510+
.or_else(|| self.spawn_defaults.retry_strategy.clone()),
511+
cancellation: options
512+
.cancellation
513+
.or_else(|| self.spawn_defaults.cancellation.clone()),
514+
..options
515+
};
475516

476517
let db_options = Self::serialize_spawn_options(&options, max_attempts)?;
477518

@@ -649,6 +690,7 @@ where
649690
self.registry.clone(),
650691
options,
651692
self.state.clone(),
693+
self.spawn_defaults.clone(),
652694
)
653695
.await)
654696
}

src/context.rs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use uuid::Uuid;
1111
use crate::error::{ControlFlow, TaskError, TaskResult};
1212
use crate::task::{Task, TaskRegistry};
1313
use crate::types::{
14-
AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnOptions,
15-
SpawnResultRow, TaskHandle,
14+
AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnDefaults,
15+
SpawnOptions, SpawnResultRow, TaskHandle,
1616
};
1717
use crate::worker::LeaseExtender;
1818

@@ -72,6 +72,9 @@ where
7272

7373
/// Task registry for validating spawn_by_name calls.
7474
registry: Arc<RwLock<TaskRegistry<State>>>,
75+
76+
/// Default settings for subtasks spawned via spawn/spawn_by_name.
77+
spawn_defaults: SpawnDefaults,
7578
}
7679

7780
/// Validate that a user-provided step name doesn't use reserved prefix.
@@ -90,6 +93,7 @@ where
9093
{
9194
/// Create a new TaskContext. Called by the worker before executing a task.
9295
/// Loads all existing checkpoints into the cache.
96+
#[allow(clippy::too_many_arguments)]
9397
pub(crate) async fn create(
9498
pool: PgPool,
9599
queue_name: String,
@@ -98,6 +102,7 @@ where
98102
lease_extender: LeaseExtender,
99103
registry: Arc<RwLock<TaskRegistry<State>>>,
100104
state: State,
105+
spawn_defaults: SpawnDefaults,
101106
) -> Result<Self, sqlx::Error> {
102107
// Load all checkpoints for this task into cache
103108
let checkpoints: Vec<CheckpointRow> = sqlx::query_as(
@@ -127,6 +132,7 @@ where
127132
lease_extender,
128133
registry,
129134
state,
135+
spawn_defaults,
130136
})
131137
}
132138

@@ -668,6 +674,22 @@ where
668674
}
669675
}
670676

677+
// Apply defaults if not set
678+
let options = SpawnOptions {
679+
max_attempts: Some(
680+
options
681+
.max_attempts
682+
.unwrap_or(self.spawn_defaults.max_attempts),
683+
),
684+
retry_strategy: options
685+
.retry_strategy
686+
.or_else(|| self.spawn_defaults.retry_strategy.clone()),
687+
cancellation: options
688+
.cancellation
689+
.or_else(|| self.spawn_defaults.cancellation.clone()),
690+
..options
691+
};
692+
671693
// Build options JSON, merging user options with parent_task_id
672694
#[derive(Serialize)]
673695
struct SubtaskOptions<'a> {
@@ -844,6 +866,11 @@ mod tests {
844866
LeaseExtender::dummy_for_tests(),
845867
Arc::new(RwLock::new(TaskRegistry::new())),
846868
(),
869+
SpawnDefaults {
870+
max_attempts: 5,
871+
retry_strategy: None,
872+
cancellation: None,
873+
},
847874
)
848875
.await
849876
.unwrap();

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ pub use context::TaskContext;
107107
pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult};
108108
pub use task::Task;
109109
pub use types::{
110-
CancellationPolicy, ClaimedTask, RetryStrategy, SpawnOptions, SpawnResult, TaskHandle,
111-
WorkerOptions,
110+
CancellationPolicy, ClaimedTask, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult,
111+
TaskHandle, WorkerOptions,
112112
};
113113
pub use worker::Worker;
114114

src/types.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,21 @@ impl<T> TaskHandle<T> {
315315
}
316316
}
317317

318+
/// Default settings for spawned tasks.
319+
///
320+
/// Groups the default `max_attempts`, `retry_strategy`, and `cancellation`
321+
/// settings that are applied when spawning tasks (either from the client
322+
/// or from within a task context).
323+
#[derive(Debug, Clone, Default)]
324+
pub struct SpawnDefaults {
325+
/// Default max attempts for spawned tasks (default: 5)
326+
pub max_attempts: u32,
327+
/// Default retry strategy for spawned tasks
328+
pub retry_strategy: Option<RetryStrategy>,
329+
/// Default cancellation policy for spawned tasks
330+
pub cancellation: Option<CancellationPolicy>,
331+
}
332+
318333
/// Terminal status of a child task.
319334
///
320335
/// This enum represents the possible terminal states a subtask can be in

src/worker.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use uuid::Uuid;
1111
use crate::context::TaskContext;
1212
use crate::error::{ControlFlow, TaskError, serialize_task_error};
1313
use crate::task::TaskRegistry;
14-
use crate::types::{ClaimedTask, ClaimedTaskRow, WorkerOptions};
14+
use crate::types::{ClaimedTask, ClaimedTaskRow, SpawnDefaults, WorkerOptions};
1515

1616
/// Notifies the worker that the lease has been extended.
1717
/// Used by TaskContext to reset warning/fatal timers.
@@ -67,6 +67,7 @@ impl Worker {
6767
registry: Arc<RwLock<TaskRegistry<State>>>,
6868
options: WorkerOptions,
6969
state: State,
70+
spawn_defaults: SpawnDefaults,
7071
) -> Self
7172
where
7273
State: Clone + Send + Sync + 'static,
@@ -92,6 +93,7 @@ impl Worker {
9293
worker_id,
9394
shutdown_rx,
9495
state,
96+
spawn_defaults,
9597
));
9698

9799
Self {
@@ -109,6 +111,7 @@ impl Worker {
109111
let _ = self.handle.await;
110112
}
111113

114+
#[allow(clippy::too_many_arguments)]
112115
async fn run_loop<State>(
113116
pool: PgPool,
114117
queue_name: String,
@@ -117,6 +120,7 @@ impl Worker {
117120
worker_id: String,
118121
mut shutdown_rx: broadcast::Receiver<()>,
119122
state: State,
123+
spawn_defaults: SpawnDefaults,
120124
) where
121125
State: Clone + Send + Sync + 'static,
122126
{
@@ -190,6 +194,7 @@ impl Worker {
190194
let registry = registry.clone();
191195
let done_tx = done_tx.clone();
192196
let state = state.clone();
197+
let spawn_defaults = spawn_defaults.clone();
193198

194199
tokio::spawn(async move {
195200
Self::execute_task(
@@ -200,6 +205,7 @@ impl Worker {
200205
claim_timeout,
201206
fatal_on_lease_timeout,
202207
state,
208+
spawn_defaults,
203209
).await;
204210

205211
drop(permit);
@@ -258,6 +264,7 @@ impl Worker {
258264
Ok(tasks)
259265
}
260266

267+
#[allow(clippy::too_many_arguments)]
261268
async fn execute_task<State>(
262269
pool: PgPool,
263270
queue_name: String,
@@ -266,6 +273,7 @@ impl Worker {
266273
claim_timeout: Duration,
267274
fatal_on_lease_timeout: bool,
268275
state: State,
276+
spawn_defaults: SpawnDefaults,
269277
) where
270278
State: Clone + Send + Sync + 'static,
271279
{
@@ -295,11 +303,13 @@ impl Worker {
295303
claim_timeout,
296304
fatal_on_lease_timeout,
297305
state,
306+
spawn_defaults,
298307
)
299308
.instrument(span)
300309
.await
301310
}
302311

312+
#[allow(clippy::too_many_arguments)]
303313
async fn execute_task_inner<State>(
304314
pool: PgPool,
305315
queue_name: String,
@@ -308,6 +318,7 @@ impl Worker {
308318
claim_timeout: Duration,
309319
fatal_on_lease_timeout: bool,
310320
state: State,
321+
spawn_defaults: SpawnDefaults,
311322
) where
312323
State: Clone + Send + Sync + 'static,
313324
{
@@ -333,6 +344,7 @@ impl Worker {
333344
lease_extender,
334345
registry.clone(),
335346
state.clone(),
347+
spawn_defaults,
336348
)
337349
.await
338350
{

tests/crash_test.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
22

33
use sqlx::{AssertSqlSafe, PgPool};
4-
use std::time::Duration;
4+
use std::time::{Duration, Instant};
55

66
mod common;
77

@@ -538,14 +538,33 @@ async fn test_slow_task_outlives_lease(pool: PgPool) -> sqlx::Result<()> {
538538
// Wait for real time to pass the lease timeout
539539
tokio::time::sleep(claim_timeout + Duration::from_secs(2)).await;
540540

541-
// Verify a new run was created (reclaim happened)
542-
let run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?;
541+
// Second worker polls to reclaim the expired lease.
542+
let worker2 = client
543+
.start_worker(WorkerOptions {
544+
poll_interval: Duration::from_millis(50),
545+
claim_timeout: Duration::from_secs(10),
546+
..Default::default()
547+
})
548+
.await
549+
.unwrap();
550+
551+
// Verify a new run was created (reclaim happened), with bounded polling.
552+
let deadline = Instant::now() + Duration::from_secs(5);
553+
let mut run_count = 0;
554+
while Instant::now() < deadline {
555+
run_count = count_runs_for_task(&pool, "crash_slow", spawn_result.task_id).await?;
556+
if run_count >= 2 {
557+
break;
558+
}
559+
tokio::time::sleep(Duration::from_millis(50)).await;
560+
}
543561
assert!(
544562
run_count >= 2,
545563
"Should have at least 2 runs after lease expiration, got {}",
546564
run_count
547565
);
548566

567+
worker2.shutdown().await;
549568
worker.shutdown().await;
550569

551570
Ok(())

0 commit comments

Comments
 (0)