Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 45 additions & 41 deletions common/src/main/scala/org/neo4j/spark/service/SchemaService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ class SchemaService(
query: String,
params: java.util.Map[String, AnyRef]
): mutable.Buffer[StructField] = {
val fields = session.run(query, params, sessionTransactionConfig)
.list
val fields = session.executeRead(tx => tx.run(query, params).list(), sessionTransactionConfig)
.asScala
.filter(record => !record.get("propertyName").isNull && !record.get("propertyName").isEmpty)
.map(record => {
Expand Down Expand Up @@ -149,7 +148,7 @@ class SchemaService(
params: java.util.Map[String, AnyRef],
extractFunction: Record => Map[String, AnyRef]
): mutable.Buffer[StructField] = {
session.run(query, params, sessionTransactionConfig).list.asScala
session.executeRead(tx => tx.run(query, params).list.asScala, sessionTransactionConfig)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a very small thing, but just wanted to note that asScala is called outside the function in other similar places.

.flatMap(extractFunction)
.groupBy(_._1)
.mapValues(_.map(_._2))
Expand Down Expand Up @@ -340,7 +339,7 @@ class SchemaService(
|RETURN *
|""".stripMargin
val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> options.query.value).asJava
val fields = session.run(query, map, sessionTransactionConfig).list.asScala
val fields = session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig).asScala
.map(r => r.get("field").asList((t: Value) => t.asString()).asScala)
.map(r =>
(
Expand Down Expand Up @@ -399,16 +398,15 @@ class SchemaService(
|RETURN *
|""".stripMargin
val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
session.run(query, map, sessionTransactionConfig)
.list
session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig)
.asScala
.map(r => (r.get("fieldName").asString(), r.get("optional").asBoolean()))
.toSeq
}

private def getReturnedColumns(query: String): Array[String] =
session.run("EXPLAIN " + query, sessionTransactionConfig)
.keys().asScala.toArray
session.executeRead(tx => tx.run("EXPLAIN " + query).keys(), sessionTransactionConfig)
.asScala.toArray

def struct(): StructType = {
val struct = options.query.queryType match {
Expand Down Expand Up @@ -469,8 +467,7 @@ class SchemaService(
queryReadStrategy.createStatementForRelationshipCount(options)
}
log.info(s"Executing the following counting query on Neo4j: $query")
session.run(query, sessionTransactionConfig)
.list()
session.executeRead(tx => tx.run(query).list(), sessionTransactionConfig)
.asScala
.map(_.get("count"))
.map(count => if (count.isNull) 0L else count.asLong())
Expand All @@ -487,7 +484,7 @@ class SchemaService(
*/
if (filters.isEmpty) {
val query = "CALL apoc.meta.stats() yield labels RETURN labels"
val map = session.run(query, sessionTransactionConfig).single()
val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig)
.asMap()
.asScala
.get("labels")
Expand All @@ -510,7 +507,7 @@ class SchemaService(
try {
if (filters.isEmpty) {
val query = "CALL apoc.meta.stats() yield relTypes RETURN relTypes"
val map = session.run(query, sessionTransactionConfig).single()
val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig)
.asMap()
.asScala
.get("relTypes")
Expand Down Expand Up @@ -579,29 +576,32 @@ class SchemaService(
|RETURN count(*) AS count
|""".stripMargin
}
session.run(query, sessionTransactionConfig).single().get("count").asLong()
session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig).get("count").asLong()
}
}

def isGdsProcedure(procName: String): Boolean = {
val params: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
session.run(
"""
|CALL gds.list() YIELD name, type
|WHERE name = $procName AND type = 'procedure'
|RETURN count(*) = 1
|""".stripMargin,
params,
session.executeRead(
tx =>
tx.run(
"""
|CALL gds.list() YIELD name, type
|WHERE name = $procName AND type = 'procedure'
|RETURN count(*) = 1
|""".stripMargin,
params
).single(),
sessionTransactionConfig
)
.single()
.get(0)
.asBoolean()
}

def validateQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): String =
try {
val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType()
val queryType =
session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType()
if (expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)) {
""
} else {
Expand All @@ -624,7 +624,7 @@ class SchemaService(

def validateQueryCount(query: String): String =
try {
val resultSummary = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume()
val resultSummary = session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig)
val queryType = resultSummary.queryType()
val plan = resultSummary.plan()
val expectedQueryTypes =
Expand All @@ -642,7 +642,8 @@ class SchemaService(

def isValidQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): Boolean =
try {
val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType()
val queryType =
session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType()
expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)
} catch {
case e: Throwable => {
Expand Down Expand Up @@ -694,8 +695,7 @@ class SchemaService(
"labels" -> Seq(label).asJava,
"properties" -> props.asJava
).asJava.asInstanceOf[util.Map[String, AnyRef]]
val isPresent = session.run(queryCheck, params, sessionTransactionConfig)
.single()
val isPresent = session.executeRead(tx => tx.run(queryCheck, params).single(), sessionTransactionConfig)
.get("isPresent")
.asBoolean()

Expand All @@ -704,7 +704,7 @@ class SchemaService(
} else {
val query = s"$queryPrefix $querySuffix"
log.info(s"Performing the following schema query: $query")
session.run(query, sessionTransactionConfig)
session.executeWrite(tx => tx.run(query).consume(), sessionTransactionConfig)
"CREATED"
}
log.info(s"Status for $action named with label $quotedLabel and props $quotedProps is: $status")
Expand Down Expand Up @@ -924,12 +924,12 @@ class SchemaService(
def execute(queries: Seq[String]): util.List[util.Map[String, AnyRef]] = {
val queryMap = queries
.map(query => {
(session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType(), query)
(session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType(), query)
})
.groupBy(_._1)
.mapValues(_.map(_._2))
val schemaQueries = queryMap.getOrElse(org.neo4j.driver.summary.QueryType.SCHEMA_WRITE, Seq.empty[String])
schemaQueries.foreach(session.run(_, sessionTransactionConfig))
schemaQueries.foreach(q => session.executeWrite(tx => tx.run(q).consume(), sessionTransactionConfig))
val others = queryMap
.filterKeys(key => key != org.neo4j.driver.summary.QueryType.SCHEMA_WRITE)
.values
Expand Down Expand Up @@ -965,12 +965,14 @@ class SchemaService(

private def lastOffsetForNode(): Long = {
val label = options.nodeMetadata.labels.head
session.run(
s"""MATCH (n:$label)
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin,
session.executeRead(
tx =>
tx.run(
s"""MATCH (n:$label)
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin
).single(),
sessionTransactionConfig
)
.single()
.get(options.streamingOptions.propertyName)
.asLong(-1)
}
Expand All @@ -980,20 +982,22 @@ class SchemaService(
val targetLabel = options.relationshipMetadata.target.labels.head.quote()
val relType = options.relationshipMetadata.relationshipType.quote()

session.run(
s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel)
|RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin,
session.executeRead(
tx =>
tx.run(
s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel)
|RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin
).single(),
sessionTransactionConfig
)
.single()
.get(options.streamingOptions.propertyName)
.asLong(-1)
}

private def lastOffsetForQuery(): Long = session.run(options.streamingOptions.queryOffset, sessionTransactionConfig)
.single()
.get(0)
.asLong(-1)
private def lastOffsetForQuery(): Long =
session.executeRead(tx => tx.run(options.streamingOptions.queryOffset).single(), sessionTransactionConfig)
.get(0)
.asLong(-1)

def lastOffset(): Long = options.query.queryType match {
case QueryType.LABELS => lastOffsetForNode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Neo4jBatchWriter(

if (neo4jOptions.indexAwait > 0) {
val session = driverCache.getOrCreate().session(neo4jOptions.session.toNeo4jSession())
session.run(s"CALL db.awaitIndexes(${neo4jOptions.indexAwait})").consume()
session.executeRead(tx => tx.run(s"CALL db.awaitIndexes(${neo4jOptions.indexAwait})").consume())
}

new Neo4jDataWriterFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ object ReauthenticationIT {
class ReauthenticationIT extends SparkConnectorScalaSuiteIT {

@Test
@Ignore("Ignored temporarily")
def createAnInstanceOfReAuthDriver(): Unit = {
val options = Map(
"url" -> NEO4J.getBoltUrl,
Expand Down