Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 46 additions & 42 deletions common/src/main/scala/org/neo4j/spark/service/SchemaService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ class SchemaService(
query: String,
params: java.util.Map[String, AnyRef]
): mutable.Buffer[StructField] = {
val fields = session.run(query, params, sessionTransactionConfig)
.list
val fields = session.executeRead(tx => tx.run(query, params).list(), sessionTransactionConfig)
.asScala
.filter(record => !record.get("propertyName").isNull && !record.get("propertyName").isEmpty)
.map(record => {
Expand Down Expand Up @@ -149,7 +148,7 @@ class SchemaService(
params: java.util.Map[String, AnyRef],
extractFunction: Record => Map[String, AnyRef]
): mutable.Buffer[StructField] = {
session.run(query, params, sessionTransactionConfig).list.asScala
session.executeRead(tx => tx.run(query, params).list.asScala, sessionTransactionConfig)
.flatMap(extractFunction)
.groupBy(_._1)
.mapValues(_.map(_._2))
Expand Down Expand Up @@ -340,7 +339,7 @@ class SchemaService(
|RETURN *
|""".stripMargin
val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> options.query.value).asJava
val fields = session.run(query, map, sessionTransactionConfig).list.asScala
val fields = session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig).asScala
.map(r => r.get("field").asList((t: Value) => t.asString()).asScala)
.map(r =>
(
Expand Down Expand Up @@ -399,16 +398,15 @@ class SchemaService(
|RETURN *
|""".stripMargin
val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
session.run(query, map, sessionTransactionConfig)
.list
session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig)
.asScala
.map(r => (r.get("fieldName").asString(), r.get("optional").asBoolean()))
.toSeq
}

private def getReturnedColumns(query: String): Array[String] =
session.run("EXPLAIN " + query, sessionTransactionConfig)
.keys().asScala.toArray
session.executeRead(tx => tx.run("EXPLAIN " + query).keys(), sessionTransactionConfig)
.asScala.toArray

def struct(): StructType = {
val struct = options.query.queryType match {
Expand Down Expand Up @@ -469,8 +467,7 @@ class SchemaService(
queryReadStrategy.createStatementForRelationshipCount(options)
}
log.info(s"Executing the following counting query on Neo4j: $query")
session.run(query, sessionTransactionConfig)
.list()
session.executeRead(tx => tx.run(query).list(), sessionTransactionConfig)
.asScala
.map(_.get("count"))
.map(count => if (count.isNull) 0L else count.asLong())
Expand All @@ -487,7 +484,7 @@ class SchemaService(
*/
if (filters.isEmpty) {
val query = "CALL apoc.meta.stats() yield labels RETURN labels"
val map = session.run(query, sessionTransactionConfig).single()
val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig)
.asMap()
.asScala
.get("labels")
Expand All @@ -510,7 +507,7 @@ class SchemaService(
try {
if (filters.isEmpty) {
val query = "CALL apoc.meta.stats() yield relTypes RETURN relTypes"
val map = session.run(query, sessionTransactionConfig).single()
val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig)
.asMap()
.asScala
.get("relTypes")
Expand Down Expand Up @@ -579,29 +576,32 @@ class SchemaService(
|RETURN count(*) AS count
|""".stripMargin
}
session.run(query, sessionTransactionConfig).single().get("count").asLong()
session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig).get("count").asLong()
}
}

def isGdsProcedure(procName: String): Boolean = {
val params: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
session.run(
"""
|CALL gds.list() YIELD name, type
|WHERE name = $procName AND type = 'procedure'
|RETURN count(*) = 1
|""".stripMargin,
params,
session.executeRead(
tx =>
tx.run(
"""
|CALL gds.list() YIELD name, type
|WHERE name = $procName AND type = 'procedure'
|RETURN count(*) = 1
|""".stripMargin,
params
).single(),
sessionTransactionConfig
)
.single()
.get(0)
.asBoolean()
}

def validateQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): String =
try {
val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType()
val queryType =
session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType()
if (expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)) {
""
} else {
Expand All @@ -624,7 +624,7 @@ class SchemaService(

def validateQueryCount(query: String): String =
try {
val resultSummary = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume()
val resultSummary = session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig)
val queryType = resultSummary.queryType()
val plan = resultSummary.plan()
val expectedQueryTypes =
Expand All @@ -642,7 +642,8 @@ class SchemaService(

def isValidQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): Boolean =
try {
val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType()
val queryType =
session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType()
expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)
} catch {
case e: Throwable => {
Expand Down Expand Up @@ -694,8 +695,7 @@ class SchemaService(
"labels" -> Seq(label).asJava,
"properties" -> props.asJava
).asJava.asInstanceOf[util.Map[String, AnyRef]]
val isPresent = session.run(queryCheck, params, sessionTransactionConfig)
.single()
val isPresent = session.executeRead(tx => tx.run(queryCheck, params).single(), sessionTransactionConfig)
.get("isPresent")
.asBoolean()

Expand All @@ -704,7 +704,7 @@ class SchemaService(
} else {
val query = s"$queryPrefix $querySuffix"
log.info(s"Performing the following schema query: $query")
session.run(query, sessionTransactionConfig)
session.executeWrite(tx => tx.run(query).consume(), sessionTransactionConfig)
"CREATED"
}
log.info(s"Status for $action named with label $quotedLabel and props $quotedProps is: $status")
Expand All @@ -730,7 +730,7 @@ class SchemaService(
s"spark_${entityType}_${constraintType.replace(s"$entityType ", "")}-CONSTRAINT_${entityIdentifier}_$dashSeparatedProps".quote()
val props = keys.values.map(_.quote()).map("e." + _).mkString(", ")
val asciiRepresentation: String = createCypherPattern(entityType, entityIdentifier)
session.executeWrite(
session.executeWriteWithoutResult(
tx => {
tx.run(
s"CREATE CONSTRAINT $constraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE ($props) IS $constraintType"
Expand Down Expand Up @@ -924,12 +924,12 @@ class SchemaService(
def execute(queries: Seq[String]): util.List[util.Map[String, AnyRef]] = {
val queryMap = queries
.map(query => {
(session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType(), query)
(session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType(), query)
})
.groupBy(_._1)
.mapValues(_.map(_._2))
val schemaQueries = queryMap.getOrElse(org.neo4j.driver.summary.QueryType.SCHEMA_WRITE, Seq.empty[String])
schemaQueries.foreach(session.run(_, sessionTransactionConfig))
schemaQueries.foreach(q => session.executeWrite(tx => tx.run(q).consume(), sessionTransactionConfig))
val others = queryMap
.filterKeys(key => key != org.neo4j.driver.summary.QueryType.SCHEMA_WRITE)
.values
Expand Down Expand Up @@ -965,12 +965,14 @@ class SchemaService(

private def lastOffsetForNode(): Long = {
val label = options.nodeMetadata.labels.head
session.run(
s"""MATCH (n:$label)
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin,
session.executeRead(
tx =>
tx.run(
s"""MATCH (n:$label)
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin
).single(),
sessionTransactionConfig
)
.single()
.get(options.streamingOptions.propertyName)
.asLong(-1)
}
Expand All @@ -980,20 +982,22 @@ class SchemaService(
val targetLabel = options.relationshipMetadata.target.labels.head.quote()
val relType = options.relationshipMetadata.relationshipType.quote()

session.run(
s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel)
|RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin,
session.executeRead(
tx =>
tx.run(
s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel)
|RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin
).single(),
sessionTransactionConfig
)
.single()
.get(options.streamingOptions.propertyName)
.asLong(-1)
}

private def lastOffsetForQuery(): Long = session.run(options.streamingOptions.queryOffset, sessionTransactionConfig)
.single()
.get(0)
.asLong(-1)
private def lastOffsetForQuery(): Long =
session.executeRead(tx => tx.run(options.streamingOptions.queryOffset).single(), sessionTransactionConfig)
.get(0)
.asLong(-1)

def lastOffset(): Long = options.query.queryType match {
case QueryType.LABELS => lastOffsetForNode()
Expand Down
87 changes: 46 additions & 41 deletions common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.SparkSession
import org.jetbrains.annotations.TestOnly
import org.neo4j.connectors.authn.AuthenticationToken
import org.neo4j.connectors.authn.AuthenticationTokenSupplierFactory
import org.neo4j.connectors.authn.BearerAuthenticationToken
import org.neo4j.connectors.authn.CustomAuthenticationToken
import org.neo4j.connectors.authn.DisabledAuthenticationToken
import org.neo4j.connectors.authn.ExpiringAuthenticationToken
import org.neo4j.connectors.authn.KerberosAuthenticationToken
import org.neo4j.connectors.authn.UserNameAndPasswordAuthenticationToken
import org.neo4j.driver.Config.TrustStrategy
Expand All @@ -40,8 +37,9 @@ import java.time.Duration
import java.util
import java.util.ServiceLoader
import java.util.UUID
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import java.util.concurrent.TimeUnit
import java.util.function.Supplier

import scala.collection.JavaConverters._
import scala.language.implicitConversions
Expand Down Expand Up @@ -498,47 +496,51 @@ case class Neo4jDriverOptions(
(URI.create(urls.head.trim), resolved)
}

// TODO this is partially intentionally not working as expect, changed to make it compile
// TODO missing here is also the keycloack support that comes from the authn commons
private def createAuthTokenManager: AuthTokenManager = {
if (auth == null || auth.isEmpty) {
throw new IllegalArgumentException(s"Authentication type name is required")
}
val token = createAuthTokenSupplier.get()
token match {
case bearerAuthenticationToken: BearerAuthenticationToken =>
AuthTokenManagers.bearer(() =>
AuthTokens.bearer(
bearerAuthenticationToken.getToken
).expiringAt(bearerAuthenticationToken.getExpiresAt.toEpochMilli)
)
case customAuthenticationToken: CustomAuthenticationToken =>
AuthTokenManagers.basic(() =>
AuthTokens.custom(
customAuthenticationToken.getPrincipal,
customAuthenticationToken.getCredentials,
customAuthenticationToken.getRealm,
customAuthenticationToken.getScheme,
customAuthenticationToken.getParameters
)
)
case disabledAuthenticationToken: DisabledAuthenticationToken =>
AuthTokenManagers.basic(() => AuthTokens.none())
case kerberosAuthenticationToken: KerberosAuthenticationToken =>
AuthTokenManagers.basic(() => AuthTokens.kerberos(kerberosAuthenticationToken.getToken))
case userNameAndPasswordAuthenticationToken: UserNameAndPasswordAuthenticationToken =>
AuthTokenManagers.basic(() =>
AuthTokens.basic(
userNameAndPasswordAuthenticationToken.getUsername,
userNameAndPasswordAuthenticationToken.getPassword,
userNameAndPasswordAuthenticationToken.getRealm
)
)
case _ => throw new IllegalStateException("bam")
val token = createAuthTokenSupplier
val name = token.getName
val username = authParameters.get("username")
val password = authParameters.get("password")
val supplier = token.create(username.orNull, password.orNull, authParameters.asJava)

name match {
case "basic" =>
val token = supplier.get().asInstanceOf[UserNameAndPasswordAuthenticationToken]
new StaticAuthTokenManager(AuthTokens.basic(token.getUsername, token.getPassword))
case "bearer" | "keycloak" =>
AuthTokenManagers.bearer(() => {
val token = supplier.get().asInstanceOf[BearerAuthenticationToken]
val authToken = AuthTokens.bearer(token.getToken)
val exp = token.getExpiresAt
if (exp == null) {
authToken.expiringAt(Long.MaxValue)
} else {
authToken.expiringAt(exp.toEpochMilli)
}
})
case "custom" =>
val token = supplier.get().asInstanceOf[CustomAuthenticationToken]
new StaticAuthTokenManager(AuthTokens.custom(
token.getPrincipal,
token.getCredentials,
token.getRealm,
token.getScheme,
token.getParameters
))
case "kerberos" =>
AuthTokenManagers.basic(() => {
val token = supplier.get().asInstanceOf[KerberosAuthenticationToken]
AuthTokens.kerberos(token.getToken)
})
case "none" =>
new StaticAuthTokenManager(AuthTokens.none())
}
}

private def createAuthTokenSupplier: Supplier[AuthenticationToken] = {
private def createAuthTokenSupplier: AuthenticationTokenSupplierFactory = {
if (auth == null || auth.isEmpty) {
throw new IllegalArgumentException(s"Authentication type name is required")
}
Expand All @@ -562,9 +564,7 @@ case class Neo4jDriverOptions(
)
}

val username = authParameters.get("username")
val password = authParameters.get("password")
filteredSupplierFactories.head.create(username.orNull, password.orNull, authParameters.asJava)
filteredSupplierFactories.head
}

}
Expand Down Expand Up @@ -722,6 +722,11 @@ object Neo4jOptions {
}
}

class StaticAuthTokenManager(authToken: AuthToken) extends AuthTokenManager {
override def getToken: CompletionStage[AuthToken] = CompletableFuture.completedStage(authToken)
override def handleSecurityException(authToken: AuthToken, exception: exceptions.SecurityException): Boolean = false
}

class CaseInsensitiveEnumeration extends Enumeration {

def withCaseInsensitiveName(s: String): Value = {
Expand Down
Loading