Skip to content

Commit a4eff37

Browse files
committed
fix: use transaction functions
1 parent e603316 commit a4eff37

File tree

3 files changed

+46
-43
lines changed

3 files changed

+46
-43
lines changed

common/src/main/scala/org/neo4j/spark/service/SchemaService.scala

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ class SchemaService(
114114
query: String,
115115
params: java.util.Map[String, AnyRef]
116116
): mutable.Buffer[StructField] = {
117-
val fields = session.run(query, params, sessionTransactionConfig)
118-
.list
117+
val fields = session.executeRead(tx => tx.run(query, params).list(), sessionTransactionConfig)
119118
.asScala
120119
.filter(record => !record.get("propertyName").isNull && !record.get("propertyName").isEmpty)
121120
.map(record => {
@@ -149,7 +148,7 @@ class SchemaService(
149148
params: java.util.Map[String, AnyRef],
150149
extractFunction: Record => Map[String, AnyRef]
151150
): mutable.Buffer[StructField] = {
152-
session.run(query, params, sessionTransactionConfig).list.asScala
151+
session.executeRead(tx => tx.run(query, params).list.asScala, sessionTransactionConfig)
153152
.flatMap(extractFunction)
154153
.groupBy(_._1)
155154
.mapValues(_.map(_._2))
@@ -340,7 +339,7 @@ class SchemaService(
340339
|RETURN *
341340
|""".stripMargin
342341
val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> options.query.value).asJava
343-
val fields = session.run(query, map, sessionTransactionConfig).list.asScala
342+
val fields = session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig).asScala
344343
.map(r => r.get("field").asList((t: Value) => t.asString()).asScala)
345344
.map(r =>
346345
(
@@ -399,16 +398,15 @@ class SchemaService(
399398
|RETURN *
400399
|""".stripMargin
401400
val map: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
402-
session.run(query, map, sessionTransactionConfig)
403-
.list
401+
session.executeRead(tx => tx.run(query, map).list(), sessionTransactionConfig)
404402
.asScala
405403
.map(r => (r.get("fieldName").asString(), r.get("optional").asBoolean()))
406404
.toSeq
407405
}
408406

409407
private def getReturnedColumns(query: String): Array[String] =
410-
session.run("EXPLAIN " + query, sessionTransactionConfig)
411-
.keys().asScala.toArray
408+
session.executeRead(tx => tx.run("EXPLAIN " + query).keys(), sessionTransactionConfig)
409+
.asScala.toArray
412410

413411
def struct(): StructType = {
414412
val struct = options.query.queryType match {
@@ -469,8 +467,7 @@ class SchemaService(
469467
queryReadStrategy.createStatementForRelationshipCount(options)
470468
}
471469
log.info(s"Executing the following counting query on Neo4j: $query")
472-
session.run(query, sessionTransactionConfig)
473-
.list()
470+
session.executeRead(tx => tx.run(query).list(), sessionTransactionConfig)
474471
.asScala
475472
.map(_.get("count"))
476473
.map(count => if (count.isNull) 0L else count.asLong())
@@ -487,7 +484,7 @@ class SchemaService(
487484
*/
488485
if (filters.isEmpty) {
489486
val query = "CALL apoc.meta.stats() yield labels RETURN labels"
490-
val map = session.run(query, sessionTransactionConfig).single()
487+
val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig)
491488
.asMap()
492489
.asScala
493490
.get("labels")
@@ -510,7 +507,7 @@ class SchemaService(
510507
try {
511508
if (filters.isEmpty) {
512509
val query = "CALL apoc.meta.stats() yield relTypes RETURN relTypes"
513-
val map = session.run(query, sessionTransactionConfig).single()
510+
val map = session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig)
514511
.asMap()
515512
.asScala
516513
.get("relTypes")
@@ -579,29 +576,32 @@ class SchemaService(
579576
|RETURN count(*) AS count
580577
|""".stripMargin
581578
}
582-
session.run(query, sessionTransactionConfig).single().get("count").asLong()
579+
session.executeRead(tx => tx.run(query).single(), sessionTransactionConfig).get("count").asLong()
583580
}
584581
}
585582

586583
def isGdsProcedure(procName: String): Boolean = {
587584
val params: util.Map[String, AnyRef] = Map[String, AnyRef]("procName" -> procName).asJava
588-
session.run(
589-
"""
590-
|CALL gds.list() YIELD name, type
591-
|WHERE name = $procName AND type = 'procedure'
592-
|RETURN count(*) = 1
593-
|""".stripMargin,
594-
params,
585+
session.executeRead(
586+
tx =>
587+
tx.run(
588+
"""
589+
|CALL gds.list() YIELD name, type
590+
|WHERE name = $procName AND type = 'procedure'
591+
|RETURN count(*) = 1
592+
|""".stripMargin,
593+
params
594+
).single(),
595595
sessionTransactionConfig
596596
)
597-
.single()
598597
.get(0)
599598
.asBoolean()
600599
}
601600

602601
def validateQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): String =
603602
try {
604-
val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType()
603+
val queryType =
604+
session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType()
605605
if (expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)) {
606606
""
607607
} else {
@@ -624,7 +624,7 @@ class SchemaService(
624624

625625
def validateQueryCount(query: String): String =
626626
try {
627-
val resultSummary = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume()
627+
val resultSummary = session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig)
628628
val queryType = resultSummary.queryType()
629629
val plan = resultSummary.plan()
630630
val expectedQueryTypes =
@@ -642,7 +642,8 @@ class SchemaService(
642642

643643
def isValidQuery(query: String, expectedQueryTypes: org.neo4j.driver.summary.QueryType*): Boolean =
644644
try {
645-
val queryType = session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType()
645+
val queryType =
646+
session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType()
646647
expectedQueryTypes.isEmpty || expectedQueryTypes.contains(queryType)
647648
} catch {
648649
case e: Throwable => {
@@ -694,8 +695,7 @@ class SchemaService(
694695
"labels" -> Seq(label).asJava,
695696
"properties" -> props.asJava
696697
).asJava.asInstanceOf[util.Map[String, AnyRef]]
697-
val isPresent = session.run(queryCheck, params, sessionTransactionConfig)
698-
.single()
698+
val isPresent = session.executeRead(tx => tx.run(queryCheck, params).single(), sessionTransactionConfig)
699699
.get("isPresent")
700700
.asBoolean()
701701

@@ -704,7 +704,7 @@ class SchemaService(
704704
} else {
705705
val query = s"$queryPrefix $querySuffix"
706706
log.info(s"Performing the following schema query: $query")
707-
session.run(query, sessionTransactionConfig)
707+
session.executeWrite(tx => tx.run(query).consume(), sessionTransactionConfig)
708708
"CREATED"
709709
}
710710
log.info(s"Status for $action named with label $quotedLabel and props $quotedProps is: $status")
@@ -924,12 +924,12 @@ class SchemaService(
924924
def execute(queries: Seq[String]): util.List[util.Map[String, AnyRef]] = {
925925
val queryMap = queries
926926
.map(query => {
927-
(session.run(s"EXPLAIN $query", sessionTransactionConfig).consume().queryType(), query)
927+
(session.executeRead(tx => tx.run(s"EXPLAIN $query").consume(), sessionTransactionConfig).queryType(), query)
928928
})
929929
.groupBy(_._1)
930930
.mapValues(_.map(_._2))
931931
val schemaQueries = queryMap.getOrElse(org.neo4j.driver.summary.QueryType.SCHEMA_WRITE, Seq.empty[String])
932-
schemaQueries.foreach(session.run(_, sessionTransactionConfig))
932+
schemaQueries.foreach(q => session.executeWrite(tx => tx.run(q).consume(), sessionTransactionConfig))
933933
val others = queryMap
934934
.filterKeys(key => key != org.neo4j.driver.summary.QueryType.SCHEMA_WRITE)
935935
.values
@@ -965,12 +965,14 @@ class SchemaService(
965965

966966
private def lastOffsetForNode(): Long = {
967967
val label = options.nodeMetadata.labels.head
968-
session.run(
969-
s"""MATCH (n:$label)
970-
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin,
968+
session.executeRead(
969+
tx =>
970+
tx.run(
971+
s"""MATCH (n:$label)
972+
|RETURN max(n.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin
973+
).single(),
971974
sessionTransactionConfig
972975
)
973-
.single()
974976
.get(options.streamingOptions.propertyName)
975977
.asLong(-1)
976978
}
@@ -980,20 +982,22 @@ class SchemaService(
980982
val targetLabel = options.relationshipMetadata.target.labels.head.quote()
981983
val relType = options.relationshipMetadata.relationshipType.quote()
982984

983-
session.run(
984-
s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel)
985-
|RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin,
985+
session.executeRead(
986+
tx =>
987+
tx.run(
988+
s"""MATCH (s:$sourceLabel)-[r:$relType]->(t:$targetLabel)
989+
|RETURN max(r.${options.streamingOptions.propertyName}) AS ${options.streamingOptions.propertyName}""".stripMargin
990+
).single(),
986991
sessionTransactionConfig
987992
)
988-
.single()
989993
.get(options.streamingOptions.propertyName)
990994
.asLong(-1)
991995
}
992996

993-
private def lastOffsetForQuery(): Long = session.run(options.streamingOptions.queryOffset, sessionTransactionConfig)
994-
.single()
995-
.get(0)
996-
.asLong(-1)
997+
private def lastOffsetForQuery(): Long =
998+
session.executeRead(tx => tx.run(options.streamingOptions.queryOffset).single(), sessionTransactionConfig)
999+
.get(0)
1000+
.asLong(-1)
9971001

9981002
def lastOffset(): Long = options.query.queryType match {
9991003
case QueryType.LABELS => lastOffsetForNode()

spark/src/main/scala/org/neo4j/spark/writer/Neo4jBatchWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class Neo4jBatchWriter(
4545

4646
if (neo4jOptions.indexAwait > 0) {
4747
val session = driverCache.getOrCreate().session(neo4jOptions.session.toNeo4jSession())
48-
session.run(s"CALL db.awaitIndexes(${neo4jOptions.indexAwait})").consume()
48+
session.executeRead(tx => tx.run(s"CALL db.awaitIndexes(${neo4jOptions.indexAwait})").consume())
4949
}
5050

5151
new Neo4jDataWriterFactory(

spark/src/test/scala/org/neo4j/spark/ReauthenticationIT.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ object ReauthenticationIT {
113113
class ReauthenticationIT extends SparkConnectorScalaSuiteIT {
114114

115115
@Test
116-
@Ignore("Ignored temporarily")
117116
def createAnInstanceOfReAuthDriver(): Unit = {
118117
val options = Map(
119118
"url" -> NEO4J.getBoltUrl,

0 commit comments

Comments
 (0)