Skip to content

Commit 9e44238

Browse files
committed
feat(auth): Adjust authentication handling for driver update
The main objective of this update is to adjust authentication handling to make it work with 6.x Java Driver.
1 parent d05a28e commit 9e44238

File tree

4 files changed

+72
-56
lines changed

4 files changed

+72
-56
lines changed

common/src/main/scala/org/neo4j/spark/service/SchemaService.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class SchemaService(
730730
s"spark_${entityType}_${constraintType.replace(s"$entityType ", "")}-CONSTRAINT_${entityIdentifier}_$dashSeparatedProps".quote()
731731
val props = keys.values.map(_.quote()).map("e." + _).mkString(", ")
732732
val asciiRepresentation: String = createCypherPattern(entityType, entityIdentifier)
733-
session.executeWrite(
733+
session.executeWriteWithoutResult(
734734
tx => {
735735
tx.run(
736736
s"CREATE CONSTRAINT $constraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE ($props) IS $constraintType"

common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ import org.apache.spark.internal.Logging
2020
import org.apache.spark.sql.SaveMode
2121
import org.apache.spark.sql.SparkSession
2222
import org.jetbrains.annotations.TestOnly
23-
import org.neo4j.connectors.authn.AuthenticationToken
2423
import org.neo4j.connectors.authn.AuthenticationTokenSupplierFactory
2524
import org.neo4j.connectors.authn.BearerAuthenticationToken
2625
import org.neo4j.connectors.authn.CustomAuthenticationToken
27-
import org.neo4j.connectors.authn.DisabledAuthenticationToken
28-
import org.neo4j.connectors.authn.ExpiringAuthenticationToken
2926
import org.neo4j.connectors.authn.KerberosAuthenticationToken
3027
import org.neo4j.connectors.authn.UserNameAndPasswordAuthenticationToken
3128
import org.neo4j.driver.Config.TrustStrategy
@@ -40,8 +37,9 @@ import java.time.Duration
4037
import java.util
4138
import java.util.ServiceLoader
4239
import java.util.UUID
40+
import java.util.concurrent.CompletableFuture
41+
import java.util.concurrent.CompletionStage
4342
import java.util.concurrent.TimeUnit
44-
import java.util.function.Supplier
4543

4644
import scala.collection.JavaConverters._
4745
import scala.language.implicitConversions
@@ -498,47 +496,51 @@ case class Neo4jDriverOptions(
498496
(URI.create(urls.head.trim), resolved)
499497
}
500498

501-
// TODO this is partially intentionally not working as expect, changed to make it compile
502-
// TODO missing here is also the keycloack support that comes from the authn commons
503499
private def createAuthTokenManager: AuthTokenManager = {
504500
if (auth == null || auth.isEmpty) {
505501
throw new IllegalArgumentException(s"Authentication type name is required")
506502
}
507-
val token = createAuthTokenSupplier.get()
508-
token match {
509-
case bearerAuthenticationToken: BearerAuthenticationToken =>
510-
AuthTokenManagers.bearer(() =>
511-
AuthTokens.bearer(
512-
bearerAuthenticationToken.getToken
513-
).expiringAt(bearerAuthenticationToken.getExpiresAt.toEpochMilli)
514-
)
515-
case customAuthenticationToken: CustomAuthenticationToken =>
516-
AuthTokenManagers.basic(() =>
517-
AuthTokens.custom(
518-
customAuthenticationToken.getPrincipal,
519-
customAuthenticationToken.getCredentials,
520-
customAuthenticationToken.getRealm,
521-
customAuthenticationToken.getScheme,
522-
customAuthenticationToken.getParameters
523-
)
524-
)
525-
case disabledAuthenticationToken: DisabledAuthenticationToken =>
526-
AuthTokenManagers.basic(() => AuthTokens.none())
527-
case kerberosAuthenticationToken: KerberosAuthenticationToken =>
528-
AuthTokenManagers.basic(() => AuthTokens.kerberos(kerberosAuthenticationToken.getToken))
529-
case userNameAndPasswordAuthenticationToken: UserNameAndPasswordAuthenticationToken =>
530-
AuthTokenManagers.basic(() =>
531-
AuthTokens.basic(
532-
userNameAndPasswordAuthenticationToken.getUsername,
533-
userNameAndPasswordAuthenticationToken.getPassword,
534-
userNameAndPasswordAuthenticationToken.getRealm
535-
)
536-
)
537-
case _ => throw new IllegalStateException("bam")
503+
val token = createAuthTokenSupplier
504+
val name = token.getName
505+
val username = authParameters.get("username")
506+
val password = authParameters.get("password")
507+
val supplier = token.create(username.orNull, password.orNull, authParameters.asJava)
508+
509+
name match {
510+
case "basic" =>
511+
val token = supplier.get().asInstanceOf[UserNameAndPasswordAuthenticationToken]
512+
new StaticAuthTokenManager(AuthTokens.basic(token.getUsername, token.getPassword))
513+
case "bearer" | "keycloak" =>
514+
AuthTokenManagers.bearer(() => {
515+
val token = supplier.get().asInstanceOf[BearerAuthenticationToken]
516+
val authToken = AuthTokens.bearer(token.getToken)
517+
val exp = token.getExpiresAt
518+
if (exp == null) {
519+
authToken.expiringAt(Long.MaxValue)
520+
} else {
521+
authToken.expiringAt(exp.toEpochMilli)
522+
}
523+
})
524+
case "custom" =>
525+
val token = supplier.get().asInstanceOf[CustomAuthenticationToken]
526+
new StaticAuthTokenManager(AuthTokens.custom(
527+
token.getPrincipal,
528+
token.getCredentials,
529+
token.getRealm,
530+
token.getScheme,
531+
token.getParameters
532+
))
533+
case "kerberos" =>
534+
AuthTokenManagers.basic(() => {
535+
val token = supplier.get().asInstanceOf[KerberosAuthenticationToken]
536+
AuthTokens.kerberos(token.getToken)
537+
})
538+
case "none" =>
539+
new StaticAuthTokenManager(AuthTokens.none())
538540
}
539541
}
540542

541-
private def createAuthTokenSupplier: Supplier[AuthenticationToken] = {
543+
private def createAuthTokenSupplier: AuthenticationTokenSupplierFactory = {
542544
if (auth == null || auth.isEmpty) {
543545
throw new IllegalArgumentException(s"Authentication type name is required")
544546
}
@@ -562,9 +564,7 @@ case class Neo4jDriverOptions(
562564
)
563565
}
564566

565-
val username = authParameters.get("username")
566-
val password = authParameters.get("password")
567-
filteredSupplierFactories.head.create(username.orNull, password.orNull, authParameters.asJava)
567+
filteredSupplierFactories.head
568568
}
569569

570570
}
@@ -722,6 +722,11 @@ object Neo4jOptions {
722722
}
723723
}
724724

725+
class StaticAuthTokenManager(authToken: AuthToken) extends AuthTokenManager {
726+
override def getToken: CompletionStage[AuthToken] = CompletableFuture.completedStage(authToken)
727+
override def handleSecurityException(authToken: AuthToken, exception: exceptions.SecurityException): Boolean = false
728+
}
729+
725730
class CaseInsensitiveEnumeration extends Enumeration {
726731

727732
def withCaseInsensitiveName(s: String): Value = {

common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@ package org.neo4j.spark.service
1818

1919
import org.junit.Test
2020
import org.junit.runner.RunWith
21-
import org.mockito.ArgumentMatchers
21+
import org.mockito.ArgumentCaptor
2222
import org.mockito.ArgumentMatchers._
2323
import org.mockito.Mockito.times
24+
import org.neo4j.driver.AuthTokenManager
25+
import org.neo4j.driver.AuthTokens
2426
import org.neo4j.driver.Config
2527
import org.neo4j.driver.GraphDatabase
26-
import org.neo4j.driver.internal.security.ExpirationBasedAuthTokenManager
2728
import org.neo4j.spark.util.DriverCache
2829
import org.neo4j.spark.util.Neo4jOptions
2930
import org.powermock.api.mockito.PowerMockito
31+
import org.powermock.core.classloader.annotations.PowerMockIgnore
3032
import org.powermock.core.classloader.annotations.PrepareForTest
3133
import org.powermock.modules.junit4.PowerMockRunner
3234
import org.testcontainers.shaded.com.google.common.io.BaseEncoding
@@ -36,6 +38,7 @@ import java.util
3638

3739
@PrepareForTest(Array(classOf[GraphDatabase]))
3840
@RunWith(classOf[PowerMockRunner])
41+
@PowerMockIgnore(Array("javax.management.*"))
3942
class AuthenticationTest {
4043

4144
@Test
@@ -56,8 +59,9 @@ class AuthenticationTest {
5659
driverCache.getOrCreate()
5760

5861
PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
59-
// was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.custom("", token, "", "")), any(classOf[Config]))
60-
GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]())
62+
val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager])
63+
GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]())
64+
assert(AuthTokens.custom("", token, "", "") == managerCaptor.getValue.getToken.toCompletableFuture.join())
6165
}
6266

6367
@Test
@@ -77,7 +81,8 @@ class AuthenticationTest {
7781
driverCache.getOrCreate()
7882

7983
PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
80-
// was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.bearer(token)), any())
81-
GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]())
84+
val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager])
85+
GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]())
86+
assert(AuthTokens.bearer(token) == managerCaptor.getValue.getToken.toCompletableFuture.join())
8287
}
8388
}

spark-3/src/test/scala/org/neo4j/spark/DataSourceWriterTSE.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,11 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
576576
@Test
577577
def `should throw an error because the node already exists`(): Unit = {
578578
SparkConnectorScalaSuiteIT.session()
579-
.executeWrite(tx => tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE"))
579+
.executeWriteWithoutResult(tx =>
580+
tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE")
581+
)
580582
SparkConnectorScalaSuiteIT.session()
581-
.executeWrite(tx => tx.run("CREATE (p:Person{name: 'Andrea', surname: 'Santurbano'})"))
583+
.executeWriteWithoutResult(tx => tx.run("CREATE (p:Person{name: 'Andrea', surname: 'Santurbano'})"))
582584

583585
val ds = Seq(SimplePerson("Andrea", "Santurbano")).toDS()
584586

@@ -602,16 +604,18 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
602604
}
603605
} finally {
604606
SparkConnectorScalaSuiteIT.session()
605-
.executeWrite(tx => tx.run("DROP CONSTRAINT person_surname"))
607+
.executeWriteWithoutResult(tx => tx.run("DROP CONSTRAINT person_surname"))
606608
}
607609
}
608610

609611
@Test
610612
def `should update the node that already exists`(): Unit = {
611613
SparkConnectorScalaSuiteIT.session()
612-
.executeWrite(tx => tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE"))
614+
.executeWriteWithoutResult(tx =>
615+
tx.run("CREATE CONSTRAINT person_surname FOR (p:Person) REQUIRE p.surname IS UNIQUE")
616+
)
613617
SparkConnectorScalaSuiteIT.session()
614-
.executeWrite(tx => tx.run("CREATE (p:Person{name: 'Federico', surname: 'Santurbano'})"))
618+
.executeWriteWithoutResult(tx => tx.run("CREATE (p:Person{name: 'Federico', surname: 'Santurbano'})"))
615619

616620
val ds = Seq(SimplePerson("Andrea", "Santurbano")).toDS()
617621

@@ -635,7 +639,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
635639
assertEquals("Andrea", nodeList.head.get("n").asNode().get("name").asString())
636640

637641
SparkConnectorScalaSuiteIT.session()
638-
.executeWrite(tx => tx.run("DROP CONSTRAINT person_surname"))
642+
.executeWriteWithoutResult(tx => tx.run("DROP CONSTRAINT person_surname"))
639643
}
640644

641645
@Test
@@ -766,7 +770,9 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
766770
@Test
767771
def `should handle unusual column names`(): Unit = {
768772
SparkConnectorScalaSuiteIT.session()
769-
.executeWrite(tx => tx.run("CREATE CONSTRAINT instrument_name FOR (i:Instrument) REQUIRE i.name IS UNIQUE"))
773+
.executeWriteWithoutResult(tx =>
774+
tx.run("CREATE CONSTRAINT instrument_name FOR (i:Instrument) REQUIRE i.name IS UNIQUE")
775+
)
770776

771777
val musicDf = Seq(
772778
(12, "John Bonham", "Drums", "f``````oo"),
@@ -791,7 +797,7 @@ class DataSourceWriterTSE extends SparkConnectorScalaBaseTSE {
791797
.save()
792798

793799
SparkConnectorScalaSuiteIT.session()
794-
.executeWrite(tx => tx.run("DROP CONSTRAINT instrument_name"))
800+
.executeWriteWithoutResult(tx => tx.run("DROP CONSTRAINT instrument_name"))
795801

796802
val musicDfCheck = ss.read.format(classOf[DataSource].getName)
797803
.option("url", SparkConnectorScalaSuiteIT.server.getBoltUrl)

0 commit comments

Comments
 (0)