diff --git a/README.md b/README.md index 5582fee..461866b 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,9 @@ When using JDBC, add the following to your `build.gradle.kts` file: ```kotlin plugins { - kotlin("jvm") version "2.1.0" // Currently the plugin is only available for Kotlin-JVM - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + kotlin("jvm") version "2.2.0" // Currently the plugin is only available for Kotlin-JVM + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" } dependencies { @@ -116,9 +116,9 @@ For Android development, add the following to your `build.gradle.kts` file: ```kotlin plugins { - kotlin("android") version "2.1.0" - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + kotlin("android") version "2.2.0" + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" } dependencies { @@ -192,9 +192,9 @@ val person: List = Sql("SELECT * FROM Person").queryOf().runOn(c For iOS, OSX, Linux and Windows development, with Kotlin Multiplatform, add the following to your `build.gradle.kts` file: ```kotlin plugins { - kotlin("multiplatform") version "2.1.0" - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + kotlin("multiplatform") version "2.2.0" + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" } kotlin { diff --git a/build.gradle.kts b/build.gradle.kts index 457003b..37a0ee0 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,9 +1,9 @@ plugins { `maven-publish` signing - kotlin("jvm") version "2.1.0" apply false + kotlin("jvm") version "2.2.0" apply false id("io.github.gradle-nexus.publish-plugin") version "1.1.0" apply false - kotlin("multiplatform") version "2.1.0" apply false + kotlin("multiplatform") version "2.2.0" apply false id("com.android.library") version "8.2.0" apply false id("org.jetbrains.dokka") version "1.9.10" apply false } @@ -85,9 +85,9 @@ subprojects { artifact(javadocJar) pom { - name.set("decomat") - description.set("DecoMat - Deconstructive Pattern Matching for Kotlin") - url.set("https://github.com/exoquery/decomat") + name.set("terpal-sql") + description.set("Safe and fun SQL buliding with interpolated strings in Kotlin") + url.set("https://github.com/exoquery/terpal-sql") licenses { license { diff --git a/controller-android/build.gradle.kts b/controller-android/build.gradle.kts index 94a72ad..9a7e177 100644 --- a/controller-android/build.gradle.kts +++ b/controller-android/build.gradle.kts @@ -8,7 +8,7 @@ plugins { id("conventions") kotlin("multiplatform") id("com.android.library") - kotlin("plugin.serialization") version "2.1.0" + kotlin("plugin.serialization") version "2.2.0" // Already on the classpath //id("org.jetbrains.kotlin.android") version "1.9.23" } @@ -39,8 +39,10 @@ kotlin { androidTarget { compilations.all { - kotlinOptions { - jvmTarget = "17" + compileTaskProvider { + compilerOptions { + jvmTarget.set(JvmTarget.JVM_17) + } } } publishLibraryVariants("release", "debug") diff --git a/controller-android/src/androidMain/kotlin/io/exoquery/controller/android/AndroidDatabaseController.kt b/controller-android/src/androidMain/kotlin/io/exoquery/controller/android/AndroidDatabaseController.kt index 158aa60..2e00283 100644 --- a/controller-android/src/androidMain/kotlin/io/exoquery/controller/android/AndroidDatabaseController.kt +++ b/controller-android/src/androidMain/kotlin/io/exoquery/controller/android/AndroidDatabaseController.kt @@ -151,7 +151,7 @@ class AndroidDatabaseController internal constructor( } // Is there an open writer? - override fun CoroutineContext.hasOpenConnection(): Boolean { + override suspend fun CoroutineContext.hasOpenConnection(): Boolean { val session = get(sessionKey)?.session return session != null && session.isWriter && !isClosedSession(session) } @@ -401,7 +401,7 @@ interface WithReadOnlyVerbs: RequiresSession accessStmtReturning( sql: String, @@ -64,7 +64,7 @@ interface HasSessionAndroid: RequiresSession { // Methods that implementors need to provide val sessionKey: CoroutineContext.Key> abstract suspend fun newSession(executionOptions: ExecutionOpts): Session - abstract fun closeSession(session: Session): Unit - abstract fun isClosedSession(session: Session): Boolean + abstract suspend fun closeSession(session: Session): Unit + abstract suspend fun isClosedSession(session: Session): Boolean suspend fun accessStmt(sql: String, conn: Session, block: suspend (Stmt) -> R): R suspend fun accessStmtReturning(sql: String, conn: Session, options: ExecutionOpts, returningColumns: List, block: suspend (Stmt) -> R): R - fun CoroutineContext.hasOpenConnection(): Boolean { + suspend fun CoroutineContext.hasOpenConnection(): Boolean { val session = get(sessionKey)?.session return session != null && !isClosedSession(session) } @@ -102,7 +102,10 @@ interface RequiresSession { } else { val session = newSession(executionOptions) try { - withContext(CoroutineSession(session, sessionKey) + Dispatchers.IO) { block() } + withContext(CoroutineSession(session, sessionKey) + Dispatchers.IO) { + val output = block() + output + } } finally { closeSession(session) } } } diff --git a/controller-core/src/commonMain/kotlin/io/exoquery/controller/ControllerError.kt b/controller-core/src/commonMain/kotlin/io/exoquery/controller/ControllerError.kt new file mode 100644 index 0000000..dddfe70 --- /dev/null +++ b/controller-core/src/commonMain/kotlin/io/exoquery/controller/ControllerError.kt @@ -0,0 +1,3 @@ +package io.exoquery.controller + +class ControllerError(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/controller-jdbc/src/main/kotlin/io/exoquery/controller/JavaEncoding.kt b/controller-core/src/jvmMain/kotlin/JavaSqlEncoding.kt similarity index 94% rename from controller-jdbc/src/main/kotlin/io/exoquery/controller/JavaEncoding.kt rename to controller-core/src/jvmMain/kotlin/JavaSqlEncoding.kt index 9bc2ac7..b3c0555 100644 --- a/controller-jdbc/src/main/kotlin/io/exoquery/controller/JavaEncoding.kt +++ b/controller-core/src/jvmMain/kotlin/JavaSqlEncoding.kt @@ -1,9 +1,5 @@ package io.exoquery.controller -import io.exoquery.controller.SqlDecoder -import io.exoquery.controller.SqlEncoder -import io.exoquery.controller.SqlEncoding -import io.exoquery.controller.TimeEncoding import java.time.* import java.math.BigDecimal import java.util.Date diff --git a/controller-jdbc/build.gradle.kts b/controller-jdbc/build.gradle.kts index 227d1e1..2b5e62b 100644 --- a/controller-jdbc/build.gradle.kts +++ b/controller-jdbc/build.gradle.kts @@ -1,10 +1,10 @@ import org.gradle.api.tasks.testing.logging.TestExceptionFormat -import org.jetbrains.kotlin.gradle.dsl.KotlinCompile +import org.jetbrains.kotlin.gradle.tasks.KotlinCompilationTask plugins { id("conventions") kotlin("multiplatform") - kotlin("plugin.serialization") version "2.1.0" + kotlin("plugin.serialization") version "2.2.0" } version = extra["controllerVersion"].toString() diff --git a/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcContextMixins.kt b/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcContextMixins.kt index 90d70cd..c0367c2 100644 --- a/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcContextMixins.kt +++ b/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcContextMixins.kt @@ -63,8 +63,8 @@ interface HasSessionJdbc: RequiresSession accessStmtReturning(sql: String, conn: Connection, options: JdbcExecutionOptions, returningColumns: List, block: suspend (PreparedStatement) -> R): R { val stmt = diff --git a/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcController.kt b/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcController.kt index f5b9a77..6cc351e 100644 --- a/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcController.kt +++ b/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcController.kt @@ -2,68 +2,10 @@ package io.exoquery.controller.jdbc import io.exoquery.controller.* import kotlinx.coroutines.flow.* -import kotlinx.serialization.json.Json -import kotlinx.serialization.modules.EmptySerializersModule -import kotlinx.serialization.modules.SerializersModule import javax.sql.DataSource -import kotlinx.datetime.TimeZone import java.sql.* -/** - * Most constructions will want to specify default values from AdditionalJdbcEncoding for additionalEncoders/decoders, - * and they should have a simple construction JdbcEncodingConfig(...). Use `Empty` to make a config that does not - * include these defaults. For this reason the real constructor is private. - */ -data class JdbcEncodingConfig private constructor( - override val additionalEncoders: Set>, - override val additionalDecoders: Set>, - override val json: Json, - // If you want to use any primitive-wrapped contextual encoders you need to add them here - override val module: SerializersModule, - override val timezone: TimeZone, override val debugMode: Boolean -): EncodingConfig { - companion object { - val Default get() = - Default( - AdditionalJdbcEncoding.encoders, - AdditionalJdbcEncoding.decoders - ) - - fun Default( - additionalEncoders: Set> = setOf(), - additionalDecoders: Set> = setOf(), - json: Json = Json, - module: SerializersModule = EmptySerializersModule(), - timezone: TimeZone = TimeZone.currentSystemDefault(), - debugMode: Boolean = false - ) = JdbcEncodingConfig( - additionalEncoders + AdditionalJdbcEncoding.encoders, - additionalDecoders + AdditionalJdbcEncoding.decoders, - json, - module, - timezone, - debugMode - ) - - operator fun invoke( - additionalEncoders: Set> = setOf(), - additionalDecoders: Set> = setOf(), - json: Json = Json, - module: SerializersModule = EmptySerializersModule(), - timezone: TimeZone = TimeZone.currentSystemDefault() - ) = Default(additionalEncoders, additionalDecoders, json, module, timezone) - - fun Empty( - additionalEncoders: Set> = setOf(), - additionalDecoders: Set> = setOf(), - json: Json = Json, - module: SerializersModule = EmptySerializersModule(), - timezone: TimeZone = TimeZone.currentSystemDefault() - ) = JdbcEncodingConfig(additionalEncoders, additionalDecoders, json, module, timezone) - } -} - /** * This is a Terpal Driver, NOT a JDBC driver! It is the base class for all JDBC-based implementations of the * Terpal Driver base class `io.exoquery.sql.Driver`. This naming follows the conventions of SQL Delight diff --git a/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcEncodingConfig.kt b/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcEncodingConfig.kt new file mode 100644 index 0000000..13c57f1 --- /dev/null +++ b/controller-jdbc/src/main/kotlin/io/exoquery/controller/jdbc/JdbcEncodingConfig.kt @@ -0,0 +1,66 @@ +package io.exoquery.controller.jdbc + +import io.exoquery.controller.EncodingConfig +import io.exoquery.controller.SqlDecoder +import io.exoquery.controller.SqlEncoder +import kotlinx.datetime.TimeZone +import kotlinx.serialization.json.Json +import kotlinx.serialization.modules.EmptySerializersModule +import kotlinx.serialization.modules.SerializersModule +import java.sql.Connection +import java.sql.PreparedStatement +import java.sql.ResultSet + +/** + * Most constructions will want to specify default values from AdditionalJdbcEncoding for additionalEncoders/decoders, + * and they should have a simple construction JdbcEncodingConfig(...). Use `Empty` to make a config that does not + * include these defaults. For this reason the real constructor is private. + */ +data class JdbcEncodingConfig private constructor( + override val additionalEncoders: Set>, + override val additionalDecoders: Set>, + override val json: Json, + // If you want to use any primitive-wrapped contextual encoders you need to add them here + override val module: SerializersModule, + override val timezone: TimeZone, override val debugMode: Boolean +): EncodingConfig { + companion object { + val Default get() = + Default( + AdditionalJdbcEncoding.encoders, + AdditionalJdbcEncoding.decoders + ) + + fun Default( + additionalEncoders: Set> = setOf(), + additionalDecoders: Set> = setOf(), + json: Json = Json.Default, + module: SerializersModule = EmptySerializersModule(), + timezone: TimeZone = TimeZone.Companion.currentSystemDefault(), + debugMode: Boolean = false + ) = JdbcEncodingConfig( + additionalEncoders + AdditionalJdbcEncoding.encoders, + additionalDecoders + AdditionalJdbcEncoding.decoders, + json, + module, + timezone, + debugMode + ) + + operator fun invoke( + additionalEncoders: Set> = setOf(), + additionalDecoders: Set> = setOf(), + json: Json = Json.Default, + module: SerializersModule = EmptySerializersModule(), + timezone: TimeZone = TimeZone.Companion.currentSystemDefault() + ) = Default(additionalEncoders, additionalDecoders, json, module, timezone) + + fun Empty( + additionalEncoders: Set> = setOf(), + additionalDecoders: Set> = setOf(), + json: Json = Json.Default, + module: SerializersModule = EmptySerializersModule(), + timezone: TimeZone = TimeZone.Companion.currentSystemDefault() + ) = JdbcEncodingConfig(additionalEncoders, additionalDecoders, json, module, timezone) + } +} diff --git a/controller-native/build.gradle.kts b/controller-native/build.gradle.kts index fd7f018..a365f3c 100644 --- a/controller-native/build.gradle.kts +++ b/controller-native/build.gradle.kts @@ -8,7 +8,7 @@ import org.jetbrains.kotlin.konan.target.HostManager plugins { id("conventions") kotlin("multiplatform") - kotlin("plugin.serialization") version "2.1.0" + kotlin("plugin.serialization") version "2.2.0" id("nativebuild") } diff --git a/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeContextMixins.kt b/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeContextMixins.kt index f7f074f..9455e42 100644 --- a/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeContextMixins.kt +++ b/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeContextMixins.kt @@ -26,8 +26,8 @@ interface HasSessionNative: RequiresSession { // Use this for the transactor pool (that's what the RequiresTransactionality interface is for) // for reader connections we borrow readers override suspend fun newSession(options: UnusedOpts): Connection = pool.borrowWriter() - override fun closeSession(session: Connection): Unit = session.close() - override fun isClosedSession(session: Connection): Boolean = !session.isOpen() + override suspend fun closeSession(session: Connection): Unit = session.close() + override suspend fun isClosedSession(session: Connection): Boolean = !session.isOpen() override suspend fun accessStmtReturning(sql: String, conn: Connection, options: UnusedOpts, returningColumns: List, block: suspend (Statement) -> R): R { val stmt = conn.value.createStatement(sql) @@ -60,7 +60,7 @@ interface HasSessionNative: RequiresSession { // reader-needs-writer,writer-needs-reader scenario since the the coroutine that has // the writer session will use it as the reader (see hasOpenReadOnlyConnection which // doesn't care where the thing it has is a reader or writer). - override fun CoroutineContext.hasOpenConnection(): Boolean { + override suspend fun CoroutineContext.hasOpenConnection(): Boolean { val session = get(sessionKey)?.session return session != null && session.isWriter && !isClosedSession(session) } diff --git a/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeDatabaseController.kt b/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeDatabaseController.kt index 5640647..41bf96d 100644 --- a/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeDatabaseController.kt +++ b/controller-native/src/commonMain/kotlin/io/exoquery/controller/native/NativeDatabaseController.kt @@ -152,7 +152,7 @@ class NativeDatabaseController internal constructor( } // Is there an open writer? - override fun CoroutineContext.hasOpenConnection(): Boolean { + override suspend fun CoroutineContext.hasOpenConnection(): Boolean { val session = get(sessionKey)?.session //if (session != null) // println("--------- (${currentThreadId()}) Found session: ${if (session.isWriter) "WRITER" else "JUST READER, needs promotion" } - isClosed: ${isClosedSession(session)}") @@ -326,7 +326,7 @@ interface WithReadOnlyVerbs: RequiresSession suspend fun newReadOnlySession(): Connection = pool.borrowReader() // Check if there is at least a reader on th context, if it has a writer that's fine too - fun CoroutineContext.hasOpenReadOrWriteConnection(): Boolean { + suspend fun CoroutineContext.hasOpenReadOrWriteConnection(): Boolean { val session = get(sessionKey)?.session return session != null && !isClosedSession(session) } diff --git a/controller-r2dbc/build.gradle.kts b/controller-r2dbc/build.gradle.kts new file mode 100644 index 0000000..eefd5a1 --- /dev/null +++ b/controller-r2dbc/build.gradle.kts @@ -0,0 +1,45 @@ +import org.gradle.api.tasks.testing.logging.TestExceptionFormat + +plugins { + id("conventions") + kotlin("multiplatform") +} + +version = extra["controllerVersion"].toString() + +repositories { + mavenCentral() + mavenLocal() +} + +kotlin { + jvmToolchain(17) + jvm { + } + java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + } + sourceSets { + val jvmMain by getting { + kotlin.srcDir("src/main/kotlin") + resources.srcDir("src/main/resources") + dependencies { + api(project(":controller-core")) + // Coroutines + reactive bridge + api("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.8.1") + // R2DBC SPI only (no specific driver) + api("io.r2dbc:r2dbc-spi:1.0.0.RELEASE") + compileOnly("org.postgresql:r2dbc-postgresql:1.0.5.RELEASE") + } + } + val jvmTest by getting { + kotlin.srcDir("src/test/kotlin") + resources.srcDir("src/test/resources") + dependencies { + implementation(kotlin("test")) + } + } + } +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcController.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcController.kt new file mode 100644 index 0000000..d6e0f0c --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcController.kt @@ -0,0 +1,204 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.* +import io.r2dbc.spi.Connection +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.emptyFlow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.reactive.asFlow +import kotlinx.coroutines.reactive.awaitFirstOrNull +import kotlinx.coroutines.reactive.collect + +open class R2dbcController( + override val encodingConfig: R2dbcEncodingConfig = R2dbcEncodingConfig.Default(), + override val connectionFactory: ConnectionFactory +): + ControllerCanonical, + WithEncoding, + ControllerVerbs, + HasTransactionalityR2dbc +{ + override fun DefaultOpts(): R2dbcExecutionOptions = R2dbcExecutionOptions.Default() + + override val encodingApi: R2dbcSqlEncoding = + object: JavaSqlEncoding, + BasicEncoding by R2dbcBasicEncoding, + JavaTimeEncoding by R2dbcTimeEncoding, + JavaUuidEncoding by R2dbcUuidEncoding {} + + override val allEncoders: Set> by lazy { encodingApi.computeEncoders() + encodingConfig.additionalEncoders } + override val allDecoders: Set> by lazy { encodingApi.computeDecoders() + encodingConfig.additionalDecoders } + + private fun changePlaceholders(sql: String): String { + // R2DBC uses $1, $2... for placeholders + val sb = StringBuilder() + var paramIndex = 1 + var i = 0 + while (i < sql.length) { + val c = sql[i] + if (c == '?') { + sb.append('$').append(paramIndex) + paramIndex++ + i++ + } else { + sb.append(c) + i++ + } + } + return sb.toString() + } + + override fun extractColumnInfo(row: Row): List? { + val meta = row.metadata + val cols = meta.columnMetadatas + return cols.map { cmd -> + ColumnInfo(cmd.name, cmd.type.name) + } + } + + override suspend fun stream(act: ControllerQuery, options: R2dbcExecutionOptions): Flow = + flowWithConnection(options) { + val conn = localConnection() + val preparedSql = changePlaceholders(act.sql) + accessStmt(preparedSql, conn) { stmt -> + prepare(stmt, conn, act.params) + tryCatchQuery(preparedSql) { + val pub = stmt.execute() + val outputFlow = pub.awaitFirstOrNull()?.map { row, meta -> + val resultMaker = act.resultMaker.makeExtractor(QueryDebugInfo(act.sql)) + PubResult(resultMaker(conn, row)) + }?.asFlow()?.map { it.value } ?: emptyFlow() + emitAll(outputFlow) + } + } + } + + override suspend fun stream(act: ControllerBatchActionReturning, options: R2dbcExecutionOptions): Flow = + flowWithConnection(options) { + val conn = localConnection() + val preparedSql = changePlaceholders(act.sql) + // Create and execute a query for each param set and emit results from all queries into the flow + act.params.forEach { params -> + accessStmtReturning(preparedSql, conn, options, act.returningColumns) { stmt -> + prepare(stmt, conn, params) + tryCatchQuery(preparedSql) { + val pub = stmt.execute().awaitFirstOrNull() + val outputFlow = pub?.map { row, _ -> + val resultMaker = act.resultMaker.makeExtractor(QueryDebugInfo(act.sql)) + PubResult(resultMaker(conn, row)) + }?.asFlow()?.map { it.value } ?: emptyFlow() + // Need to actually emit the flow into the surrounding flow that holds the connection + emitAll(outputFlow) + } + } + } + } + + override suspend fun stream(act: ControllerActionReturning, options: R2dbcExecutionOptions): Flow = + flowWithConnection(options) { + val conn = localConnection() + val preparedSql = changePlaceholders(act.sql) + accessStmtReturning(preparedSql, conn, options, act.returningColumns) { stmt -> + prepare(stmt, conn, act.params) + tryCatchQuery(preparedSql) { + val pub = stmt.execute().awaitFirstOrNull() + val outputFlow = pub?.map { row, _ -> + val resultMaker = act.resultMaker.makeExtractor(QueryDebugInfo(act.sql)) + PubResult(resultMaker(conn, row)) + }?.asFlow()?.map { it.value } ?: emptyFlow() + // Need to actually emit the flow into the surrounding flow that holds the connection + emitAll(outputFlow) + } + } + } + + /** Need a temporary wrapper to work around limitation of pub-result being not-nullable */ + @JvmInline + private value class PubResult(val value: T) + + override suspend fun run(query: ControllerQuery, options: R2dbcExecutionOptions): List = + stream(query, options).toList() + + override suspend fun run(act: ControllerAction, options: R2dbcExecutionOptions): Long = + flowWithConnection(options) { + val conn = localConnection() + val preparedSql = changePlaceholders(act.sql) + accessStmt(preparedSql, conn) { stmt -> + prepare(stmt, conn, act.params) + // Execute and sum rowsUpdated across possibly multiple results + tryCatchQuery(preparedSql) { + val pub = stmt.execute() + val numRows = pub.awaitFirstOrNull()?.rowsUpdated?.awaitFirstOrNull() ?: 0L + emit(numRows) + } + } + }.first() + + override suspend fun run(query: ControllerBatchAction, options: R2dbcExecutionOptions): List = + flowWithConnection(options) { + val conn = localConnection() + // TODO this statement works very well with caching, should look into reusing statements across calls + val preparedSql = changePlaceholders(query.sql) + accessStmtReturning(preparedSql, conn, options, emptyList()) { stmt -> + tryCatchQuery(preparedSql) { + val iter = query.params.iterator() + while (iter.hasNext()) { + val batch = iter.next() + prepare(stmt, conn, batch) + // We need to put a `add` after every batch except for the last one + if (iter.hasNext()) { + stmt.add() + } + } + val pub = stmt.execute() + // Here using the asFlow and connect actually makes sense because multiple results are expected + pub.asFlow().collect { result -> + val updated = result.rowsUpdated.awaitFirstOrNull() ?: 0 + emit(updated) + } + } + } + }.toList() + + private inline fun tryCatchQuery(sql: String, op: () -> T): T = + try { + op() + } catch (e: Exception) { + if (e is ControllerError) throw e + else throw ControllerError("Error executing query: ${sql}", e) + } + + override suspend fun run(query: ControllerActionReturning, options: R2dbcExecutionOptions): T = + stream(query, options).toList().first() + + override suspend fun run(query: ControllerBatchActionReturning, options: R2dbcExecutionOptions): List = + stream(query, options).toList() + + override suspend fun runRaw(query: ControllerQuery, options: R2dbcExecutionOptions) = + flowWithConnection(options) { + val conn = localConnection() + val preparedSql = changePlaceholders(query.sql) + accessStmt(preparedSql, conn) { stmt -> + prepare(stmt, conn, query.params) + tryCatchQuery(preparedSql) { + val pub = stmt.execute() + val outputFlow = + pub.awaitFirstOrNull()?.map { row, meta -> + val cols = meta.columnMetadatas + cols.mapIndexed { i, md -> + val name = md.name + val value = row.get(i, Any::class.java) + name to value?.toString() + } + }?.asFlow() ?: emptyFlow>>() + emitAll(outputFlow) + } + } + }.toList() +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcControllerMixins.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcControllerMixins.kt new file mode 100644 index 0000000..64e07d0 --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcControllerMixins.kt @@ -0,0 +1,89 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.ControllerError +import io.exoquery.controller.CoroutineSession +import io.exoquery.controller.RequiresSession +import io.exoquery.controller.RequiresTransactionality +import io.exoquery.controller.jdbc.CoroutineTransaction +import io.r2dbc.spi.Connection +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement +import io.r2dbc.spi.ValidationDepth +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.reactive.awaitFirstOrNull +import kotlinx.coroutines.reactive.awaitSingle +import kotlinx.coroutines.withContext +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.coroutineContext + +object R2dbcCoroutineContext: CoroutineContext.Key> {} + +interface HasTransactionalityR2dbc: RequiresTransactionality, HasSessionR2dbc { + override val sessionKey: CoroutineContext.Key> get() = R2dbcCoroutineContext + + override suspend fun runTransactionally(block: suspend CoroutineScope.() -> T): T { + val session = coroutineContext[sessionKey]?.session ?: error("No connection found") + session.runWithManualCommit { + val transaction = CoroutineTransaction() + try { + val result = withContext(transaction) { block() } + commitTransaction() + return result + } catch (ex: Throwable) { + rollbackTransaction() + throw ex + } finally { + transaction.complete() + } + } + } +} + +internal inline fun Connection.runWithManualCommit(block: Connection.() -> T): T { + val before = this.isAutoCommit + + return try { + this.setAutoCommit(false) + this.run(block) + } finally { + this.setAutoCommit(before) + } +} + +interface HasSessionR2dbc: RequiresSession { + override val sessionKey: CoroutineContext.Key> get() = R2dbcCoroutineContext + val connectionFactory: io.r2dbc.spi.ConnectionFactory + + override suspend fun newSession(executionOptions: R2dbcExecutionOptions): Connection = + connectionFactory.create().awaitFirstOrNull() ?: error("Failed to create R2DBC connection") + + override suspend fun closeSession(session: Connection) = + session.close().awaitFirstOrNull().run { Unit } + + override suspend fun isClosedSession(session: Connection): Boolean = + session.validate(ValidationDepth.REMOTE).awaitFirstOrNull()?.let { it == false } ?: true // if null returned treat as closed + + override suspend fun accessStmt(sql: String, conn: Connection, block: suspend (Statement) -> R): R = + try { + block(conn.createStatement(sql)) + } catch (ex: Throwable) { + throw ControllerError("Error preparing statement: $sql", ex) + } + + override suspend fun accessStmtReturning(sql: String, conn: Connection, options: R2dbcExecutionOptions, returningColumns: List, block: suspend (Statement) -> R): R = + conn.createStatement(sql).let { + val preparedWithColumns = + if (returningColumns.isNotEmpty()) { + it.returnGeneratedValues(*returningColumns.toTypedArray()) + } else { + it + } + + val fetchSize = options.fetchSize + val preparedWithOptions = + (fetchSize?.let { preparedWithColumns.fetchSize(it) } ?: preparedWithColumns) + + block(preparedWithOptions) + } + +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcControllers.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcControllers.kt new file mode 100644 index 0000000..e1f4ed4 --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcControllers.kt @@ -0,0 +1,33 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.BasicEncoding +import io.exoquery.controller.JavaSqlEncoding +import io.exoquery.controller.JavaTimeEncoding +import io.exoquery.controller.JavaUuidEncoding +import io.exoquery.controller.SqlDecoder +import io.exoquery.controller.SqlEncoder +import io.r2dbc.spi.Connection +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement + +object R2dbcControllers { + class Postgres( + encodingConfig: R2dbcEncodingConfig = R2dbcEncodingConfig.Default(), + override val connectionFactory: ConnectionFactory + ): R2dbcController(encodingConfig,connectionFactory) { + + override val encodingConfig = + encodingConfig.copy( + additionalEncoders = encodingConfig.additionalEncoders + R2dbcPostgresAdditionalEncoding.encoders, + additionalDecoders = encodingConfig.additionalDecoders + R2dbcPostgresAdditionalEncoding.decoders + ) + + override val encodingApi: R2dbcSqlEncoding = + object: JavaSqlEncoding, + BasicEncoding by R2dbcBasicEncoding, + JavaTimeEncoding by R2dbcTimeEncoding, + JavaUuidEncoding by R2dbcUuidEncoding {} + } + +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcDecoders.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcDecoders.kt new file mode 100644 index 0000000..904acba --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcDecoders.kt @@ -0,0 +1,58 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.DecoderAny +import io.exoquery.controller.SqlDecoder +import io.r2dbc.spi.Connection +import io.r2dbc.spi.Row +import kotlinx.datetime.toKotlinInstant +import kotlinx.datetime.toKotlinLocalDate +import kotlinx.datetime.toKotlinLocalDateTime +import kotlinx.datetime.toKotlinLocalTime +import java.time.* +import java.util.* +import kotlin.reflect.KClass + +class R2dbcDecoderAny( + override val type: KClass, + override val f: (R2dbcDecodingContext, Int) -> T? +): DecoderAny( + type, + { index, row -> + row.get(index) == null + }, + f + ) { +} + + +object R2dbcDecoders { + @Suppress("UNCHECKED_CAST") + val decoders: Set> = setOf( + R2dbcBasicEncoding.BooleanDecoder, + R2dbcBasicEncoding.ByteDecoder, + R2dbcBasicEncoding.CharDecoder, + R2dbcBasicEncoding.DoubleDecoder, + R2dbcBasicEncoding.FloatDecoder, + R2dbcBasicEncoding.IntDecoder, + R2dbcBasicEncoding.LongDecoder, + R2dbcBasicEncoding.ShortDecoder, + R2dbcBasicEncoding.StringDecoder, + R2dbcBasicEncoding.ByteArrayDecoder, + + R2dbcTimeEncoding.LocalDateDecoder, + R2dbcTimeEncoding.LocalDateTimeDecoder, + R2dbcTimeEncoding.LocalTimeDecoder, + R2dbcTimeEncoding.InstantDecoder, + R2dbcTimeEncoding.JLocalDateDecoder, + R2dbcTimeEncoding.JLocalTimeDecoder, + R2dbcTimeEncoding.JLocalDateTimeDecoder, + R2dbcTimeEncoding.JZonedDateTimeDecoder, + R2dbcTimeEncoding.JInstantDecoder, + R2dbcTimeEncoding.JOffsetTimeDecoder, + R2dbcTimeEncoding.JOffsetDateTimeDecoder, + R2dbcTimeEncoding.JDateDecoder, + R2dbcUuidEncoding.JUuidDecoder, + + R2dbcAdditionalEncoding.BigDecimalDecoder + ) +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncoders.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncoders.kt new file mode 100644 index 0000000..f4c7dbd --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncoders.kt @@ -0,0 +1,216 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.BasicEncoding +import io.exoquery.controller.EncoderAny +import io.exoquery.controller.JavaTimeEncoding +import io.exoquery.controller.JavaUuidEncoding +import io.exoquery.controller.SqlDecoder +import io.exoquery.controller.SqlEncoder +import io.exoquery.controller.SqlJson +import io.exoquery.controller.r2dbc.R2dbcTimeEncoding.NA +import io.r2dbc.spi.Connection +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement +import kotlinx.datetime.toJavaLocalDate +import kotlinx.datetime.toJavaLocalDateTime +import kotlinx.datetime.toJavaLocalTime +import kotlinx.datetime.toJavaInstant +import kotlinx.datetime.toJavaZoneId +import kotlinx.datetime.toKotlinInstant +import kotlinx.datetime.toKotlinLocalDate +import kotlinx.datetime.toKotlinLocalDateTime +import kotlinx.datetime.toKotlinLocalTime +import java.sql.Types +import java.time.* +import java.util.* +import kotlin.reflect.KClass + +// Note: R2DBC has no java.sql.Types. We keep an Int id for compatibility but do not use it. +class R2dbcEncoderAny( + override val dataType: Int, + override val type: KClass, + override val f: (R2dbcEncodingContext, T, Int) -> Unit, +): EncoderAny( + dataType, type, + { i, stmt, _ -> + // Always use boxed reference types for nulls to satisfy R2DBC drivers (e.g., Postgres) + stmt.bindNull(i, type.javaObjectType) + }, + f +) + +object R2dbcBasicEncoding: BasicEncoding { + private const val NA = 0 + + override val BooleanEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Boolean::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val ByteEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Byte::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val CharEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Char::class) { ctx, v, i -> ctx.stmt.bind(i, v.toString()) } + override val DoubleEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Double::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val FloatEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Float::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val IntEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Int::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val LongEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Long::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val ShortEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Short::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val StringEncoder: SqlEncoder = + R2dbcEncoderAny(NA, String::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val ByteArrayEncoder: SqlEncoder = + R2dbcEncoderAny(NA, ByteArray::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + + override fun preview(index: Int, row: Row): String? = + row.get(index)?.let { it.toString() } + override fun isNull(index: Int, row: Row): Boolean = + row.get(index) == null + + override val BooleanDecoder: SqlDecoder = + R2dbcDecoderAny(Boolean::class) { ctx, i -> ctx.row.get(i, java.lang.Boolean::class.java)?.booleanValue() } + override val ByteDecoder: SqlDecoder = + R2dbcDecoderAny(Byte::class) { ctx, i -> ctx.row.get(i, java.lang.Byte::class.java)?.toByte() } + override val CharDecoder: SqlDecoder = + R2dbcDecoderAny(Char::class) { ctx, i -> ctx.row.get(i, String::class.java)?.let { it[0] } ?: Char.MIN_VALUE } + override val DoubleDecoder: SqlDecoder = + R2dbcDecoderAny(Double::class) { ctx, i -> ctx.row.get(i, java.lang.Double::class.java)?.toDouble() } + override val FloatDecoder: SqlDecoder = + R2dbcDecoderAny(Float::class) { ctx, i -> ctx.row.get(i, java.lang.Float::class.java)?.toFloat() } + override val IntDecoder: SqlDecoder = + R2dbcDecoderAny(Int::class) { ctx, i -> ctx.row.get(i, java.lang.Integer::class.java)?.toInt() } + override val LongDecoder: SqlDecoder = + R2dbcDecoderAny(Long::class) { ctx, i -> ctx.row.get(i, java.lang.Long::class.java)?.toLong() } + override val ShortDecoder: SqlDecoder = + R2dbcDecoderAny(Short::class) { ctx, i -> ctx.row.get(i, java.lang.Short::class.java)?.toShort() } + override val StringDecoder: SqlDecoder = + R2dbcDecoderAny(String::class) { ctx, i -> ctx.row.get(i, String::class.java) } + override val ByteArrayDecoder: SqlDecoder = + R2dbcDecoderAny(ByteArray::class) { ctx, i -> ctx.row.get(i, ByteArray::class.java) } +} + +private fun kotlinx.datetime.TimeZone.toJava(): TimeZone = TimeZone.getTimeZone(this.toJavaZoneId()) + +object R2dbcTimeEncoding: JavaTimeEncoding { + private const val NA = 0 + + // KMP datetime -> convert to java.time before binding + override val LocalDateEncoder: SqlEncoder = + R2dbcEncoderAny(NA, kotlinx.datetime.LocalDate::class) { ctx, v, i -> + ctx.stmt.bind(i, v.toJavaLocalDate()) + } + override val LocalDateTimeEncoder: SqlEncoder = + R2dbcEncoderAny(NA, kotlinx.datetime.LocalDateTime::class) { ctx, v, i -> + ctx.stmt.bind(i, v.toJavaLocalDateTime()) + } + override val LocalTimeEncoder: SqlEncoder = + R2dbcEncoderAny(NA, kotlinx.datetime.LocalTime::class) { ctx, v, i -> + ctx.stmt.bind(i, v.toJavaLocalTime()) + } + override val InstantEncoder: SqlEncoder = + R2dbcEncoderAny(NA, kotlinx.datetime.Instant::class) { ctx, v, i -> + ctx.stmt.bind(i, v.toJavaInstant()) + } + + // Java time types can be bound directly + override val JLocalDateEncoder: SqlEncoder = + R2dbcEncoderAny(NA, LocalDate::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val JLocalTimeEncoder: SqlEncoder = + R2dbcEncoderAny(NA, LocalTime::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val JLocalDateTimeEncoder: SqlEncoder = + R2dbcEncoderAny(NA, LocalDateTime::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val JZonedDateTimeEncoder: SqlEncoder = + R2dbcEncoderAny(NA, ZonedDateTime::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val JInstantEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Instant::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val JOffsetTimeEncoder: SqlEncoder = + R2dbcEncoderAny(NA, OffsetTime::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + override val JOffsetDateTimeEncoder: SqlEncoder = + R2dbcEncoderAny(NA, OffsetDateTime::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + + // java.util.Date -> bind as Instant (supported type) + override val JDateEncoder: SqlEncoder = + R2dbcEncoderAny(NA, Date::class) { ctx, v, i -> ctx.stmt.bind(i, Instant.ofEpochMilli(v.getTime())) } + + // KMP datetime decoders via java.time + override val LocalDateDecoder: SqlDecoder = + R2dbcDecoderAny(kotlinx.datetime.LocalDate::class) { ctx, i -> ctx.row.get(i, LocalDate::class.java)?.toKotlinLocalDate() } + override val LocalDateTimeDecoder: SqlDecoder = + R2dbcDecoderAny(kotlinx.datetime.LocalDateTime::class) { ctx, i -> ctx.row.get(i, LocalDateTime::class.java)?.toKotlinLocalDateTime() } + override val LocalTimeDecoder: SqlDecoder = + R2dbcDecoderAny(kotlinx.datetime.LocalTime::class) { ctx, i -> ctx.row.get(i, LocalTime::class.java)?.toKotlinLocalTime() } + override val InstantDecoder: SqlDecoder = + R2dbcDecoderAny(kotlinx.datetime.Instant::class) { ctx, i -> ctx.row.get(i, OffsetDateTime::class.java)?.toInstant()?.toKotlinInstant() } + + // Java time decoders + override val JLocalDateDecoder: SqlDecoder = + R2dbcDecoderAny(LocalDate::class) { ctx, i -> ctx.row.get(i, LocalDate::class.java) } + override val JLocalTimeDecoder: SqlDecoder = + R2dbcDecoderAny(LocalTime::class) { ctx, i -> ctx.row.get(i, LocalTime::class.java) } + override val JLocalDateTimeDecoder: SqlDecoder = + R2dbcDecoderAny(LocalDateTime::class) { ctx, i -> ctx.row.get(i, LocalDateTime::class.java) } + override val JZonedDateTimeDecoder: SqlDecoder = + R2dbcDecoderAny(ZonedDateTime::class) { ctx, i -> ctx.row.get(i, OffsetDateTime::class.java)?.toZonedDateTime() } + override val JInstantDecoder: SqlDecoder = + R2dbcDecoderAny(Instant::class) { ctx, i -> ctx.row.get(i, OffsetDateTime::class.java)?.toInstant() } + override val JOffsetTimeDecoder: SqlDecoder = + R2dbcDecoderAny(OffsetTime::class) { ctx, i -> ctx.row.get(i, OffsetTime::class.java) } + override val JOffsetDateTimeDecoder: SqlDecoder = + R2dbcDecoderAny(OffsetDateTime::class) { ctx, i -> ctx.row.get(i, OffsetDateTime::class.java) } + + // java.util.Date from Instant + override val JDateDecoder: SqlDecoder = + R2dbcDecoderAny(Date::class) { ctx, i -> ctx.row.get(i, Instant::class.java)?.let { Date.from(it) } } +} + +object R2dbcUuidEncoding: JavaUuidEncoding { + private const val NA = 0 + + override val JUuidEncoder: SqlEncoder = + R2dbcEncoderAny(NA, UUID::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + + override val JUuidDecoder: SqlDecoder = + R2dbcDecoderAny(UUID::class) { ctx, i -> ctx.row.get(i, UUID::class.java) } +} + +object R2dbcAdditionalEncoding { + private const val NA = 0 + + val BigDecimalEncoder: R2dbcEncoderAny = + R2dbcEncoderAny(NA, java.math.BigDecimal::class) { ctx, v, i -> ctx.stmt.bind(i, v) } + val BigDecimalDecoder: R2dbcDecoderAny = + R2dbcDecoderAny(java.math.BigDecimal::class) { ctx, i -> ctx.row.get(i, java.math.BigDecimal::class.java) } +} + +object R2dbcEncoders { + @Suppress("UNCHECKED_CAST") + val encoders: Set> = setOf( + R2dbcBasicEncoding.BooleanEncoder, + R2dbcBasicEncoding.ByteEncoder, + R2dbcBasicEncoding.CharEncoder, + R2dbcBasicEncoding.DoubleEncoder, + R2dbcBasicEncoding.FloatEncoder, + R2dbcBasicEncoding.IntEncoder, + R2dbcBasicEncoding.LongEncoder, + R2dbcBasicEncoding.ShortEncoder, + R2dbcBasicEncoding.StringEncoder, + R2dbcBasicEncoding.ByteArrayEncoder, + + R2dbcTimeEncoding.LocalDateEncoder, + R2dbcTimeEncoding.LocalDateTimeEncoder, + R2dbcTimeEncoding.LocalTimeEncoder, + R2dbcTimeEncoding.InstantEncoder, + R2dbcTimeEncoding.JLocalDateEncoder, + R2dbcTimeEncoding.JLocalTimeEncoder, + R2dbcTimeEncoding.JLocalDateTimeEncoder, + R2dbcTimeEncoding.JZonedDateTimeEncoder, + R2dbcTimeEncoding.JInstantEncoder, + R2dbcTimeEncoding.JOffsetTimeEncoder, + R2dbcTimeEncoding.JOffsetDateTimeEncoder, + R2dbcTimeEncoding.JDateEncoder, + R2dbcUuidEncoding.JUuidEncoder, + + R2dbcAdditionalEncoding.BigDecimalEncoder + ) +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncodingConfig.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncodingConfig.kt new file mode 100644 index 0000000..00db3ea --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncodingConfig.kt @@ -0,0 +1,65 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.EncodingConfig +import io.exoquery.controller.SqlDecoder +import io.exoquery.controller.SqlEncoder +import io.r2dbc.spi.Connection +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement +import kotlinx.datetime.TimeZone +import kotlinx.serialization.json.Json +import kotlinx.serialization.modules.EmptySerializersModule +import kotlinx.serialization.modules.SerializersModule + +/** + * Mirrors JdbcEncodingConfig: provides factory helpers and defaults that include built-in + * R2DBC encoders/decoders unless Empty() is used. + */ + data class R2dbcEncodingConfig private constructor( + override val additionalEncoders: Set>, + override val additionalDecoders: Set>, + override val json: Json, + override val module: SerializersModule, + override val timezone: TimeZone, + override val debugMode: Boolean +): EncodingConfig { + companion object { + val Default get() = + Default( + R2dbcEncoders.encoders, + R2dbcDecoders.decoders + ) + + fun Default( + additionalEncoders: Set> = setOf(), + additionalDecoders: Set> = setOf(), + json: Json = Json, + module: SerializersModule = EmptySerializersModule(), + timezone: TimeZone = TimeZone.currentSystemDefault(), + debugMode: Boolean = false + ) = R2dbcEncodingConfig( + additionalEncoders + R2dbcEncoders.encoders, + additionalDecoders + R2dbcDecoders.decoders, + json, + module, + timezone, + debugMode + ) + + operator fun invoke( + additionalEncoders: Set> = setOf(), + additionalDecoders: Set> = setOf(), + json: Json = Json, + module: SerializersModule = EmptySerializersModule(), + timezone: TimeZone = TimeZone.currentSystemDefault() + ) = Default(additionalEncoders, additionalDecoders, json, module, timezone) + + fun Empty( + additionalEncoders: Set> = setOf(), + additionalDecoders: Set> = setOf(), + json: Json = Json, + module: SerializersModule = EmptySerializersModule(), + timezone: TimeZone = TimeZone.currentSystemDefault() + ) = R2dbcEncodingConfig(additionalEncoders, additionalDecoders, json, module, timezone, debugMode = false) + } +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncodingContext.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncodingContext.kt new file mode 100644 index 0000000..6c08656 --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcEncodingContext.kt @@ -0,0 +1,10 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.DecodingContext +import io.exoquery.controller.EncodingContext +import io.r2dbc.spi.Connection +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement + +typealias R2dbcEncodingContext = EncodingContext +typealias R2dbcDecodingContext = DecodingContext diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcExecutionOptions.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcExecutionOptions.kt new file mode 100644 index 0000000..fc23d39 --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcExecutionOptions.kt @@ -0,0 +1,15 @@ +package io.exoquery.controller.r2dbc + +import com.sun.jdi.connect.spi.Connection +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement + +data class R2dbcExecutionOptions( + val sessionTimeout: Int? = null, + val fetchSize: Int? = null, + val queryTimeout: Int? = null +) { + companion object { + fun Default() = R2dbcExecutionOptions() + } +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcPostgresAdditionalEncoding.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcPostgresAdditionalEncoding.kt new file mode 100644 index 0000000..44740d3 --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcPostgresAdditionalEncoding.kt @@ -0,0 +1,12 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.SqlJson + +object R2dbcPostgresAdditionalEncoding { + private const val NA = 0 + val SqlJsonEncoder: R2dbcEncoderAny = R2dbcEncoderAny(NA, SqlJson::class) { ctx, v, i -> ctx.stmt.bind(i, io.r2dbc.postgresql.codec.Json.of(v.value)) } + val SqlJsonDecoder: R2dbcDecoderAny = R2dbcDecoderAny(SqlJson::class) { ctx, i -> SqlJson(ctx.row.get(i, io.r2dbc.postgresql.codec.Json::class.java).asString()) } + + val encoders: Set> = setOf(SqlJsonEncoder) + val decoders: Set> = setOf(SqlJsonDecoder) +} diff --git a/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcSqlEncoding.kt b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcSqlEncoding.kt new file mode 100644 index 0000000..040b391 --- /dev/null +++ b/controller-r2dbc/src/main/kotlin/io/exoquery/controller/r2dbc/R2dbcSqlEncoding.kt @@ -0,0 +1,8 @@ +package io.exoquery.controller.r2dbc + +import io.exoquery.controller.SqlEncoding +import io.r2dbc.spi.Connection +import io.r2dbc.spi.Row +import io.r2dbc.spi.Statement + +typealias R2dbcSqlEncoding = SqlEncoding diff --git a/scripts/start.sh b/scripts/start.sh index 0c2efd2..35849e7 100755 --- a/scripts/start.sh +++ b/scripts/start.sh @@ -1,3 +1,3 @@ #!/bin/bash -docker-compose down && docker-compose build && docker-compose run --rm --service-ports setup +docker compose down && docker compose build && docker compose run --rm --service-ports setup diff --git a/settings.gradle.kts b/settings.gradle.kts index 139684f..c74a6e8 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -13,12 +13,13 @@ include("controller-core") include("controller-jdbc") include("controller-native") include("controller-android") +include("controller-r2dbc") include("terpal-sql-core") include("terpal-sql-core-testing") include("terpal-sql-jdbc") include("terpal-sql-native") include("terpal-sql-android") - +include("terpal-sql-r2dbc") rootProject.name = "terpal-sql" diff --git a/terpal-sql-android/build.gradle.kts b/terpal-sql-android/build.gradle.kts index 19b35e2..addb7ee 100644 --- a/terpal-sql-android/build.gradle.kts +++ b/terpal-sql-android/build.gradle.kts @@ -8,8 +8,8 @@ plugins { id("conventions") kotlin("multiplatform") id("com.android.library") - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" // Already on the classpath //id("org.jetbrains.kotlin.android") version "1.9.23" } @@ -52,8 +52,10 @@ kotlin { androidTarget { compilations.all { - kotlinOptions { - jvmTarget = "17" + compileTaskProvider { + compilerOptions { + jvmTarget.set(JvmTarget.JVM_17) + } } } publishLibraryVariants("release", "debug") diff --git a/terpal-sql-core-testing/build.gradle.kts b/terpal-sql-core-testing/build.gradle.kts index 053f6b5..77e2b28 100644 --- a/terpal-sql-core-testing/build.gradle.kts +++ b/terpal-sql-core-testing/build.gradle.kts @@ -4,8 +4,8 @@ import org.jetbrains.kotlin.gradle.dsl.JvmTarget plugins { kotlin("multiplatform") - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" id("nativebuild") } diff --git a/terpal-sql-core/build.gradle.kts b/terpal-sql-core/build.gradle.kts index ba444bd..50f5e59 100644 --- a/terpal-sql-core/build.gradle.kts +++ b/terpal-sql-core/build.gradle.kts @@ -5,8 +5,8 @@ import org.jetbrains.kotlin.gradle.dsl.JvmTarget plugins { id("conventions") kotlin("multiplatform") - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" id("nativebuild") } diff --git a/terpal-sql-core/src/commonMain/kotlin/io/exoquery/sql/Statement.kt b/terpal-sql-core/src/commonMain/kotlin/io/exoquery/sql/Statement.kt index 7117ecd..d5b36a5 100644 --- a/terpal-sql-core/src/commonMain/kotlin/io/exoquery/sql/Statement.kt +++ b/terpal-sql-core/src/commonMain/kotlin/io/exoquery/sql/Statement.kt @@ -8,6 +8,8 @@ import io.exoquery.controller.TerpalSqlInternal import kotlinx.serialization.KSerializer import kotlinx.serialization.serializer + + fun Param.toStatementParam(): StatementParam = StatementParam(serializer, cls, value) diff --git a/terpal-sql-jdbc/build.gradle.kts b/terpal-sql-jdbc/build.gradle.kts index c1f8eca..7a8e997 100644 --- a/terpal-sql-jdbc/build.gradle.kts +++ b/terpal-sql-jdbc/build.gradle.kts @@ -1,11 +1,10 @@ import org.gradle.api.tasks.testing.logging.TestExceptionFormat -import org.jetbrains.kotlin.gradle.dsl.KotlinCompile plugins { id("conventions") kotlin("multiplatform") - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" } val thisVersion = version diff --git a/terpal-sql-native/build.gradle.kts b/terpal-sql-native/build.gradle.kts index 6531d4b..acdbd56 100644 --- a/terpal-sql-native/build.gradle.kts +++ b/terpal-sql-native/build.gradle.kts @@ -8,8 +8,8 @@ import org.jetbrains.kotlin.konan.target.HostManager plugins { id("conventions") kotlin("multiplatform") - id("io.exoquery.terpal-plugin") version "2.1.0-2.0.0.PL" - kotlin("plugin.serialization") version "2.1.0" + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" id("nativebuild") } diff --git a/terpal-sql-r2dbc/build.gradle.kts b/terpal-sql-r2dbc/build.gradle.kts new file mode 100644 index 0000000..76f9a13 --- /dev/null +++ b/terpal-sql-r2dbc/build.gradle.kts @@ -0,0 +1,81 @@ +import org.gradle.api.tasks.testing.logging.TestExceptionFormat + +plugins { + id("conventions") + kotlin("multiplatform") + id("io.exoquery.terpal-plugin") version "2.2.0-2.0.0.PL" + kotlin("plugin.serialization") version "2.2.0" +} + +val thisVersion = version + +tasks.withType().configureEach { + compilerOptions { + freeCompilerArgs.add("-Xcontext-receivers") + java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + } + } +} + +repositories { + mavenCentral() + mavenLocal() +} + +tasks.withType().configureEach { + useJUnitPlatform() + testLogging { + exceptionFormat = TestExceptionFormat.FULL + } +} + +kotlin { + jvmToolchain(17) + jvm { + } + + java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + } + + sourceSets { + val jvmMain by getting { + kotlin.srcDir("src/main/kotlin") + resources.srcDir("src/main/resources") + + dependencies { + api(project(":terpal-sql-core")) + api(project(":controller-r2dbc")) + + api("org.jetbrains.kotlinx:kotlinx-serialization-core:1.6.2") + api("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.2") + api("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.8.1") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.8.1") + + // R2DBC Postgres driver + api("org.postgresql:r2dbc-postgresql:1.0.5.RELEASE") + } + } + val jvmTest by getting { + kotlin.srcDir("src/test/kotlin") + resources.srcDir("src/test/resources") + + dependencies { + api(project(":controller-r2dbc")) + api(project(":terpal-sql-core-testing")) + + implementation("io.exoquery:pprint-kotlin:2.0.2") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.8.1") + implementation("io.kotest:kotest-runner-junit5:5.9.1") + + // Embedded Postgres for tests (same as JDBC module) + implementation("io.zonky.test:embedded-postgres:2.0.7") + implementation("io.zonky.test.postgres:embedded-postgres-binaries-linux-amd64:16.2.0") + implementation("org.flywaydb:flyway-core:7.15.0") + } + } + } +} diff --git a/terpal-sql-r2dbc/src/main/kotlin/io/exoquery/sql/ParamExtensions.kt b/terpal-sql-r2dbc/src/main/kotlin/io/exoquery/sql/ParamExtensions.kt new file mode 100644 index 0000000..ab16636 --- /dev/null +++ b/terpal-sql-r2dbc/src/main/kotlin/io/exoquery/sql/ParamExtensions.kt @@ -0,0 +1,25 @@ +package io.exoquery.sql + +import kotlinx.serialization.ContextualSerializer +import java.math.BigDecimal +import java.sql.Date +import java.sql.Time +import java.sql.Timestamp +import java.time.ZonedDateTime +import java.util.* +import java.time.* +import kotlinx.serialization.ExperimentalSerializationApi as SerApi + +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: LocalDate?): Param = Param(ContextualSerializer(LocalDate::class), LocalDate::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: LocalTime?): Param = Param(ContextualSerializer(LocalTime::class), LocalTime::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: LocalDateTime?): Param = Param(ContextualSerializer(LocalDateTime::class), LocalDateTime::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: Instant?): Param = Param(ContextualSerializer(Instant::class), Instant::class, value) + +@OptIn(SerApi::class) fun Param.Companion.fromUtilDate(value: java.util.Date?): Param = Param(ContextualSerializer(java.util.Date::class), java.util.Date::class, value) +@OptIn(SerApi::class) fun Param.Companion.fromSqlDate(value: java.sql.Date?): Param = Param(ContextualSerializer(java.sql.Date::class), java.sql.Date::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: java.sql.Time?): Param = Param(ContextualSerializer(java.sql.Time::class), java.sql.Time::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: java.sql.Timestamp?): Param = Param(ContextualSerializer(Timestamp::class), Timestamp::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: BigDecimal?): Param = Param(ContextualSerializer(BigDecimal::class), BigDecimal::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: ZonedDateTime?): Param = Param(ContextualSerializer(ZonedDateTime::class), ZonedDateTime::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: OffsetTime?): Param = Param(ContextualSerializer(OffsetTime::class), OffsetTime::class, value) +@OptIn(SerApi::class) operator fun Param.Companion.invoke(value: OffsetDateTime?): Param = Param(ContextualSerializer(OffsetDateTime::class), OffsetDateTime::class, value) diff --git a/terpal-sql-r2dbc/src/main/kotlin/io/exoquery/sql/Wrappers.kt b/terpal-sql-r2dbc/src/main/kotlin/io/exoquery/sql/Wrappers.kt new file mode 100644 index 0000000..76565b1 --- /dev/null +++ b/terpal-sql-r2dbc/src/main/kotlin/io/exoquery/sql/Wrappers.kt @@ -0,0 +1,30 @@ +package io.exoquery.sql + +import io.exoquery.terpal.StrictType +import java.math.BigDecimal +import java.time.* + +fun SqlInterpolator.wrap(value: BigDecimal?): Param = Param.contextual(value) +fun SqlInterpolator.wrap(value: ByteArray?): Param = Param(value) + +// It's a bit crazy but all the java.sql.* types are a subtype of this +// so we want it to only match a strict java.util.Date parameter +@StrictType +fun SqlInterpolator.wrap(value: java.util.Date?): Param = Param.fromUtilDate(value) + +fun SqlInterpolator.wrap(value: java.sql.Date?): Param = Param.fromSqlDate(value) +fun SqlInterpolator.wrap(value: java.sql.Time?): Param = Param(value) +fun SqlInterpolator.wrap(value: java.sql.Timestamp?): Param = Param(value) + +fun SqlInterpolator.wrap(value: kotlinx.datetime.LocalDate?): Param = Param(value) +fun SqlInterpolator.wrap(value: kotlinx.datetime.LocalTime?): Param = Param(value) +fun SqlInterpolator.wrap(value: kotlinx.datetime.LocalDateTime?): Param = Param(value) +fun SqlInterpolator.wrap(value: kotlinx.datetime.Instant?): Param = Param(value) + +fun SqlInterpolator.wrap(value: LocalDate?): Param = Param.contextual(value) +fun SqlInterpolator.wrap(value: LocalTime?): Param = Param.contextual(value) +fun SqlInterpolator.wrap(value: LocalDateTime?): Param = Param.contextual(value) +fun SqlInterpolator.wrap(value: ZonedDateTime?): Param = Param.contextual(value) +fun SqlInterpolator.wrap(value: Instant?): Param = Param.contextual(value) +fun SqlInterpolator.wrap(value: OffsetTime?): Param = Param.contextual(value) +fun SqlInterpolator.wrap(value: OffsetDateTime?): Param = Param.contextual(value) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/BatchActionSpecData.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/BatchActionSpecData.kt new file mode 100644 index 0000000..23de474 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/BatchActionSpecData.kt @@ -0,0 +1,44 @@ +package io.exoquery.r2dbc + +import io.exoquery.sql.Sql +import io.exoquery.sql.SqlBatch + +private fun id(t: T) = t + +object Ex1_BatchInsertNormal { + val products = makeProducts(22) + val op = + SqlBatch { p: Product -> "INSERT INTO Product (id, description, sku) VALUES (${p.id}, ${p.description}, ${p.sku})" } + .values(products.asSequence()).action() + val get = Sql("SELECT id, description, sku FROM Product").queryOf() + val result = products +} + +object Ex2_BatchInsertMixed { + val products = makeProducts(20) + val op = + SqlBatch { p: Product -> "INSERT INTO Product (id, description, sku) VALUES (${p.id}, ${id("BlahBlah")}, ${p.sku})" } + .values(products.asSequence()).action() + val get = Sql("SELECT id, description, sku FROM Product").queryOf() + val result = products.map { it.copy(description = "BlahBlah") } +} + +object Ex3_BatchReturnIds { + val products = makeProducts(20) + val op = + SqlBatch { p: Product -> "INSERT INTO Product (description, sku) VALUES (${p.description}, ${p.sku}) RETURNING id" } + .values(products.asSequence()).actionReturning() + val get = Sql("SELECT id, description, sku FROM Product").queryOf() + val opResult = (1..20).toList() + val result = products.mapIndexed { i, p -> p.copy(id = i + 1) } +} + +object Ex4_BatchReturnRecord { + val products = makeProducts(20) + val op = + SqlBatch { p: Product -> "INSERT INTO Product (description, sku) VALUES (${p.description}, ${p.sku}) RETURNING id, description, sku" } + .values(products.asSequence()).actionReturning() + val get = Sql("SELECT id, description, sku FROM Product").queryOf() + val opResult = products.mapIndexed { i, p -> p.copy(id = i + 1) } + val result = opResult +} diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/KotestProjectConfig.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/KotestProjectConfig.kt new file mode 100644 index 0000000..1694685 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/KotestProjectConfig.kt @@ -0,0 +1,15 @@ +package io.exoquery.r2dbc + +import io.kotest.core.config.AbstractProjectConfig + +object KotestProjectConfig : AbstractProjectConfig() { + override suspend fun beforeProject() { + // Ensure EmbeddedPostgres is started before any specs run + TestDatabasesR2dbc.embeddedPostgres + } + + override suspend fun afterProject() { + // Ensure EmbeddedPostgres is closed after all specs complete + try { TestDatabasesR2dbc.embeddedPostgres.close() } catch (_: Throwable) {} + } +} diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/Model.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/Model.kt new file mode 100644 index 0000000..7a2a8fa --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/Model.kt @@ -0,0 +1,8 @@ +package io.exoquery.r2dbc + +import kotlinx.serialization.Serializable + +@Serializable +data class Product(val id: Int, val description: String, val sku: Long) + +fun makeProducts(num: Int): List = (1..num).map { Product(it, "Product-$it", it.toLong()) } diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/TestDatabasesR2dbc.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/TestDatabasesR2dbc.kt new file mode 100644 index 0000000..e27d29f --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/TestDatabasesR2dbc.kt @@ -0,0 +1,42 @@ +package io.exoquery.r2dbc + +import io.r2dbc.spi.ConnectionFactory +import io.r2dbc.spi.ConnectionFactories +import io.r2dbc.spi.ConnectionFactoryOptions +import io.zonky.test.db.postgres.embedded.EmbeddedPostgres + +object TestDatabasesR2dbc { + val embeddedPostgres: EmbeddedPostgres by lazy { + val started = EmbeddedPostgres.builder().start() + val postgresScriptsPath = "/db/postgres-schema.sql" + val resource = this::class.java.getResource(postgresScriptsPath) + if (resource == null) throw NullPointerException("The postgres script path `$postgresScriptsPath` was not found") + val postgresScript = resource.readText() + started.postgresDatabase.connection.use { conn -> + val commands = postgresScript.split(';') + commands.filter { it.isNotBlank() }.forEach { cmd -> + conn.prepareStatement(cmd).execute() + } + } + started + } + + val postgres: ConnectionFactory by lazy { + val ep = embeddedPostgres + val host = "localhost" + val port = ep.port + val db = "postgres" + val user = "postgres" + ConnectionFactories.get( + ConnectionFactoryOptions.builder() + .option(ConnectionFactoryOptions.DRIVER, "postgresql") + .option(ConnectionFactoryOptions.HOST, host) + .option(ConnectionFactoryOptions.PORT, port) + .option(ConnectionFactoryOptions.DATABASE, db) + .option(ConnectionFactoryOptions.USER, user) + // Provide password if needed; EmbeddedPostgres default often doesn't require it + // .option(io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD, "password") + .build() + ) + } +} diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/JavaEntities.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/JavaEntities.kt new file mode 100644 index 0000000..237f582 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/JavaEntities.kt @@ -0,0 +1,61 @@ +package io.exoquery.r2dbc.encodingdata + +import io.exoquery.controller.ControllerAction +import io.exoquery.sql.Param +import io.exoquery.sql.Sql +import io.exoquery.sql.encodingdata.shouldBeEqual +import io.exoquery.sql.encodingdata.shouldBeEqualNullable +import io.kotest.matchers.bigdecimal.shouldBeEqualIgnoringScale +import kotlinx.serialization.Contextual +import kotlinx.serialization.Serializable +import java.math.BigDecimal +import java.sql.Date +import java.time.LocalDateTime +import java.time.ZoneOffset +import java.util.* + +@Serializable +data class JavaTestEntity( + @Contextual val bigDecimalMan: BigDecimal, + @Contextual val javaUtilDateMan: java.util.Date, + @Contextual val uuidMan: UUID, + @Contextual val bigDecimalOpt: BigDecimal?, + @Contextual val javaUtilDateOpt: java.util.Date?, + @Contextual val uuidOpt: UUID? +) { + companion object { + val regular = + JavaTestEntity( + BigDecimal("1.1"), + Date.from(LocalDateTime.of(2013, 11, 23, 0, 0, 0, 0).toInstant(ZoneOffset.UTC)), + UUID.randomUUID(), + BigDecimal("1.1"), + Date.from(LocalDateTime.of(2013, 11, 23, 0, 0, 0, 0).toInstant(ZoneOffset.UTC)), + UUID.randomUUID() + ) + + val empty = + JavaTestEntity( + BigDecimal.ZERO, + Date(0), + UUID(0, 0), + null, + null, + null + ) + } +} + +fun insert(e: JavaTestEntity): ControllerAction { + fun wrap(value: UUID?): Param = Param.ctx(value) + return Sql("INSERT INTO JavaTestEntity VALUES (${e.bigDecimalMan}, ${e.javaUtilDateMan}, ${wrap(e.uuidMan)}, ${e.bigDecimalOpt}, ${e.javaUtilDateOpt}, ${wrap(e.uuidOpt)})").action() +} + +fun verify(e: JavaTestEntity, expected: JavaTestEntity) { + e.bigDecimalMan shouldBeEqualIgnoringScale expected.bigDecimalMan + e.javaUtilDateMan shouldBeEqual expected.javaUtilDateMan + e.uuidMan shouldBeEqual expected.uuidMan + e.bigDecimalOpt shouldBeEqualIgnoringScaleNullable expected.bigDecimalOpt + e.javaUtilDateOpt shouldBeEqualNullable expected.javaUtilDateOpt + e.uuidOpt shouldBeEqualNullable expected.uuidOpt +} diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/MiscOpsR2dbc.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/MiscOpsR2dbc.kt new file mode 100644 index 0000000..0fe95c4 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/MiscOpsR2dbc.kt @@ -0,0 +1,19 @@ +package io.exoquery.r2dbc.encodingdata + +import io.exoquery.controller.r2dbc.R2dbcEncoderAny +import io.exoquery.controller.r2dbc.R2dbcEncodingConfig +import io.exoquery.sql.encodingdata.SerializeableTestType +import io.kotest.matchers.bigdecimal.shouldBeEqualIgnoringScale +import org.junit.jupiter.api.Assertions.assertEquals +import java.math.BigDecimal + +val encodingConfig = R2dbcEncodingConfig( + setOf( + R2dbcEncoderAny(0, SerializeableTestType::class) { ctx, v, i -> ctx.stmt.bind(i, v.value) } + ) +) + +public infix fun BigDecimal?.shouldBeEqualIgnoringScaleNullable(expected: BigDecimal?) = + if (this == null && expected == null) Unit + else if (this == null || expected == null) assertEquals(this, expected) + else this.shouldBeEqualIgnoringScale(expected) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/TimeEntities.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/TimeEntities.kt new file mode 100644 index 0000000..129dfd1 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/encodingdata/TimeEntities.kt @@ -0,0 +1,86 @@ +package io.exoquery.r2dbc.encodingdata + +import io.exoquery.controller.ControllerAction +import io.exoquery.sql.Sql +import kotlinx.serialization.Contextual +import kotlinx.serialization.Serializable +import java.time.* + +@Serializable +data class SimpleTimeEntity( + // Remove java.sql.* types. They are mostly deprecated and not recommended for use in new code. + //@Contextual val sqlDate: Date, // DATE + //@Contextual val sqlTime: Time, // TIME + //@Contextual val sqlTimestamp: Timestamp, // DATETIME + @Contextual val timeLocalDate: LocalDate, // DATE + @Contextual val timeLocalTime: LocalTime, // TIME + @Contextual val timeLocalDateTime: LocalDateTime, // DATETIME + @Contextual val timeZonedDateTime: ZonedDateTime, // DATETIMEOFFSET + @Contextual val timeInstant: Instant, // DATETIMEOFFSET + @Contextual val timeOffsetTime: OffsetTime, // TIME + @Contextual val timeOffsetDateTime: OffsetDateTime // DATETIMEOFFSET +) { + override fun equals(other: Any?): Boolean = + when (other) { + is SimpleTimeEntity -> + this.timeLocalDate == other.timeLocalDate && + this.timeLocalTime == other.timeLocalTime && + this.timeLocalDateTime == other.timeLocalDateTime && + this.timeZonedDateTime.isEqual(other.timeZonedDateTime) && + this.timeInstant == other.timeInstant && + this.timeOffsetTime.isEqual(other.timeOffsetTime) && + this.timeOffsetDateTime.isEqual(other.timeOffsetDateTime) + else -> false + } + + data class TimeEntityInput(val year: Int, val month: Int, val day: Int, val hour: Int, val minute: Int, val second: Int, val nano: Int) { + fun toLocalDate() = LocalDateTime.of(year, month, day, hour, minute, second, nano) + companion object { + val default = TimeEntityInput(2022, 1, 2, 3, 4, 6, 0) + } + } + + companion object { + fun make(zoneIdRaw: ZoneId, timeEntity: TimeEntityInput = TimeEntityInput.default) = run { + val zoneId = zoneIdRaw.normalized() + val nowInstant = timeEntity.toLocalDate().atZone(zoneId).toInstant() + val nowDateTime = LocalDateTime.ofInstant(nowInstant, zoneId) + val nowDate = nowDateTime.toLocalDate() + val nowTime = nowDateTime.toLocalTime() + val nowZoned = ZonedDateTime.of(nowDateTime, zoneId) + SimpleTimeEntity( + nowDate, + nowTime, + nowDateTime, + nowZoned, + nowInstant, + OffsetTime.ofInstant(nowInstant, zoneId), + OffsetDateTime.ofInstant(nowInstant, zoneId) + ) + } + } +} + +fun insert(e: SimpleTimeEntity): ControllerAction { + return Sql( + """ + INSERT INTO TimeEntity ( + timeLocalDate, + timeLocalTime, + timeLocalDateTime, + timeZonedDateTime, + timeInstant, + timeOffsetTime, + timeOffsetDateTime + ) VALUES ( + ${e.timeLocalDate}, + ${e.timeLocalTime}, + ${e.timeLocalDateTime}, + ${e.timeZonedDateTime}, + ${e.timeInstant}, + ${e.timeOffsetTime}, + ${e.timeOffsetDateTime} + ) + """ + ).action() +} diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BasicActionSpec.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BasicActionSpec.kt new file mode 100644 index 0000000..4b6b9d3 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BasicActionSpec.kt @@ -0,0 +1,64 @@ +package io.exoquery.r2dbc.postgres + +import io.exoquery.controller.runOn +import io.exoquery.controller.runActions +import io.exoquery.sql.Sql +import io.kotest.core.spec.style.FreeSpec +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable +import io.exoquery.controller.r2dbc.R2dbcController +import io.exoquery.r2dbc.TestDatabasesR2dbc + +class BasicActionSpec : FreeSpec({ + // Start EmbeddedPostgres and build an R2DBC ConnectionFactory from its port + + val cf = TestDatabasesR2dbc.postgres + val ctx: R2dbcController by lazy { R2dbcController(connectionFactory = cf) } + + suspend fun runActions(actions: String) = ctx.runActions(actions) + + beforeEach { + runActions( + """ + TRUNCATE TABLE Person RESTART IDENTITY CASCADE; + TRUNCATE TABLE Address RESTART IDENTITY CASCADE; + """.trimIndent() + ) + } + + @Serializable + data class Person(val id: Int, val firstName: String, val lastName: String, val age: Int) + + val joe = Person(1, "Joe", "Bloggs", 111) + val jim = Person(2, "Jim", "Roogs", 222) + + "Basic Insert" { + Sql("INSERT INTO Person (id, firstName, lastName, age) VALUES (${joe.id}, ${joe.firstName}, ${joe.lastName}, ${joe.age})").action().runOn(ctx) + Sql("INSERT INTO Person (id, firstName, lastName, age) VALUES (${jim.id}, ${jim.firstName}, ${jim.lastName}, ${jim.age})").action().runOn(ctx) + Sql("SELECT id, firstName, lastName, age FROM Person").queryOf().runOn(ctx) shouldBe listOf(joe, jim) + } + + "Insert Returning" { + val id1 = Sql("INSERT INTO Person (firstName, lastName, age) VALUES (${joe.firstName}, ${joe.lastName}, ${joe.age}) RETURNING id").actionReturning().runOn(ctx) + val id2 = Sql("INSERT INTO Person (firstName, lastName, age) VALUES (${jim.firstName}, ${jim.lastName}, ${jim.age}) RETURNING id").actionReturning().runOn(ctx) + id1 shouldBe 1 + id2 shouldBe 2 + Sql("SELECT id, firstName, lastName, age FROM Person").queryOf().runOn(ctx) shouldBe listOf(joe, jim) + } + + "Insert Returning Record" { + val person1 = Sql("INSERT INTO Person (firstName, lastName, age) VALUES (${joe.firstName}, ${joe.lastName}, ${joe.age}) RETURNING id, firstName, lastName, age").actionReturning().runOn(ctx) + val person2 = Sql("INSERT INTO Person (firstName, lastName, age) VALUES (${jim.firstName}, ${jim.lastName}, ${jim.age}) RETURNING id, firstName, lastName, age").actionReturning().runOn(ctx) + person1 shouldBe joe + person2 shouldBe jim + Sql("SELECT id, firstName, lastName, age FROM Person").queryOf().runOn(ctx) shouldBe listOf(joe, jim) + } + + "Insert Returning Ids" { + val id1 = Sql("INSERT INTO Person (firstName, lastName, age) VALUES (${joe.firstName}, ${joe.lastName}, ${joe.age})").actionReturningId("id").runOn(ctx) + val id2 = Sql("INSERT INTO Person (firstName, lastName, age) VALUES (${jim.firstName}, ${jim.lastName}, ${jim.age})").actionReturningId("id").runOn(ctx) + id1 shouldBe 1 + id2 shouldBe 2 + Sql("SELECT id, firstName, lastName, age FROM Person").queryOf().runOn(ctx) shouldBe listOf(joe, jim) + } +}) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BasicQuerySpec.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BasicQuerySpec.kt new file mode 100644 index 0000000..007e1d1 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BasicQuerySpec.kt @@ -0,0 +1,140 @@ +package io.exoquery.r2dbc.postgres + +import io.exoquery.controller.runOn +import io.exoquery.controller.runActions +import io.exoquery.sql.Sql +import io.kotest.core.spec.style.FreeSpec +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable +import io.exoquery.controller.r2dbc.R2dbcController +import io.exoquery.r2dbc.TestDatabasesR2dbc + +class BasicQuerySpec : FreeSpec({ + + // Start EmbeddedPostgres and build an R2DBC ConnectionFactory from its port + + val cf = TestDatabasesR2dbc.postgres + val ctx: R2dbcController by lazy { R2dbcController(connectionFactory = cf) } + + suspend fun runActions(actions: String) = ctx.runActions(actions) + + beforeSpec { + runActions( + """ + DELETE FROM Person; + DELETE FROM Address; + INSERT INTO Person (id, firstName, lastName, age) VALUES (1, 'Joe', 'Bloggs', 111); + INSERT INTO Person (id, firstName, lastName, age) VALUES (2, 'Jim', 'Roogs', 222); + INSERT INTO Address (ownerId, street, zip) VALUES (1, '123 Main St', '12345'); + """.trimIndent() + ) + } + + + "SELECT Person - simple" { + @Serializable + data class Person(val id: Int, val firstName: String, val lastName: String, val age: Int) + Sql("SELECT id, firstName, lastName, age FROM Person").queryOf().runOn(ctx) shouldBe listOf( + Person(1, "Joe", "Bloggs", 111), + Person(2, "Jim", "Roogs", 222) + ) + } + + "joins" - { + @Serializable + data class Person(val id: Int, val firstName: String, val lastName: String, val age: Int) + @Serializable + data class Address(val ownerId: Int, val street: String, val zip: String) + + "SELECT Person, Address - join" { + Sql("SELECT p.id, p.firstName, p.lastName, p.age, a.ownerId, a.street, a.zip FROM Person p JOIN Address a ON p.id = a.ownerId").queryOf>().runOn(ctx) shouldBe listOf( + Person(1, "Joe", "Bloggs", 111) to Address(1, "123 Main St", "12345") + ) + } + + "SELECT Person, Address - leftJoin + null" { + Sql("SELECT p.id, p.firstName, p.lastName, p.age, a.ownerId, a.street, a.zip FROM Person p LEFT JOIN Address a ON p.id = a.ownerId").queryOf>().runOn(ctx) shouldBe listOf( + Person(1, "Joe", "Bloggs", 111) to Address(1, "123 Main St", "12345"), + Person(2, "Jim", "Roogs", 222) to null + ) + } + + "SELECT Person, Address - leftJoin + null (Triple(NN,null,null))" { + Sql("SELECT p.id, p.firstName, p.lastName, p.age, a.ownerId, a.street, a.zip, aa.ownerId, aa.street, aa.zip FROM Person p LEFT JOIN Address a ON p.id = a.ownerId LEFT JOIN Address aa ON p.id = aa.ownerId").queryOf>().runOn(ctx) shouldBe listOf( + Triple(Person(1, "Joe", "Bloggs", 111), Address(1, "123 Main St", "12345"), Address(1, "123 Main St", "12345")), + Triple(Person(2, "Jim", "Roogs", 222), null, null) + ) + } + + // This is a test for the RowEncoder to advance number of null elements (in the child decoder) that are needed when all rows are null + "SELECT Person, Address - leftJoin + null (Triple(NN,null,NN))" { + Sql("SELECT p.id, p.firstName, p.lastName, p.age, a.ownerId, a.street, a.zip, aa.ownerId, aa.street, aa.zip FROM Person p LEFT JOIN Address a ON p.id = a.ownerId LEFT JOIN Address aa ON 1 = aa.ownerId").queryOf>().runOn(ctx) shouldBe listOf( + Triple(Person(1, "Joe", "Bloggs", 111), Address(1, "123 Main St", "12345"), Address(1, "123 Main St", "12345")), + Triple(Person(2, "Jim", "Roogs", 222), null, Address(1, "123 Main St", "12345")) + ) + } + + @Serializable + data class CustomRow1(val Person: Person, val Address: Address) + @Serializable + data class CustomRow2(val Person: Person, val Address: Address?) + + "SELECT Person, Address - join - custom row" { + Sql("SELECT p.id, p.firstName, p.lastName, p.age, a.ownerId, a.street, a.zip FROM Person p JOIN Address a ON p.id = a.ownerId").queryOf().runOn(ctx) shouldBe listOf( + CustomRow1(Person(1, "Joe", "Bloggs", 111), Address(1, "123 Main St", "12345")) + ) + } + + "SELECT Person, Address - leftJoin + null - custom row" { + Sql("SELECT p.id, p.firstName, p.lastName, p.age, a.ownerId, a.street, a.zip FROM Person p LEFT JOIN Address a ON p.id = a.ownerId").queryOf().runOn(ctx) shouldBe listOf( + CustomRow2(Person(1, "Joe", "Bloggs", 111), Address(1, "123 Main St", "12345")), + CustomRow2(Person(2, "Jim", "Roogs", 222), null) + ) + } + } + + "joins + null complex" - { + @Serializable + data class Person(val id: Int, val firstName: String?, val lastName: String, val age: Int) + @Serializable + data class Address(val ownerId: Int?, val street: String, val zip: String) + + "SELECT Person, Address - join" { + Sql("SELECT p.id, null as firstName, p.lastName, p.age, null as ownerId, a.street, a.zip FROM Person p JOIN Address a ON p.id = a.ownerId").queryOf>().runOn(ctx) shouldBe listOf( + Person(1, null, "Bloggs", 111) to Address(null, "123 Main St", "12345") + ) + } + + "SELECT Person, Address - leftJoin + null" { + Sql("SELECT p.id, null as firstName, p.lastName, p.age, null as ownerId, a.street, a.zip FROM Person p LEFT JOIN Address a ON p.id = a.ownerId").queryOf>().runOn(ctx) shouldBe listOf( + Person(1, null, "Bloggs", 111) to Address(null, "123 Main St", "12345"), + Person(2, null, "Roogs", 222) to null + ) + } + } + + "SELECT Person - nested" { + @Serializable + data class Name(val firstName: String, val lastName: String) + @Serializable + data class Person(val id: Int, val name: Name, val age: Int) + + Sql("SELECT id, firstName, lastName, age FROM Person").queryOf().runOn(ctx) shouldBe listOf( + Person(1, Name("Joe", "Bloggs"), 111), + Person(2, Name("Jim", "Roogs"), 222) + ) + } + + "SELECT Person - nested with join" { + @Serializable + data class Name(val firstName: String, val lastName: String) + @Serializable + data class Person(val id: Int, val name: Name, val age: Int) + @Serializable + data class Address(val street: String, val zip: String) + + Sql("SELECT p.id, p.firstName, p.lastName, p.age, a.street, a.zip FROM Person p JOIN Address a ON p.id = a.ownerId").queryOf>().runOn(ctx) shouldBe listOf( + Person(1, Name("Joe", "Bloggs"), 111) to Address("123 Main St", "12345") + ) + } +}) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BatchValuesSpec.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BatchValuesSpec.kt new file mode 100644 index 0000000..650457b --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/BatchValuesSpec.kt @@ -0,0 +1,44 @@ +package io.exoquery.r2dbc.postgres + +import io.exoquery.controller.r2dbc.R2dbcController +import io.exoquery.controller.runActions +import io.exoquery.controller.runOn +import io.exoquery.r2dbc.Ex1_BatchInsertNormal +import io.exoquery.r2dbc.Ex2_BatchInsertMixed +import io.exoquery.r2dbc.Ex3_BatchReturnIds +import io.exoquery.r2dbc.Ex4_BatchReturnRecord +import io.exoquery.r2dbc.TestDatabasesR2dbc +import io.kotest.core.spec.style.FreeSpec +import io.kotest.matchers.shouldBe + +class BatchValuesSpec: FreeSpec ({ + + val cf = TestDatabasesR2dbc.postgres + val ctx: R2dbcController by lazy { R2dbcController(connectionFactory = cf) } + + suspend fun runActions(actions: String) = ctx.runActions(actions) + + beforeEach { + runActions("TRUNCATE TABLE Product RESTART IDENTITY CASCADE") + } + + "Ex 1 - Batch Insert Normal" { + Ex1_BatchInsertNormal.op.runOn(ctx) + Ex1_BatchInsertNormal.get.runOn(ctx) shouldBe Ex1_BatchInsertNormal.result + } + + "Ex 2 - Batch Insert Mixed" { + Ex2_BatchInsertMixed.op.runOn(ctx) + Ex2_BatchInsertMixed.get.runOn(ctx) shouldBe Ex2_BatchInsertMixed.result + } + + "Ex 3 - Batch Return Ids" { + Ex3_BatchReturnIds.op.runOn(ctx) shouldBe Ex3_BatchReturnIds.opResult + Ex3_BatchReturnIds.get.runOn(ctx) shouldBe Ex3_BatchReturnIds.result + } + + "Ex 4 - Batch Return Record" { + Ex4_BatchReturnRecord.op.runOn(ctx) shouldBe Ex4_BatchReturnRecord.opResult + Ex4_BatchReturnRecord.get.runOn(ctx) shouldBe Ex4_BatchReturnRecord.result + } +}) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/EncodingSpec.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/EncodingSpec.kt new file mode 100644 index 0000000..08ad46d --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/EncodingSpec.kt @@ -0,0 +1,154 @@ +package io.exoquery.r2dbc.postgres + +import io.exoquery.sql.encodingdata.* +import io.exoquery.sql.Sql +import io.exoquery.controller.runOn +import io.exoquery.controller.runActions +import io.kotest.core.spec.style.FreeSpec +import java.time.ZoneId +import io.exoquery.controller.r2dbc.R2dbcController +import io.exoquery.r2dbc.TestDatabasesR2dbc +import io.exoquery.r2dbc.encodingdata.JavaTestEntity +import io.exoquery.r2dbc.encodingdata.SimpleTimeEntity +import io.exoquery.r2dbc.encodingdata.encodingConfig +import io.exoquery.r2dbc.encodingdata.insert +import io.exoquery.r2dbc.encodingdata.verify + +class EncodingSpec: FreeSpec({ + + val cf = TestDatabasesR2dbc.postgres + val ctx: R2dbcController by lazy { R2dbcController(encodingConfig = encodingConfig, connectionFactory = cf) } + + suspend fun runActions(actions: String) = ctx.runActions(actions) + + beforeEach { + // The main table used across many tests + runActions("DELETE FROM EncodingTestEntity") + } + + "encodes and decodes nullables - not nulls" { + insert(EncodingTestEntity.regular).runOn(ctx) + val res = Sql("SELECT * FROM EncodingTestEntity").queryOf().runOn(ctx) + verify(res.first(), EncodingTestEntity.regular) + } + + "encodes and decodes custom impls nullables - not nulls" { + insert(EncodingTestEntityImp.regular).runOn(ctx) + val res = Sql("SELECT * FROM EncodingTestEntity").queryOf().runOn(ctx) + verify(res.first(), EncodingTestEntityImp.regular) + } + + "encodes and decodes custom impls nullables - nulls" { + insert(EncodingTestEntityImp.empty).runOn(ctx) + val res = Sql("SELECT * FROM EncodingTestEntity").queryOf().runOn(ctx) + verify(res.first(), EncodingTestEntityImp.empty) + } + + "encodes and decodes custom value-classes nullables - not nulls" { + insert(EncodingTestEntityVal.regular).runOn(ctx) + val res = Sql("SELECT * FROM EncodingTestEntity").queryOf().runOn(ctx) + verify(res.first(), EncodingTestEntityVal.regular) + } + + "encodes and decodes custom value-classes nullables - nulls" { + insert(EncodingTestEntityVal.empty).runOn(ctx) + val res = Sql("SELECT * FROM EncodingTestEntity").queryOf().runOn(ctx) + verify(res.first(), EncodingTestEntityVal.empty) + } + + "encodes and decodes batch" { + insertBatch(listOf(EncodingTestEntity.regular, EncodingTestEntity.regular)).runOn(ctx) + val res = Sql("SELECT * FROM EncodingTestEntity").queryOf().runOn(ctx) + verify(res[0], EncodingTestEntity.regular) + verify(res[1], EncodingTestEntity.regular) + } + + "encodes and decodes nullables - nulls" { + insert(EncodingTestEntity.empty).runOn(ctx) + val res = Sql("SELECT * FROM EncodingTestEntity").queryOf().runOn(ctx) + verify(res.first(), EncodingTestEntity.empty) + } + + "Encode/Decode Additional Java Types - regular" { + runActions("DELETE FROM JavaTestEntity") + insert(JavaTestEntity.Companion.regular).runOn(ctx) + val actual = Sql("SELECT * FROM JavaTestEntity").queryOf().runOn(ctx).first() + verify(actual, JavaTestEntity.Companion.regular) + } + + "Encode/Decode Additional Java Types - empty" { + runActions("DELETE FROM JavaTestEntity") + insert(JavaTestEntity.Companion.empty).runOn(ctx) + val actual = Sql("SELECT * FROM JavaTestEntity").queryOf().runOn(ctx).first() + verify(actual, JavaTestEntity.Companion.empty) + } + + "Encode/Decode KMP Types" { + runActions("DELETE FROM KmpTestEntity") + insert(KmpTestEntity.regular).runOn(ctx) + val actual = Sql("SELECT * FROM KmpTestEntity").queryOf().runOn(ctx).first() + verify(actual, KmpTestEntity.regular) + } + + "Encode/Decode Other Time Types" { + runActions("DELETE FROM TimeEntity") + val zid = ZoneId.systemDefault() + val timeEntity = SimpleTimeEntity.Companion.make(zid) + insert(timeEntity).runOn(ctx) + val actual = Sql(""" + SELECT + timeLocalDate, + timeLocalTime, + timeLocalDateTime, + timeZonedDateTime, + timeInstant, + timeOffsetTime, + timeOffsetDateTime + FROM TimeEntity + """).queryOf().runOn(ctx).first() + assert(timeEntity == actual) + } + + "Encode/Decode Other Time Types ordering" { + runActions("DELETE FROM TimeEntity") + + val zid = ZoneId.systemDefault() + val timeEntityA = SimpleTimeEntity.make(zid, SimpleTimeEntity.TimeEntityInput(2022, 1, 1, 1, 1, 1, 0)) + val timeEntityB = SimpleTimeEntity.make(zid, SimpleTimeEntity.TimeEntityInput(2022, 2, 2, 2, 2, 2, 0)) + + insert(timeEntityA).runOn(ctx) + insert(timeEntityB).runOn(ctx) + + assert(timeEntityB.timeLocalDate > timeEntityA.timeLocalDate) + assert(timeEntityB.timeLocalTime > timeEntityA.timeLocalTime) + assert(timeEntityB.timeLocalDateTime > timeEntityA.timeLocalDateTime) + assert(timeEntityB.timeZonedDateTime > timeEntityA.timeZonedDateTime) + assert(timeEntityB.timeInstant > timeEntityA.timeInstant) + assert(timeEntityB.timeOffsetTime > timeEntityA.timeOffsetTime) + assert(timeEntityB.timeOffsetDateTime > timeEntityA.timeOffsetDateTime) + + val actual = + Sql(""" + SELECT + timeLocalDate, + timeLocalTime, + timeLocalDateTime, + timeZonedDateTime, + timeInstant, + timeOffsetTime, + timeOffsetDateTime + FROM TimeEntity + WHERE + timeLocalDate > ${timeEntityA.timeLocalDate} + AND timeLocalTime > ${timeEntityA.timeLocalTime} + AND timeLocalDateTime > ${timeEntityA.timeLocalDateTime} + AND timeZonedDateTime > ${timeEntityA.timeZonedDateTime} + AND timeInstant > ${timeEntityA.timeInstant} + AND timeOffsetTime > ${timeEntityA.timeOffsetTime} + AND timeOffsetDateTime > ${timeEntityA.timeOffsetDateTime} + """ + ).queryOf().runOn(ctx).first() + + assert(actual == timeEntityB) + } +}) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/InQuerySpec.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/InQuerySpec.kt new file mode 100644 index 0000000..c21e80d --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/InQuerySpec.kt @@ -0,0 +1,64 @@ +package io.exoquery.r2dbc.postgres + +import io.exoquery.controller.runActions +import io.exoquery.controller.runOn +import io.exoquery.controller.r2dbc.R2dbcController +import io.exoquery.r2dbc.TestDatabasesR2dbc +import io.exoquery.sql.Params +import io.exoquery.sql.Sql +import io.kotest.core.spec.style.FreeSpec +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable + +class InQuerySpec : FreeSpec({ + // Start EmbeddedPostgres and build an R2DBC ConnectionFactory from its port + + val cf = TestDatabasesR2dbc.postgres + val ctx: R2dbcController by lazy { R2dbcController(connectionFactory = cf) } + + suspend fun runActions(actions: String) = ctx.runActions(actions) + + beforeSpec { + runActions( + """ + DELETE FROM Person; + DELETE FROM Address; + INSERT INTO Person (id, firstName, lastName, age) VALUES (1, 'Joe', 'Bloggs', 111); + INSERT INTO Person (id, firstName, lastName, age) VALUES (2, 'Jim', 'Roogs', 222); + INSERT INTO Person (id, firstName, lastName, age) VALUES (3, 'Jill', 'Doogs', 222); + INSERT INTO Address (ownerId, street, zip) VALUES (1, '123 Main St', '12345'); + """.trimIndent() + ) + } + + @Serializable + data class Person(val id: Int, val firstName: String, val lastName: String, val age: Int) + + "Person IN (names) - simple" { + val sql = Sql("SELECT id, firstName, lastName, age FROM Person WHERE firstName IN ${Params("Joe", "Jim")}").queryOf() + sql.sql shouldBe "SELECT id, firstName, lastName, age FROM Person WHERE firstName IN (?, ?)" + sql.runOn(ctx) shouldBe listOf( + Person(1, "Joe", "Bloggs", 111), + Person(2, "Jim", "Roogs", 222) + ) + } + + "Person IN (names) - single" { + val sql = Sql("SELECT id, firstName, lastName, age FROM Person WHERE firstName IN ${Params("Joe")}").queryOf() + sql.sql shouldBe "SELECT id, firstName, lastName, age FROM Person WHERE firstName IN (?)" + sql.runOn(ctx) shouldBe listOf( + Person(1, "Joe", "Bloggs", 111) + ) + } + + "Person IN (names) - empty" { + val sql = Sql("SELECT id, firstName, lastName, age FROM Person WHERE firstName IN ${Params.empty()}").queryOf() + sql.sql shouldBe "SELECT id, firstName, lastName, age FROM Person WHERE firstName IN (null)" + sql.runOn(ctx) shouldBe listOf() + } + + "Person IN (names) - empty list" { + val names: List = emptyList() + Sql("SELECT id, firstName, lastName, age FROM Person WHERE firstName IN ${Params.list(names)}").queryOf().runOn(ctx) shouldBe listOf() + } +}) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/InjectionSpec.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/InjectionSpec.kt new file mode 100644 index 0000000..f52b0c4 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/InjectionSpec.kt @@ -0,0 +1,38 @@ +package io.exoquery.r2dbc.postgres + +import io.exoquery.controller.runActions +import io.exoquery.controller.runOn +import io.exoquery.controller.r2dbc.R2dbcController +import io.exoquery.r2dbc.TestDatabasesR2dbc +import io.exoquery.sql.Param +import io.exoquery.sql.Sql +import io.exoquery.sql.encodingdata.EncodingTestEntity +import io.exoquery.sql.encodingdata.insert +import io.kotest.core.spec.style.FreeSpec +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable + +class InjectionSpec: FreeSpec({ + + val cf = TestDatabasesR2dbc.postgres + val ctx: R2dbcController by lazy { R2dbcController(connectionFactory = cf) } + + suspend fun runActions(actions: String) = ctx.runActions(actions) + + beforeEach { + runActions("DELETE FROM Person") + runActions("INSERT INTO Person (id, firstName, lastName, age) VALUES (1, 'Joe', 'Blogs', 123)") + } + + @Serializable + data class Person(val id: Int, val firstName: String, val lastName: String, val age: Int) + + "escapes column meant to be an injection attack" { + insert(EncodingTestEntity.regular).runOn(ctx) + val name = "'Joe'; DROP TABLE Person;" + Sql("SELECT * FROM Person WHERE firstName = ${Param.withSer(name)}").queryOf().runOn(ctx) shouldBe listOf() + + // verify table still exists and is intact + Sql("SELECT * FROM Person").queryOf().runOn(ctx) shouldBe listOf(Person(1, "Joe", "Blogs", 123)) + } +}) diff --git a/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/JsonSpec.kt b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/JsonSpec.kt new file mode 100644 index 0000000..c32bec1 --- /dev/null +++ b/terpal-sql-r2dbc/src/test/kotlin/io/exoquery/r2dbc/postgres/JsonSpec.kt @@ -0,0 +1,165 @@ +package io.exoquery.r2dbc.postgres + +import io.exoquery.controller.JsonValue +import io.exoquery.controller.SqlJsonValue +import io.exoquery.controller.r2dbc.R2dbcController +import io.exoquery.controller.r2dbc.R2dbcControllers +import io.exoquery.controller.runActions +import io.exoquery.controller.runOn +import io.exoquery.r2dbc.TestDatabasesR2dbc +import io.exoquery.sql.Param +import io.exoquery.sql.Sql +import io.kotest.core.spec.style.FreeSpec +import io.kotest.matchers.shouldBe +import kotlinx.serialization.Serializable + +typealias MyPersonJson = @Serializable @SqlJsonValue JsonSpecData.A.MyPerson + +object JsonSpecData { + object A { + @Serializable + data class MyPerson(val name: String, val age: Int) + + @Serializable + data class Example(val id: Int, val value: MyPersonJson) + } +} + +class JsonSpec: FreeSpec({ + + val cf = TestDatabasesR2dbc.postgres + val ctx: R2dbcController by lazy { R2dbcControllers.Postgres(connectionFactory = cf) } + + suspend fun runActions(actions: String) = ctx.runActions(actions) + + beforeEach { + runActions("DELETE FROM JsonbExample") + runActions("DELETE FROM JsonbExample2") + runActions("DELETE FROM JsonbExample3") + runActions("DELETE FROM JsonExample") + } + + "SqlJsonValue annotation works on" - { + "inner data class" - { + @SqlJsonValue + @Serializable + data class MyPerson(val name: String, val age: Int) + + @Serializable + data class Example(val id: Int, val value: MyPerson) + + val je = Example(1, MyPerson("Alice", 30)) + + "should encode in jsonb and decode" { + Sql("INSERT INTO JsonbExample (id, value) VALUES (1, ${Param.withSer(je.value)})").action().runOn(ctx) + Sql("SELECT id, value FROM JsonbExample").queryOf().runOn(ctx) shouldBe listOf(je) + } + + "should encode in jsonb and decode as atom" { + Sql("INSERT INTO JsonbExample (id, value) VALUES (1, ${Param.withSer(je.value)})").action().runOn(ctx) + Sql("SELECT value FROM JsonbExample").queryOf().runOn(ctx) shouldBe listOf(je.value) + } + + "should encode in json (with explicit serializer) and decode" { + Sql("INSERT INTO JsonExample (id, value) VALUES (1, ${Param.withSer(je.value, MyPerson.serializer())})").action().runOn(ctx) + Sql("SELECT id, value FROM JsonExample").queryOf().runOn(ctx) shouldBe listOf(je) + } + } + + "annotated field" { + @Serializable + data class MyPerson(val name: String, val age: Int) + + @Serializable + data class Example(val id: Int, @SqlJsonValue val value: MyPerson) + + val je = Example(1, MyPerson("Joe", 123)) + Sql("""INSERT INTO JsonbExample (id, value) VALUES (1, '{"name":"Joe", "age":123}')""").action().runOn(ctx) + val customers = Sql("SELECT id, value FROM JsonbExample").queryOf().runOn(ctx) + customers shouldBe listOf(je) + } + + "outer typealias - A".config(enabled = false) { + val je = JsonSpecData.A.Example(1, JsonSpecData.A.MyPerson("Joe", 123)) + Sql("""INSERT INTO JsonbExample (id, value) VALUES (1, '{"name":"Joe", "age":123}')""").action().runOn(ctx) + val customers = Sql("SELECT id, value FROM JsonbExample").queryOf().runOn(ctx) + customers shouldBe listOf(je) + } + } + + "JsonValue object works on" - { + "inner data class" - { + @Serializable + data class MyPerson(val name: String, val age: Int) + + @Serializable + data class JsonbExample(val id: Int, val jsonValue: JsonValue) + + "field value" { + val je = JsonbExample(1, JsonValue(MyPerson("Alice", 30))) + Sql("INSERT INTO JsonbExample (id, value) VALUES (1, ${Param.withSer(je.jsonValue)})").action().runOn(ctx) + Sql("SELECT id, value FROM JsonbExample").queryOf().runOn(ctx) shouldBe listOf(je) + } + + "leaf value" { + val je = JsonbExample(1, JsonValue(MyPerson("Alice", 30))) + Sql("INSERT INTO JsonbExample (id, value) VALUES (1, ${Param.withSer(je.jsonValue)})").action().runOn(ctx) + Sql("SELECT value FROM JsonbExample").queryOf>().runOn(ctx) shouldBe listOf(je.jsonValue) + } + } + "complex data classes" - { + @Serializable + data class MyPerson(val name: String, val age: Int) + + @Serializable + data class JsonbExample(val id: Int, val jsonValue: JsonValue>) + + val people = listOf(MyPerson("Joe", 30), MyPerson("Jack", 31)) + val je = JsonbExample(1, JsonValue(people)) + + "field value" { + Sql("INSERT INTO JsonbExample (id, value) VALUES (1, ${Param.withSer(je.jsonValue)})").action().runOn(ctx) + Sql("SELECT id, value FROM JsonbExample").queryOf().runOn(ctx) shouldBe listOf(je) + } + + "leaf value" { + Sql("INSERT INTO JsonbExample (id, value) VALUES (1, ${Param.withSer(je.jsonValue)})").action().runOn(ctx) + Sql("SELECT value FROM JsonbExample").queryOf>>().runOn(ctx) shouldBe listOf(je.jsonValue) + } + } + } + "multiple complex data classes" - { + @Serializable + data class MyPerson(val name: String, val age: Int) + + @Serializable + data class MyJob(val job: String, val salary: Long) + + @Serializable + data class JsonbExample2(val id: Int, val jsonValue1: JsonValue>, val jsonValue2: JsonValue>) + + val people = listOf(MyPerson("Joe", 30), MyPerson("Jack", 31)) + val jobs = listOf(MyJob("job1", 100), MyJob("job2", 200)) + val je = JsonbExample2(1, JsonValue(people), JsonValue(jobs)) + + "field value" { + Sql("INSERT INTO JsonbExample2 (id, value1, value2) VALUES (1, ${Param.withSer(je.jsonValue1)}, ${Param.withSer(je.jsonValue2)})").action().runOn(ctx) + Sql("SELECT id, value1, value2 FROM JsonbExample2").queryOf().runOn(ctx) shouldBe listOf(je) + } + } + "complex data classes before primitive column" - { + @Serializable + data class MyPerson(val name: String, val age: Int) + + @Serializable + data class JsonbExample3(val id: Int, val jsonValue1: JsonValue>, val sample: Int) + + val people = listOf(MyPerson("Joe", 30), MyPerson("Jack", 31)) + val je = JsonbExample3(1, JsonValue(people), 100) + + "field value" { + Sql("INSERT INTO JsonbExample3 (id, value, sample) VALUES (1, ${Param.withSer(je.jsonValue1)}, 100)").action().runOn(ctx) + Sql("SELECT id, value, sample FROM JsonbExample3").queryOf().runOn(ctx) shouldBe listOf(je) + } + } +}) diff --git a/terpal-sql-r2dbc/src/test/resources/db/postgres-schema.sql b/terpal-sql-r2dbc/src/test/resources/db/postgres-schema.sql new file mode 100644 index 0000000..dac126d --- /dev/null +++ b/terpal-sql-r2dbc/src/test/resources/db/postgres-schema.sql @@ -0,0 +1,103 @@ +CREATE TABLE person ( + id SERIAL PRIMARY KEY, + firstName VARCHAR(255), + lastName VARCHAR(255), + age INT +); + +CREATE TABLE address ( + ownerId INT, + street VARCHAR, + zip INT +); + +CREATE TABLE Product( + description VARCHAR(255), + id SERIAL PRIMARY KEY, + sku BIGINT +); + +CREATE TABLE KmpTestEntity( + timeLocalDate DATE, -- java.time.LocalDate + timeLocalTime TIME, -- java.time.LocalTime + timeLocalDateTime TIMESTAMP, -- java.time.LocalDateTime + timeInstant TIMESTAMP WITH TIME ZONE, -- java.time.Instant + timeLocalDateOpt DATE, + timeLocalTimeOpt TIME, + timeLocalDateTimeOpt TIMESTAMP, + timeInstantOpt TIMESTAMP WITH TIME ZONE +); + +CREATE TABLE TimeEntity( + sqlDate DATE, -- java.sql.Date + sqlTime TIME, -- java.sql.Time + sqlTimestamp TIMESTAMP, -- java.sql.Timestamp + timeLocalDate DATE, -- java.time.LocalDate + timeLocalTime TIME, -- java.time.LocalTime + timeLocalDateTime TIMESTAMP, -- java.time.LocalDateTime + timeZonedDateTime TIMESTAMP WITH TIME ZONE, -- java.time.ZonedDateTime + timeInstant TIMESTAMP WITH TIME ZONE, -- java.time.Instant + -- Postgres actually has a notion of a Time+Timezone type unlike most DBs + timeOffsetTime TIME WITH TIME ZONE, -- java.time.OffsetTime + timeOffsetDateTime TIMESTAMP WITH TIME ZONE -- java.time.OffsetDateTime +); + +CREATE TABLE EncodingTestEntity( + stringMan VARCHAR(255), + booleanMan BOOLEAN, + byteMan SMALLINT, + shortMan SMALLINT, + intMan INTEGER, + longMan BIGINT, + floatMan FLOAT, + doubleMan DOUBLE PRECISION, + byteArrayMan BYTEA, + customMan VARCHAR(255), + stringOpt VARCHAR(255), + booleanOpt BOOLEAN, + byteOpt SMALLINT, + shortOpt SMALLINT, + intOpt INTEGER, + longOpt BIGINT, + floatOpt FLOAT, + doubleOpt DOUBLE PRECISION, + byteArrayOpt BYTEA, + customOpt VARCHAR(255) +); + +CREATE TABLE JsonbExample( + id SERIAL PRIMARY KEY, + value JSONB +); + +CREATE TABLE JsonbExample2( + id SERIAL PRIMARY KEY, + value1 JSONB, + value2 JSONB +); + +CREATE TABLE JsonbExample3( + id SERIAL PRIMARY KEY, + value JSONB, + sample SERIAL +); + +CREATE TABLE JsonExample( + id SERIAL PRIMARY KEY, + value JSON +); + + +CREATE TABLE MiscTest ( + id INTEGER NOT NULL PRIMARY KEY, + value TEXT NOT NULL +); + +CREATE TABLE JavaTestEntity( + bigDecimalMan DECIMAL(5,2), + javaUtilDateMan TIMESTAMP, + uuidMan UUID, + bigDecimalOpt DECIMAL(5,2), + javaUtilDateOpt TIMESTAMP, + uuidOpt UUID +);