Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions sqlx-postgres/src/migrate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::str::FromStr;
use std::time::Duration;
use std::time::Instant;

Expand All @@ -8,6 +7,7 @@ pub(crate) use sqlx_core::migrate::MigrateError;
pub(crate) use sqlx_core::migrate::{AppliedMigration, Migration};
pub(crate) use sqlx_core::migrate::{Migrate, MigrateDatabase};
use sqlx_core::sql_str::AssertSqlSafe;
use sqlx_core::Url;

use crate::connection::{ConnectOptions, Connection};
use crate::error::Error;
Expand All @@ -18,7 +18,29 @@ use crate::query_scalar::query_scalar;
use crate::{PgConnectOptions, PgConnection, Postgres};

fn parse_for_maintenance(url: &str) -> Result<(PgConnectOptions, String), Error> {
let mut options = PgConnectOptions::from_str(url)?;
let mut url: Url = url.parse().map_err(Error::config)?;

// check for user provided `?maintenance_database=example`
let mut maintenance_database = None;
if url.query_pairs().any(|(k, _)| k == "maintenance_database") {
let remaining: Vec<(String, String)> = url
.query_pairs()
.into_owned()
.filter_map(|(k, v)| {
if k == "maintenance_database" {
if maintenance_database.is_none() {
maintenance_database = Some(v);
}
None
} else {
Some((k, v))
}
})
.collect();
url.query_pairs_mut().clear().extend_pairs(remaining);
}

let mut options = PgConnectOptions::parse_from_url(&url)?;

// pull out the name of the database to create
let database = options
Expand All @@ -28,13 +50,17 @@ fn parse_for_maintenance(url: &str) -> Result<(PgConnectOptions, String), Error>
.to_owned();

// switch us to the maintenance database
// use `postgres` _unless_ the database is postgres, in which case, use `template1`
// this matches the behavior of the `createdb` util
options.database = if database == "postgres" {
Some("template1".into())
if maintenance_database.is_some() {
options.database = maintenance_database;
} else {
Some("postgres".into())
};
// use `postgres` _unless_ the database is postgres, in which case, use `template1`
// this matches the behavior of the `createdb` util
options.database = if database == "postgres" {
Some("template1".into())
} else {
Some("postgres".into())
};
}

Ok((options, database))
}
Expand Down Expand Up @@ -350,3 +376,32 @@ fn generate_lock_id(database_name: &str) -> i64 {
// 0x3d32ad9e chosen by fair dice roll
0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64)
}

#[cfg(test)]
mod tests {
use super::parse_for_maintenance;

#[test]
fn test_parse_for_maintenance() {
let (opts, db) = parse_for_maintenance("postgres://user:pass@host/mydb").unwrap();
assert_eq!(opts.database.as_deref(), Some("postgres"));
assert_eq!(db, "mydb");

let (opts, db) = parse_for_maintenance("postgres://user:pass@host/postgres").unwrap();
assert_eq!(opts.database.as_deref(), Some("template1"));
assert_eq!(db, "postgres");

let (opts, db) =
parse_for_maintenance("postgres://user:pass@host/mydb?maintenance_database=defaultdb")
.unwrap();
assert_eq!(opts.database.as_deref(), Some("defaultdb"));
assert_eq!(db, "mydb");

let (opts, db) = parse_for_maintenance(
"postgres://user:pass@host/mydb?sslmode=require&maintenance_database=defaultdb",
)
.unwrap();
assert_eq!(opts.database.as_deref(), Some("defaultdb"));
assert_eq!(db, "mydb");
}
}