diff --git a/Cargo.toml b/Cargo.toml index 2813b3b..c51132c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ thiserror = "1.0" tokio = { version = "1.8", features = ["parking_lot", "rt", "sync", "io-util", "process", "macros", "fs"], default-features = false, optional = true } tracing = "0.1" which = "4.0" +futures-util = { version = "0.3", optional = true } [dev-dependencies] test-log = { version = "0.2", default-features = false, features = ["trace"] } @@ -31,4 +32,4 @@ tracing-subscriber = { version = "0.3", default-features = false, features = ["e [features] default = [] -tokio-process = ["tokio"] +tokio-process = ["tokio", "futures-util"] diff --git a/src/lib.rs b/src/lib.rs index b4bb117..595c69b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -164,17 +164,19 @@ pub async fn new_default_process_async() -> TmpPostgrustResult Result<(), Box>, -) -> TmpPostgrustResult { +pub async fn new_default_process_async_with_migrations( + migrate: F, +) -> TmpPostgrustResult +where + F: for<'r> Fn(&'r str) -> futures_util::future::BoxFuture<'r, Result<(), Box>>, +{ let factory_mutex = TOKIO_POSTGRES_FACTORY .get_or_try_init(|| async { - TmpPostgrustFactory::try_new_async().await.map(|factory| { - factory - .run_migrations(migrate) - .expect("Failed to run migrations."); - tokio::sync::Mutex::new(Some(factory)) - }) + let factory = TmpPostgrustFactory::try_new_async().await?; + factory + .run_migrations_async(migrate) + .await?; + Ok(tokio::sync::Mutex::new(Some(factory))) }) .await?; let guard = factory_mutex.lock().await; @@ -294,7 +296,7 @@ impl TmpPostgrustFactory { /// an error. pub fn run_migrations( &self, - migrate: impl FnOnce(&str) -> Result<(), Box>, + migrate: impl Fn(&str) -> Result<(), Box>, ) -> TmpPostgrustResult<()> { let process = self.start_postgresql(&self.cache_dir)?; @@ -303,6 +305,28 @@ impl TmpPostgrustFactory { Ok(()) } + /// Run migrations against the cache directory, will cause all subsequent instances + /// to be run against a version of the database where the migrations have been applied. + /// + /// # Errors + /// + /// Will error if Postgresql is unable to start or if the migrate function returns + /// an error. + #[cfg(feature = "tokio-process")] + pub async fn run_migrations_async( + &self, + migrate: F, + ) -> TmpPostgrustResult<()> + where + F: for<'r> Fn(&'r str) -> futures_util::future::BoxFuture<'r, Result<(), Box>>, + { + let process = self.start_postgresql(&self.cache_dir)?; + + migrate(&process.connection_string()).await.map_err(TmpPostgrustError::MigrationsFailed)?; + + Ok(()) + } + /// Start a new postgresql instance and return a process guard that will ensure it is cleaned /// up when dropped. #[instrument(skip(self))]