Skip to content

Commit 1513871

Browse files
committed
Make authentication handling work with 6.x Java Driver
The main objective of this update is to adjust authentication handling to make it work with 6.x Java Driver.
1 parent d05a28e commit 1513871

File tree

4 files changed

+72
-56
lines changed

4 files changed

+72
-56
lines changed

common/src/main/scala/org/neo4j/spark/service/SchemaService.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class SchemaService(
730730
s"spark_${entityType}_${constraintType.replace(s"$entityType ", "")}-CONSTRAINT_${entityIdentifier}_$dashSeparatedProps".quote()
731731
val props = keys.values.map(_.quote()).map("e." + _).mkString(", ")
732732
val asciiRepresentation: String = createCypherPattern(entityType, entityIdentifier)
733-
session.executeWrite(
733+
session.executeWriteWithoutResult(
734734
tx => {
735735
tx.run(
736736
s"CREATE CONSTRAINT $constraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE ($props) IS $constraintType"

common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ import org.apache.spark.internal.Logging
2020
import org.apache.spark.sql.SaveMode
2121
import org.apache.spark.sql.SparkSession
2222
import org.jetbrains.annotations.TestOnly
23-
import org.neo4j.connectors.authn.AuthenticationToken
2423
import org.neo4j.connectors.authn.AuthenticationTokenSupplierFactory
2524
import org.neo4j.connectors.authn.BearerAuthenticationToken
2625
import org.neo4j.connectors.authn.CustomAuthenticationToken
27-
import org.neo4j.connectors.authn.DisabledAuthenticationToken
28-
import org.neo4j.connectors.authn.ExpiringAuthenticationToken
2926
import org.neo4j.connectors.authn.KerberosAuthenticationToken
3027
import org.neo4j.connectors.authn.UserNameAndPasswordAuthenticationToken
3128
import org.neo4j.driver.Config.TrustStrategy
@@ -40,8 +37,9 @@ import java.time.Duration
4037
import java.util
4138
import java.util.ServiceLoader
4239
import java.util.UUID
40+
import java.util.concurrent.CompletableFuture
41+
import java.util.concurrent.CompletionStage
4342
import java.util.concurrent.TimeUnit
44-
import java.util.function.Supplier
4543

4644
import scala.collection.JavaConverters._
4745
import scala.language.implicitConversions
@@ -498,47 +496,51 @@ case class Neo4jDriverOptions(
498496
(URI.create(urls.head.trim), resolved)
499497
}
500498

501-
// TODO this is partially intentionally not working as expect, changed to make it compile
502-
// TODO missing here is also the keycloack support that comes from the authn commons
503499
private def createAuthTokenManager: AuthTokenManager = {
504500
if (auth == null || auth.isEmpty) {
505501
throw new IllegalArgumentException(s"Authentication type name is required")
506502
}
507-
val token = createAuthTokenSupplier.get()
508-
token match {
509-
case bearerAuthenticationToken: BearerAuthenticationToken =>
510-
AuthTokenManagers.bearer(() =>
511-
AuthTokens.bearer(
512-
bearerAuthenticationToken.getToken
513-
).expiringAt(bearerAuthenticationToken.getExpiresAt.toEpochMilli)
514-
)
515-
case customAuthenticationToken: CustomAuthenticationToken =>
516-
AuthTokenManagers.basic(() =>
517-
AuthTokens.custom(
518-
customAuthenticationToken.getPrincipal,
519-
customAuthenticationToken.getCredentials,
520-
customAuthenticationToken.getRealm,
521-
customAuthenticationToken.getScheme,
522-
customAuthenticationToken.getParameters
523-
)
524-
)
525-
case disabledAuthenticationToken: DisabledAuthenticationToken =>
526-
AuthTokenManagers.basic(() => AuthTokens.none())
527-
case kerberosAuthenticationToken: KerberosAuthenticationToken =>
528-
AuthTokenManagers.basic(() => AuthTokens.kerberos(kerberosAuthenticationToken.getToken))
529-
case userNameAndPasswordAuthenticationToken: UserNameAndPasswordAuthenticationToken =>
530-
AuthTokenManagers.basic(() =>
531-
AuthTokens.basic(
532-
userNameAndPasswordAuthenticationToken.getUsername,
533-
userNameAndPasswordAuthenticationToken.getPassword,
534-
userNameAndPasswordAuthenticationToken.getRealm
535-
)
536-
)
537-
case _ => throw new IllegalStateException("bam")
503+
val token = createAuthTokenSupplier
504+
val name = token.getName
505+
val username = authParameters.get("username")
506+
val password = authParameters.get("password")
507+
val supplier = token.create(username.orNull, password.orNull, authParameters.asJava)
508+
509+
name match {
510+
case "basic" =>
511+
val token = supplier.get().asInstanceOf[UserNameAndPasswordAuthenticationToken]
512+
new StaticAuthTokenManager(AuthTokens.basic(token.getUsername, token.getPassword))
513+
case "bearer" | "keycloak" =>
514+
AuthTokenManagers.bearer(() => {
515+
val token = supplier.get().asInstanceOf[BearerAuthenticationToken]
516+
val authToken = AuthTokens.bearer(token.getToken)
517+
val exp = token.getExpiresAt
518+
if (exp == null) {
519+
authToken.expiringAt(Long.MaxValue)
520+
} else {
521+
authToken.expiringAt(exp.toEpochMilli)
522+
}
523+
})
524+
case "custom" =>
525+
val token = supplier.get().asInstanceOf[CustomAuthenticationToken]
526+
new StaticAuthTokenManager(AuthTokens.custom(
527+
token.getPrincipal,
528+
token.getCredentials,
529+
token.getRealm,
530+
token.getScheme,
531+
token.getParameters
532+
))
533+
case "kerberos" =>
534+
AuthTokenManagers.basic(() => {
535+
val token = supplier.get().asInstanceOf[KerberosAuthenticationToken]
536+
AuthTokens.kerberos(token.getToken)
537+
})
538+
case "none" =>
539+
new StaticAuthTokenManager(AuthTokens.none())
538540
}
539541
}
540542

541-
private def createAuthTokenSupplier: Supplier[AuthenticationToken] = {
543+
private def createAuthTokenSupplier: AuthenticationTokenSupplierFactory = {
542544
if (auth == null || auth.isEmpty) {
543545
throw new IllegalArgumentException(s"Authentication type name is required")
544546
}
@@ -562,9 +564,7 @@ case class Neo4jDriverOptions(
562564
)
563565
}
564566

565-
val username = authParameters.get("username")
566-
val password = authParameters.get("password")
567-
filteredSupplierFactories.head.create(username.orNull, password.orNull, authParameters.asJava)
567+
filteredSupplierFactories.head
568568
}
569569

570570
}
@@ -722,6 +722,11 @@ object Neo4jOptions {
722722
}
723723
}
724724

725+
class StaticAuthTokenManager(authToken: AuthToken) extends AuthTokenManager {
726+
override def getToken: CompletionStage[AuthToken] = CompletableFuture.completedStage(authToken)
727+
override def handleSecurityException(authToken: AuthToken, exception: exceptions.SecurityException): Boolean = false
728+
}
729+
725730
class CaseInsensitiveEnumeration extends Enumeration {
726731

727732
def withCaseInsensitiveName(s: String): Value = {

common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@ package org.neo4j.spark.service
1818

1919
import org.junit.Test
2020
import org.junit.runner.RunWith
21-
import org.mockito.ArgumentMatchers
21+
import org.mockito.ArgumentCaptor
2222
import org.mockito.ArgumentMatchers._
2323
import org.mockito.Mockito.times
24+
import org.neo4j.driver.AuthTokenManager
25+
import org.neo4j.driver.AuthTokens
2426
import org.neo4j.driver.Config
2527
import org.neo4j.driver.GraphDatabase
26-
import org.neo4j.driver.internal.security.ExpirationBasedAuthTokenManager
2728
import org.neo4j.spark.util.DriverCache
2829
import org.neo4j.spark.util.Neo4jOptions
2930
import org.powermock.api.mockito.PowerMockito
31+
import org.powermock.core.classloader.annotations.PowerMockIgnore
3032
import org.powermock.core.classloader.annotations.PrepareForTest
3133
import org.powermock.modules.junit4.PowerMockRunner
3234
import org.testcontainers.shaded.com.google.common.io.BaseEncoding
@@ -36,6 +38,7 @@ import java.util
3638

3739
@PrepareForTest(Array(classOf[GraphDatabase]))
3840
@RunWith(classOf[PowerMockRunner])
41+
@PowerMockIgnore(Array("javax.management.*"))
3942
class AuthenticationTest {
4043

4144
@Test
@@ -56,8 +59,9 @@ class AuthenticationTest {
5659
driverCache.getOrCreate()
5760

5861
PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
59-
// was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.custom("", token, "", "")), any(classOf[Config]))
60-
GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]())
62+
val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager])
63+
GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]())
64+
assert(AuthTokens.custom("", token, "", "") == managerCaptor.getValue.getToken.toCompletableFuture.join())
6165
}
6266

6367
@Test
@@ -77,7 +81,8 @@ class AuthenticationTest {
7781
driverCache.getOrCreate()
7882

7983
PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
80-
// was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.bearer(token)), any())
81-
GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]())
84+
val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager])
85+
GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]())
86+
assert(AuthTokens.bearer(token) == managerCaptor.getValue.getToken.toCompletableFuture.join())
8287
}
8388
}

spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterTSE.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,11 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
576576
@Test
577577
def `should throw an error because the node already exists`(): Unit = {
578578
SparkConnectorScalaSuiteIT.session()
579-
.executeWrite(tx => tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE"))
579+
.executeWriteWithoutResult(tx =>
580+
tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE")
581+
)
580582
SparkConnectorScalaSuiteIT.session()
581-
.executeWrite(tx => tx.run("CREATE (p:Person{name: 'Andrea', surname: 'Santurbano'})"))
583+
.executeWriteWithoutResult(tx => tx.run("CREATE (p:Person{name: 'Andrea', surname: 'Santurbano'})"))
582584

583585
val ds = Seq(SimplePerson("Andrea", "Santurbano")).toDS()
584586

@@ -602,16 +604,18 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
602604
}
603605
} finally {
604606
SparkConnectorScalaSuiteIT.session()
605-
.executeWrite(tx => tx.run("DROP CONSTRAINT person_surname"))
607+
.executeWriteWithoutResult(tx => tx.run("DROP CONSTRAINT person_surname"))
606608
}
607609
}
608610

609611
@Test
610612
def `should update the node that already exists`(): Unit = {
611613
SparkConnectorScalaSuiteIT.session()
612-
.executeWrite(tx => tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE"))
614+
.executeWriteWithoutResult(tx =>
615+
tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE")
616+
)
613617
SparkConnectorScalaSuiteIT.session()
614-
.executeWrite(tx => tx.run("CREATE (p:Person{name: 'Federico', surname: 'Santurbano'})"))
618+
.executeWriteWithoutResult(tx => tx.run("CREATE (p:Person{name: 'Federico', surname: 'Santurbano'})"))
615619

616620
val ds = Seq(SimplePerson("Andrea", "Santurbano")).toDS()
617621

@@ -635,7 +639,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
635639
assertEquals("Andrea", nodeList.head.get("n").asNode().get("name").asString())
636640

637641
SparkConnectorScalaSuiteIT.session()
638-
.executeWrite(tx => tx.run("DROP CONSTRAINT person_surname"))
642+
.executeWriteWithoutResult(tx => tx.run("DROP CONSTRAINT person_surname"))
639643
}
640644

641645
@Test
@@ -766,7 +770,9 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
766770
@Test
767771
def `should handle unusual column names`(): Unit = {
768772
SparkConnectorScalaSuiteIT.session()
769-
.executeWrite(tx => tx.run("CREATE CONSTRAINT instrument_name FOR (i:Instrument) REQUIRE i.name IS UNIQUE"))
773+
.executeWriteWithoutResult(tx =>
774+
tx.run("CREATE CONSTRAINT instrument_name FOR (i:Instrument) REQUIRE i.name IS UNIQUE")
775+
)
770776

771777
val musicDf = Seq(
772778
(12, "John Bonham", "Drums", "f``````oo"),
@@ -791,7 +797,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
791797
.save()
792798

793799
SparkConnectorScalaSuiteIT.session()
794-
.executeWrite(tx => tx.run("DROP CONSTRAINT instrument_name"))
800+
.executeWriteWithoutResult(tx => tx.run("DROP CONSTRAINT instrument_name"))
795801

796802
val musicDfCheck = ss.read.format(classOf[DataSource].getName)
797803
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)

0 commit comments

Comments
 (0)