diff --git a/rust/impls/Cargo.toml b/rust/impls/Cargo.toml index 026bbda..37e3c47 100644 --- a/rust/impls/Cargo.toml +++ b/rust/impls/Cargo.toml @@ -10,6 +10,7 @@ chrono = "0.4.38" tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] } bb8-postgres = "0.7" bytes = "1.4.0" +tokio = { version = "1.38.0", default-features = false } [dev-dependencies] tokio = { version = "1.38.0", default-features = false, features = ["rt-multi-thread", "macros"] } diff --git a/rust/impls/src/lib.rs b/rust/impls/src/lib.rs index 3d4f64f..27844d0 100644 --- a/rust/impls/src/lib.rs +++ b/rust/impls/src/lib.rs @@ -11,6 +11,7 @@ #![deny(rustdoc::private_intra_doc_links)] #![deny(missing_docs)] +mod migrations; /// Contains [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS. pub mod postgres_store; diff --git a/rust/impls/src/migrations.rs b/rust/impls/src/migrations.rs new file mode 100644 index 0000000..bab951b --- /dev/null +++ b/rust/impls/src/migrations.rs @@ -0,0 +1,39 @@ +pub(crate) const DB_VERSION_COLUMN: &str = "db_version"; +#[cfg(test)] +pub(crate) const MIGRATION_LOG_COLUMN: &str = "upgrade_from"; + +pub(crate) const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1"; +pub(crate) const INIT_DB_CMD: &str = "CREATE DATABASE"; +#[cfg(test)] +const DROP_DB_CMD: &str = "DROP DATABASE"; +pub(crate) const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;"; +pub(crate) const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;"; +pub(crate) const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);"; +#[cfg(test)] +pub(crate) const GET_MIGRATION_LOG_STMT: &str = "SELECT upgrade_from FROM vss_db_upgrades;"; + +// APPEND-ONLY list of migration statements +// +// Each statement MUST be applied in-order, and only once per database. +// +// We make an exception for the vss_db table creation statement, as users of VSS could have initialized the table +// themselves. +pub(crate) const MIGRATIONS: &[&str] = &[ + "CREATE TABLE vss_db_version (db_version INTEGER);", + "INSERT INTO vss_db_version VALUES(1);", + // A write-only log of all the migrations performed on this database, useful for debugging and testing + "CREATE TABLE vss_db_upgrades (upgrade_from INTEGER);", + // We do not complain if the table already exists, as users of VSS could have already created this table + "CREATE TABLE IF NOT EXISTS vss_db ( + user_token character varying(120) NOT NULL CHECK (user_token <> ''), + store_id character varying(120) NOT NULL CHECK (store_id <> ''), + key character varying(600) NOT NULL, + value bytea NULL, + version bigint NOT NULL, + created_at TIMESTAMP WITH TIME ZONE, + last_updated_at TIMESTAMP WITH TIME ZONE, + PRIMARY KEY (user_token, store_id, key) + );", +]; +#[cfg(test)] +const DUMMY_MIGRATION: &str = "SELECT 1 WHERE FALSE;"; diff --git a/rust/impls/src/postgres_store.rs b/rust/impls/src/postgres_store.rs index 96f9618..f3b39c3 100644 --- a/rust/impls/src/postgres_store.rs +++ b/rust/impls/src/postgres_store.rs @@ -1,3 +1,5 @@ +use crate::migrations::*; + use api::error::VssError; use api::kv_store::{KvStore, GLOBAL_VERSION_KEY, INITIAL_RECORD_VERSION}; use api::types::{ @@ -12,7 +14,7 @@ use chrono::Utc; use std::cmp::min; use std::io; use std::io::{Error, ErrorKind}; -use tokio_postgres::{NoTls, Transaction}; +use tokio_postgres::{error, NoTls, Transaction}; pub(crate) struct VssDbRecord { pub(crate) user_token: String, @@ -46,17 +48,189 @@ pub struct PostgresBackendImpl { pool: Pool>, } +async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> { + let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres"); + let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls) + .await + .map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?; + // Connection must be driven on a separate task, and will resolve when the client is dropped + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("Connection error: {}", e); + } + }); + + let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| { + Error::new( + ErrorKind::Other, + format!("Failed to check presence of database {}: {}", db_name, e), + ) + })?; + + if num_rows == 0 { + let stmt = format!("{} {};", INIT_DB_CMD, db_name); + client.execute(&stmt, &[]).await.map_err(|e| { + Error::new(ErrorKind::Other, format!("Failed to create database {}: {}", db_name, e)) + })?; + println!("Created database {}", db_name); + } + + Ok(()) +} + +#[cfg(test)] +async fn drop_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> { + let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres"); + let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls) + .await + .map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?; + // Connection must be driven on a separate task, and will resolve when the client is dropped + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("Connection error: {}", e); + } + }); + + let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name); + let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| { + Error::new( + ErrorKind::Other, + format!("Failed to drop database {}: {}", db_name, e), + ) + })?; + assert_eq!(num_rows, 0); + + Ok(()) +} + impl PostgresBackendImpl { /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information. - pub async fn new(dsn: &str) -> Result { - let manager = PostgresConnectionManager::new_from_stringlike(dsn, NoTls).map_err(|e| { - Error::new(ErrorKind::Other, format!("Connection manager error: {}", e)) - })?; + pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result { + initialize_vss_database(postgres_endpoint, db_name).await?; + + let vss_dsn = format!("{}/{}", postgres_endpoint, db_name); + let manager = + PostgresConnectionManager::new_from_stringlike(vss_dsn, NoTls).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("Failed to create PostgresConnectionManager: {}", e), + ) + })?; + // By default, Pool maintains 0 long-running connections, so returning a pool + // here is no guarantee that Pool established a connection to the database. + // + // See Builder::min_idle to increase the long-running connection count. let pool = Pool::builder() .build(manager) .await - .map_err(|e| Error::new(ErrorKind::Other, format!("Pool build error: {}", e)))?; - Ok(PostgresBackendImpl { pool }) + .map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?; + let postgres_backend = PostgresBackendImpl { pool }; + + #[cfg(not(test))] + postgres_backend.migrate_vss_database(MIGRATIONS).await?; + + Ok(postgres_backend) + } + + async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> { + let mut conn = self.pool.get().await.map_err(|e| { + Error::new( + ErrorKind::Other, + format!("Failed to fetch a connection from Pool: {}", e), + ) + })?; + + // Get the next migration to be applied. + let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await { + Ok(row) => { + let i: i32 = row.get(DB_VERSION_COLUMN); + usize::try_from(i).expect("The column should always contain unsigned integers") + }, + Err(e) => { + // If the table is not defined, start at migration 0 + if let Some(&error::SqlState::UNDEFINED_TABLE) = e.code() { + 0 + } else { + return Err(Error::new( + ErrorKind::Other, + format!("Failed to query the version of the database schema: {}", e), + )); + } + }, + }; + + let tx = conn + .transaction() + .await + .map_err(|e| Error::new(ErrorKind::Other, format!("Transaction start error: {}", e)))?; + + if migration_start == migrations.len() { + // No migrations needed, we are done + return Ok((migration_start, migrations.len())); + } else if migration_start > migrations.len() { + panic!("We do not allow downgrades"); + } + + println!("Applying migration(s) {} through {}", migration_start, migrations.len() - 1); + + for (idx, &stmt) in (&migrations[migration_start..]).iter().enumerate() { + let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| { + Error::new( + ErrorKind::Other, + format!( + "Database migration no {} with stmt {} failed: {}", + migration_start + idx, + stmt, + e + ), + ) + })?; + } + + let num_rows = tx + .execute( + LOG_MIGRATION_STMT, + &[&i32::try_from(migration_start).expect("Read from an i32 further above")], + ) + .await + .map_err(|e| { + Error::new(ErrorKind::Other, format!("Failed to log database migration: {}", e)) + })?; + assert_eq!(num_rows, 1, "LOG_MIGRATION_STMT should only add one row at a time"); + + let next_migration_start = + i32::try_from(migrations.len()).expect("Length is definitely smaller than i32::MAX"); + let num_rows = + tx.execute(UPDATE_VERSION_STMT, &[&next_migration_start]).await.map_err(|e| { + Error::new( + ErrorKind::Other, + format!("Failed to update the version of the schema: {}", e), + ) + })?; + assert_eq!( + num_rows, 1, + "UPDATE_VERSION_STMT should only update the unique row in the version table" + ); + + tx.commit().await.map_err(|e| { + Error::new(ErrorKind::Other, format!("Transaction commit error: {}", e)) + })?; + + Ok((migration_start, migrations.len())) + } + + #[cfg(test)] + async fn get_schema_version(&self) -> usize { + let conn = self.pool.get().await.unwrap(); + let row = conn.query_one(GET_VERSION_STMT, &[]).await.unwrap(); + usize::try_from(row.get::<&str, i32>(DB_VERSION_COLUMN)).unwrap() + } + + #[cfg(test)] + async fn get_upgrades_list(&self) -> Vec { + let conn = self.pool.get().await.unwrap(); + let rows = conn.query(GET_MIGRATION_LOG_STMT, &[]).await.unwrap(); + rows.iter().map(|row| usize::try_from(row.get::<&str, i32>(MIGRATION_LOG_COLUMN)).unwrap()).collect() } fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord { @@ -409,12 +583,105 @@ impl KvStore for PostgresBackendImpl { mod tests { use crate::postgres_store::PostgresBackendImpl; use api::define_kv_store_tests; + use tokio::sync::OnceCell; + use super::{MIGRATIONS, DUMMY_MIGRATION, drop_database}; + + const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432"; + const MIGRATIONS_START: usize = 0; + const MIGRATIONS_END: usize = MIGRATIONS.len(); + + static START: OnceCell<()> = OnceCell::const_new(); + + define_kv_store_tests!(PostgresKvStoreTest, PostgresBackendImpl, { + let db_name = "postgres_kv_store_tests"; + START + .get_or_init(|| async { + let _ = drop_database(POSTGRES_ENDPOINT, db_name).await; + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); + assert_eq!(start, MIGRATIONS_START); + assert_eq!(end, MIGRATIONS_END); + }) + .await; + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); + assert_eq!(start, MIGRATIONS_END); + assert_eq!(end, MIGRATIONS_END); + assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]); + assert_eq!(store.get_schema_version().await, MIGRATIONS_END); + store + }); + + #[tokio::test] + #[should_panic(expected = "We do not allow downgrades")] + async fn panic_on_downgrade() { + let db_name = "panic_on_downgrade_test"; + let _ = drop_database(POSTGRES_ENDPOINT, db_name).await; + { + let mut migrations = MIGRATIONS.to_vec(); + migrations.push(DUMMY_MIGRATION); + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let (start, end) = store.migrate_vss_database(&migrations).await.unwrap(); + assert_eq!(start, MIGRATIONS_START); + assert_eq!(end, MIGRATIONS_END + 1); + }; + { + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap(); + }; + } - define_kv_store_tests!( - PostgresKvStoreTest, - PostgresBackendImpl, - PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432/postgres") - .await - .unwrap() - ); + #[tokio::test] + async fn new_migrations_increments_upgrades() { + let db_name = "new_migrations_increments_upgrades_test"; + let _ = drop_database(POSTGRES_ENDPOINT, db_name).await; + { + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); + assert_eq!(start, MIGRATIONS_START); + assert_eq!(end, MIGRATIONS_END); + assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]); + assert_eq!(store.get_schema_version().await, MIGRATIONS_END); + }; + { + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); + assert_eq!(start, MIGRATIONS_END); + assert_eq!(end, MIGRATIONS_END); + assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]); + assert_eq!(store.get_schema_version().await, MIGRATIONS_END); + }; + + let mut migrations = MIGRATIONS.to_vec(); + migrations.push(DUMMY_MIGRATION); + { + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let (start, end) = store.migrate_vss_database(&migrations).await.unwrap(); + assert_eq!(start, MIGRATIONS_END); + assert_eq!(end, MIGRATIONS_END + 1); + assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END]); + assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 1); + }; + + migrations.push(DUMMY_MIGRATION); + migrations.push(DUMMY_MIGRATION); + { + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let (start, end) = store.migrate_vss_database(&migrations).await.unwrap(); + assert_eq!(start, MIGRATIONS_END + 1); + assert_eq!(end, MIGRATIONS_END + 3); + assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]); + assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 3); + }; + + { + let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let list = store.get_upgrades_list().await; + assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]); + let version = store.get_schema_version().await; + assert_eq!(version, MIGRATIONS_END + 3); + } + + drop_database(POSTGRES_ENDPOINT, db_name).await.unwrap(); + } } diff --git a/rust/server/src/main.rs b/rust/server/src/main.rs index 849dcee..5a78be6 100644 --- a/rust/server/src/main.rs +++ b/rust/server/src/main.rs @@ -67,13 +67,18 @@ fn main() { }, }; let authorizer = Arc::new(NoopAuthorizer {}); + let postgresql_config = config.postgresql_config.expect("PostgreSQLConfig must be defined in config file."); + let endpoint = postgresql_config.to_postgresql_endpoint(); + let db_name = postgresql_config.database; let store = Arc::new( - PostgresBackendImpl::new(&config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.").to_connection_string()) + PostgresBackendImpl::new(&endpoint, &db_name) .await .unwrap(), ); + println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name); let rest_svc_listener = TcpListener::bind(&addr).await.expect("Failed to bind listening port"); + println!("Listening for incoming connections on {}", addr); loop { tokio::select! { res = rest_svc_listener.accept() => { diff --git a/rust/server/src/util/config.rs b/rust/server/src/util/config.rs index e75f972..cf70daf 100644 --- a/rust/server/src/util/config.rs +++ b/rust/server/src/util/config.rs @@ -22,7 +22,7 @@ pub(crate) struct PostgreSQLConfig { } impl PostgreSQLConfig { - pub(crate) fn to_connection_string(&self) -> String { + pub(crate) fn to_postgresql_endpoint(&self) -> String { let username_env = std::env::var("VSS_POSTGRESQL_USERNAME"); let username = username_env.as_ref() .ok() @@ -34,10 +34,7 @@ impl PostgreSQLConfig { .or_else(|| self.password.as_ref()) .expect("PostgreSQL database password must be provided in config or env var VSS_POSTGRESQL_PASSWORD must be set."); - format!( - "postgresql://{}:{}@{}:{}/{}", - username, password, self.host, self.port, self.database - ) + format!("postgresql://{}:{}@{}:{}", username, password, self.host, self.port) } }