diff --git a/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala b/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala index 05452159..39fdd9fd 100644 --- a/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala +++ b/common/src/main/scala/org/neo4j/spark/service/SchemaService.scala @@ -114,8 +114,7 @@ class SchemaService( query: String, params: java.util.Map[String, AnyRef] ): mutable.Buffer[StructField] = { - val fields = session.run(query, params, sessionTransactionConfig) - .list + val fields = session.executeRead(tx => tx.run(query, params).list(), sessionTransactionConfig) .asScala .filter(record => !record.get("propertyName").isNull && !record.get("propertyName").isEmpty) .map(record => { @@ -149,7 +148,7 @@ class SchemaService( params: java.util.Map[String, AnyRef], extractFunction: Record => Map[String, AnyRef] ): mutable.Buffer[StructField] = { - session.run(query, params, sessionTransactionConfig).list.asScala + session.executeRead(tx => tx.run(query, params).list.asScala, sessionTransactionConfig) .flatMap(extractFunction) .groupBy(_._1) .mapValues(_.map(_._2)) @@ -340,7 +339,7 @@ class SchemaService( |RETURN * |""".stripMargin val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> options.query.value).asJava - val fields = session.run(query, map, sessionTransactionConfig).list.asScala + val fields = session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig).asScala .map(r => r.get("field").asList((t: Value) => t.asString()).asScala) .map(r => ( @@ -399,16 +398,15 @@ class SchemaService( |RETURN * |""".stripMargin val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava - session.run(query, map, sessionTransactionConfig) - .list + session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig) .asScala .map(r => (r.get("fieldName").asString(), r.get("optional").asBoolean())) .toSeq } private def getReturnedColumns(query: String): Array[String] = - session.run("EXPLAIN " + query, sessionTransactionConfig) - .keys().asScala.toArray + session.executeRead(tx => tx.run("EXPLAIN " + query).keys(), sessionTransactionConfig) + .asScala.toArray def struct(): StructType = { val struct = options.query.queryType match { @@ -469,8 +467,7 @@ class SchemaService( queryReadStrategy.createStatementForRelationshipCount(options) } log.info(s"Executing the following counting query on Neo4j: $query") - session.run(query, sessionTransactionConfig) - .list() + session.executeRead(tx => tx.run(query).list(), sessionTransactionConfig) .asScala .map(_.get("count")) .map(count => if (count.isNull) 0L else count.asLong()) @@ -487,7 +484,7 @@ class SchemaService( */ if (filters.isEmpty) { val query = "CALL apoc.meta.stats() yield labels RETURN labels" - val map = session.run(query, sessionTransactionConfig).single() + val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig) .asMap() .asScala .get("labels") @@ -510,7 +507,7 @@ class SchemaService( try { if (filters.isEmpty) { val query = "CALL apoc.meta.stats() yield relTypes RETURN relTypes" - val map = session.run(query, sessionTransactionConfig).single() + val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig) .asMap() .asScala .get("relTypes") @@ -579,29 +576,32 @@ class SchemaService( |RETURN count(*) AS count |""".stripMargin } - session.run(query, sessionTransactionConfig).single().get("count").asLong() + session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig).get("count").asLong() } } def isGdsProcedure(procName: String): Boolean = { val params: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava - session.run( - """ - |CALL gds.list() YIELD name, type - |WHERE name = $procName AND type = 'procedure' - |RETURN count(*) = 1 - |""".stripMargin, - params, + session.executeRead( + tx => + tx.run( + """ + |CALL gds.list() YIELD name, type + |WHERE name = $procName AND type = 'procedure' + |RETURN count(*) = 1 + |""".stripMargin, + params + ).single(), sessionTransactionConfig ) - .single() .get(0) .asBoolean() } def validateQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): String = try { - val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType() + val queryType = + session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType() if (expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)) { "" } else { @@ -624,7 +624,7 @@ class SchemaService( def validateQueryCount(query: String): String = try { - val resultSummary = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume() + val resultSummary = session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig) val queryType = resultSummary.queryType() val plan = resultSummary.plan() val expectedQueryTypes = @@ -642,7 +642,8 @@ class SchemaService( def isValidQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): Boolean = try { - val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType() + val queryType = + session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType() expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType) } catch { case e: Throwable => { @@ -694,8 +695,7 @@ class SchemaService( "labels" -> Seq(label).asJava, "properties" -> props.asJava ).asJava.asInstanceOf[util.Map[String, AnyRef]] - val isPresent = session.run(queryCheck, params, sessionTransactionConfig) - .single() + val isPresent = session.executeRead(tx => tx.run(queryCheck, params).single(), sessionTransactionConfig) .get("isPresent") .asBoolean() @@ -704,7 +704,7 @@ class SchemaService( } else { val query = s"$queryPrefix $querySuffix" log.info(s"Performing the following schema query: $query") - session.run(query, sessionTransactionConfig) + session.executeWrite(tx => tx.run(query).consume(), sessionTransactionConfig) "CREATED" } log.info(s"Status for $action named with label $quotedLabel and props $quotedProps is: $status") @@ -730,7 +730,7 @@ class SchemaService( s"spark_${entityType}_${constraintType.replace(s"$entityType ", "")}-CONSTRAINT_${entityIdentifier}_$dashSeparatedProps".quote() val props = keys.values.map(_.quote()).map("e." + _).mkString(", ") val asciiRepresentation: String = createCypherPattern(entityType, entityIdentifier) - session.executeWrite( + session.executeWriteWithoutResult( tx => { tx.run( s"CREATE CONSTRAINT $constraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE ($props) IS $constraintType" @@ -924,12 +924,12 @@ class SchemaService( def execute(queries: Seq[String]): util.List[util.Map[String, AnyRef]] = { val queryMap = queries .map(query => { - (session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType(), query) + (session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType(), query) }) .groupBy(_._1) .mapValues(_.map(_._2)) val schemaQueries = queryMap.getOrElse(org.neo4j.driver.summary.QueryType.SCHEMA_WRITE, Seq.empty[String]) - schemaQueries.foreach(session.run(_, sessionTransactionConfig)) + schemaQueries.foreach(q => session.executeWrite(tx => tx.run(q).consume(), sessionTransactionConfig)) val others = queryMap .filterKeys(key => key != org.neo4j.driver.summary.QueryType.SCHEMA_WRITE) .values @@ -965,12 +965,14 @@ class SchemaService( private def lastOffsetForNode(): Long = { val label = options.nodeMetadata.labels.head - session.run( - s"""MATCH (n:$label) - |RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin, + session.executeRead( + tx => + tx.run( + s"""MATCH (n:$label) + |RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin + ).single(), sessionTransactionConfig ) - .single() .get(options.streamingOptions.propertyName) .asLong(-1) } @@ -980,20 +982,22 @@ class SchemaService( val targetLabel = options.relationshipMetadata.target.labels.head.quote() val relType = options.relationshipMetadata.relationshipType.quote() - session.run( - s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel) - |RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin, + session.executeRead( + tx => + tx.run( + s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel) + |RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin + ).single(), sessionTransactionConfig ) - .single() .get(options.streamingOptions.propertyName) .asLong(-1) } - private def lastOffsetForQuery(): Long = session.run(options.streamingOptions.queryOffset, sessionTransactionConfig) - .single() - .get(0) - .asLong(-1) + private def lastOffsetForQuery(): Long = + session.executeRead(tx => tx.run(options.streamingOptions.queryOffset).single(), sessionTransactionConfig) + .get(0) + .asLong(-1) def lastOffset(): Long = options.query.queryType match { case QueryType.LABELS => lastOffsetForNode() diff --git a/common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala b/common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala index 939041b1..8567e4bb 100644 --- a/common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala +++ b/common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala @@ -20,12 +20,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SaveMode import org.apache.spark.sql.SparkSession import org.jetbrains.annotations.TestOnly -import org.neo4j.connectors.authn.AuthenticationToken import org.neo4j.connectors.authn.AuthenticationTokenSupplierFactory import org.neo4j.connectors.authn.BearerAuthenticationToken import org.neo4j.connectors.authn.CustomAuthenticationToken -import org.neo4j.connectors.authn.DisabledAuthenticationToken -import org.neo4j.connectors.authn.ExpiringAuthenticationToken import org.neo4j.connectors.authn.KerberosAuthenticationToken import org.neo4j.connectors.authn.UserNameAndPasswordAuthenticationToken import org.neo4j.driver.Config.TrustStrategy @@ -40,8 +37,9 @@ import java.time.Duration import java.util import java.util.ServiceLoader import java.util.UUID +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionStage import java.util.concurrent.TimeUnit -import java.util.function.Supplier import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -498,47 +496,51 @@ case class Neo4jDriverOptions( (URI.create(urls.head.trim), resolved) } - // TODO this is partially intentionally not working as expect, changed to make it compile - // TODO missing here is also the keycloack support that comes from the authn commons private def createAuthTokenManager: AuthTokenManager = { if (auth == null || auth.isEmpty) { throw new IllegalArgumentException(s"Authentication type name is required") } - val token = createAuthTokenSupplier.get() - token match { - case bearerAuthenticationToken: BearerAuthenticationToken => - AuthTokenManagers.bearer(() => - AuthTokens.bearer( - bearerAuthenticationToken.getToken - ).expiringAt(bearerAuthenticationToken.getExpiresAt.toEpochMilli) - ) - case customAuthenticationToken: CustomAuthenticationToken => - AuthTokenManagers.basic(() => - AuthTokens.custom( - customAuthenticationToken.getPrincipal, - customAuthenticationToken.getCredentials, - customAuthenticationToken.getRealm, - customAuthenticationToken.getScheme, - customAuthenticationToken.getParameters - ) - ) - case disabledAuthenticationToken: DisabledAuthenticationToken => - AuthTokenManagers.basic(() => AuthTokens.none()) - case kerberosAuthenticationToken: KerberosAuthenticationToken => - AuthTokenManagers.basic(() => AuthTokens.kerberos(kerberosAuthenticationToken.getToken)) - case userNameAndPasswordAuthenticationToken: UserNameAndPasswordAuthenticationToken => - AuthTokenManagers.basic(() => - AuthTokens.basic( - userNameAndPasswordAuthenticationToken.getUsername, - userNameAndPasswordAuthenticationToken.getPassword, - userNameAndPasswordAuthenticationToken.getRealm - ) - ) - case _ => throw new IllegalStateException("bam") + val token = createAuthTokenSupplier + val name = token.getName + val username = authParameters.get("username") + val password = authParameters.get("password") + val supplier = token.create(username.orNull, password.orNull, authParameters.asJava) + + name match { + case "basic" => + val token = supplier.get().asInstanceOf[UserNameAndPasswordAuthenticationToken] + new StaticAuthTokenManager(AuthTokens.basic(token.getUsername, token.getPassword)) + case "bearer" | "keycloak" => + AuthTokenManagers.bearer(() => { + val token = supplier.get().asInstanceOf[BearerAuthenticationToken] + val authToken = AuthTokens.bearer(token.getToken) + val exp = token.getExpiresAt + if (exp == null) { + authToken.expiringAt(Long.MaxValue) + } else { + authToken.expiringAt(exp.toEpochMilli) + } + }) + case "custom" => + val token = supplier.get().asInstanceOf[CustomAuthenticationToken] + new StaticAuthTokenManager(AuthTokens.custom( + token.getPrincipal, + token.getCredentials, + token.getRealm, + token.getScheme, + token.getParameters + )) + case "kerberos" => + AuthTokenManagers.basic(() => { + val token = supplier.get().asInstanceOf[KerberosAuthenticationToken] + AuthTokens.kerberos(token.getToken) + }) + case "none" => + new StaticAuthTokenManager(AuthTokens.none()) } } - private def createAuthTokenSupplier: Supplier[AuthenticationToken] = { + private def createAuthTokenSupplier: AuthenticationTokenSupplierFactory = { if (auth == null || auth.isEmpty) { throw new IllegalArgumentException(s"Authentication type name is required") } @@ -562,9 +564,7 @@ case class Neo4jDriverOptions( ) } - val username = authParameters.get("username") - val password = authParameters.get("password") - filteredSupplierFactories.head.create(username.orNull, password.orNull, authParameters.asJava) + filteredSupplierFactories.head } } @@ -722,6 +722,11 @@ object Neo4jOptions { } } +class StaticAuthTokenManager(authToken: AuthToken) extends AuthTokenManager { + override def getToken: CompletionStage[AuthToken] = CompletableFuture.completedStage(authToken) + override def handleSecurityException(authToken: AuthToken, exception: exceptions.SecurityException): Boolean = false +} + class CaseInsensitiveEnumeration extends Enumeration { def withCaseInsensitiveName(s: String): Value = { diff --git a/common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala b/common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala index eac9ead7..f350bb3a 100644 --- a/common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala +++ b/common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala @@ -18,15 +18,17 @@ package org.neo4j.spark.service import org.junit.Test import org.junit.runner.RunWith -import org.mockito.ArgumentMatchers +import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers._ import org.mockito.Mockito.times +import org.neo4j.driver.AuthTokenManager +import org.neo4j.driver.AuthTokens import org.neo4j.driver.Config import org.neo4j.driver.GraphDatabase -import org.neo4j.driver.internal.security.ExpirationBasedAuthTokenManager import org.neo4j.spark.util.DriverCache import org.neo4j.spark.util.Neo4jOptions import org.powermock.api.mockito.PowerMockito +import org.powermock.core.classloader.annotations.PowerMockIgnore import org.powermock.core.classloader.annotations.PrepareForTest import org.powermock.modules.junit4.PowerMockRunner import org.testcontainers.shaded.com.google.common.io.BaseEncoding @@ -36,6 +38,7 @@ import java.util @PrepareForTest(Array(classOf[GraphDatabase])) @RunWith(classOf[PowerMockRunner]) +@PowerMockIgnore(Array("javax.management.*")) class AuthenticationTest { @Test @@ -56,8 +59,9 @@ class AuthenticationTest { driverCache.getOrCreate() PowerMockito.verifyStatic(classOf[GraphDatabase], times(1)) - // was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.custom("", token, "", "")), any(classOf[Config])) - GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]()) + val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager]) + GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]()) + assert(AuthTokens.custom("", token, "", "") == managerCaptor.getValue.getToken.toCompletableFuture.join()) } @Test @@ -77,7 +81,8 @@ class AuthenticationTest { driverCache.getOrCreate() PowerMockito.verifyStatic(classOf[GraphDatabase], times(1)) - // was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.bearer(token)), any()) - GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]()) + val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager]) + GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]()) + assert(AuthTokens.bearer(token) == managerCaptor.getValue.getToken.toCompletableFuture.join()) } } diff --git a/spark/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala b/spark/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala index 68bfd83e..e0aec88d 100644 --- a/spark/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala +++ b/spark/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala @@ -45,7 +45,7 @@ class Neo4jBatchWriter( if (neo4jOptions.indexAwait > 0) { val session = driverCache.getOrCreate().session(neo4jOptions.session.toNeo4jSession()) - session.run(s"CALL db.awaitIndexes(${neo4jOptions.indexAwait})").consume() + session.executeRead(tx => tx.run(s"CALL db.awaitIndexes(${neo4jOptions.indexAwait})").consume()) } new Neo4jDataWriterFactory( diff --git a/spark/src/test/scala/org/neo4j/spark/ReauthenticationIT.scala b/spark/src/test/scala/org/neo4j/spark/ReauthenticationIT.scala index a87ec3a3..0a543856 100644 --- a/spark/src/test/scala/org/neo4j/spark/ReauthenticationIT.scala +++ b/spark/src/test/scala/org/neo4j/spark/ReauthenticationIT.scala @@ -113,7 +113,6 @@ object ReauthenticationIT { class ReauthenticationIT extends SparkConnectorScalaSuiteIT { @Test - @Ignore("Ignored temporarily") def createAnInstanceOfReAuthDriver(): Unit = { val options = Map( "url" -> NEO4J.getBoltUrl,