Skip to content

Commit e603316

Browse files
injectivesali-ince
authored andcommitted
feat(auth): Adjust auth handling for 6.x driver
The main objective of this update is to adjust authentication handling to make it work with 6.x Java Driver.
1 parent 55e0ee2 commit e603316

File tree

3 files changed

+58
-48
lines changed

3 files changed

+58
-48
lines changed

common/src/main/scala/org/neo4j/spark/service/SchemaService.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class SchemaService(
730730
s"spark_${entityType}_${constraintType.replace(s"$entityType ", "")}-CONSTRAINT_${entityIdentifier}_$dashSeparatedProps".quote()
731731
val props = keys.values.map(_.quote()).map("e." + _).mkString(", ")
732732
val asciiRepresentation: String = createCypherPattern(entityType, entityIdentifier)
733-
session.executeWrite(
733+
session.executeWriteWithoutResult(
734734
tx => {
735735
tx.run(
736736
s"CREATE CONSTRAINT $constraintName IF NOT EXISTS FOR $asciiRepresentation REQUIRE ($props) IS $constraintType"

common/src/main/scala/org/neo4j/spark/util/Neo4jOptions.scala

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ import org.apache.spark.internal.Logging
2020
import org.apache.spark.sql.SaveMode
2121
import org.apache.spark.sql.SparkSession
2222
import org.jetbrains.annotations.TestOnly
23-
import org.neo4j.connectors.authn.AuthenticationToken
2423
import org.neo4j.connectors.authn.AuthenticationTokenSupplierFactory
2524
import org.neo4j.connectors.authn.BearerAuthenticationToken
2625
import org.neo4j.connectors.authn.CustomAuthenticationToken
27-
import org.neo4j.connectors.authn.DisabledAuthenticationToken
28-
import org.neo4j.connectors.authn.ExpiringAuthenticationToken
2926
import org.neo4j.connectors.authn.KerberosAuthenticationToken
3027
import org.neo4j.connectors.authn.UserNameAndPasswordAuthenticationToken
3128
import org.neo4j.driver.Config.TrustStrategy
@@ -40,8 +37,9 @@ import java.time.Duration
4037
import java.util
4138
import java.util.ServiceLoader
4239
import java.util.UUID
40+
import java.util.concurrent.CompletableFuture
41+
import java.util.concurrent.CompletionStage
4342
import java.util.concurrent.TimeUnit
44-
import java.util.function.Supplier
4543

4644
import scala.collection.JavaConverters._
4745
import scala.language.implicitConversions
@@ -498,47 +496,51 @@ case class Neo4jDriverOptions(
498496
(URI.create(urls.head.trim), resolved)
499497
}
500498

501-
// TODO this is partially intentionally not working as expect, changed to make it compile
502-
// TODO missing here is also the keycloack support that comes from the authn commons
503499
private def createAuthTokenManager: AuthTokenManager = {
504500
if (auth == null || auth.isEmpty) {
505501
throw new IllegalArgumentException(s"Authentication type name is required")
506502
}
507-
val token = createAuthTokenSupplier.get()
508-
token match {
509-
case bearerAuthenticationToken: BearerAuthenticationToken =>
510-
AuthTokenManagers.bearer(() =>
511-
AuthTokens.bearer(
512-
bearerAuthenticationToken.getToken
513-
).expiringAt(bearerAuthenticationToken.getExpiresAt.toEpochMilli)
514-
)
515-
case customAuthenticationToken: CustomAuthenticationToken =>
516-
AuthTokenManagers.basic(() =>
517-
AuthTokens.custom(
518-
customAuthenticationToken.getPrincipal,
519-
customAuthenticationToken.getCredentials,
520-
customAuthenticationToken.getRealm,
521-
customAuthenticationToken.getScheme,
522-
customAuthenticationToken.getParameters
523-
)
524-
)
525-
case disabledAuthenticationToken: DisabledAuthenticationToken =>
526-
AuthTokenManagers.basic(() => AuthTokens.none())
527-
case kerberosAuthenticationToken: KerberosAuthenticationToken =>
528-
AuthTokenManagers.basic(() => AuthTokens.kerberos(kerberosAuthenticationToken.getToken))
529-
case userNameAndPasswordAuthenticationToken: UserNameAndPasswordAuthenticationToken =>
530-
AuthTokenManagers.basic(() =>
531-
AuthTokens.basic(
532-
userNameAndPasswordAuthenticationToken.getUsername,
533-
userNameAndPasswordAuthenticationToken.getPassword,
534-
userNameAndPasswordAuthenticationToken.getRealm
535-
)
536-
)
537-
case _ => throw new IllegalStateException("bam")
503+
val token = createAuthTokenSupplier
504+
val name = token.getName
505+
val username = authParameters.get("username")
506+
val password = authParameters.get("password")
507+
val supplier = token.create(username.orNull, password.orNull, authParameters.asJava)
508+
509+
name match {
510+
case "basic" =>
511+
val token = supplier.get().asInstanceOf[UserNameAndPasswordAuthenticationToken]
512+
new StaticAuthTokenManager(AuthTokens.basic(token.getUsername, token.getPassword))
513+
case "bearer" | "keycloak" =>
514+
AuthTokenManagers.bearer(() => {
515+
val token = supplier.get().asInstanceOf[BearerAuthenticationToken]
516+
val authToken = AuthTokens.bearer(token.getToken)
517+
val exp = token.getExpiresAt
518+
if (exp == null) {
519+
authToken.expiringAt(Long.MaxValue)
520+
} else {
521+
authToken.expiringAt(exp.toEpochMilli)
522+
}
523+
})
524+
case "custom" =>
525+
val token = supplier.get().asInstanceOf[CustomAuthenticationToken]
526+
new StaticAuthTokenManager(AuthTokens.custom(
527+
token.getPrincipal,
528+
token.getCredentials,
529+
token.getRealm,
530+
token.getScheme,
531+
token.getParameters
532+
))
533+
case "kerberos" =>
534+
AuthTokenManagers.basic(() => {
535+
val token = supplier.get().asInstanceOf[KerberosAuthenticationToken]
536+
AuthTokens.kerberos(token.getToken)
537+
})
538+
case "none" =>
539+
new StaticAuthTokenManager(AuthTokens.none())
538540
}
539541
}
540542

541-
private def createAuthTokenSupplier: Supplier[AuthenticationToken] = {
543+
private def createAuthTokenSupplier: AuthenticationTokenSupplierFactory = {
542544
if (auth == null || auth.isEmpty) {
543545
throw new IllegalArgumentException(s"Authentication type name is required")
544546
}
@@ -562,9 +564,7 @@ case class Neo4jDriverOptions(
562564
)
563565
}
564566

565-
val username = authParameters.get("username")
566-
val password = authParameters.get("password")
567-
filteredSupplierFactories.head.create(username.orNull, password.orNull, authParameters.asJava)
567+
filteredSupplierFactories.head
568568
}
569569

570570
}
@@ -722,6 +722,11 @@ object Neo4jOptions {
722722
}
723723
}
724724

725+
class StaticAuthTokenManager(authToken: AuthToken) extends AuthTokenManager {
726+
override def getToken: CompletionStage[AuthToken] = CompletableFuture.completedStage(authToken)
727+
override def handleSecurityException(authToken: AuthToken, exception: exceptions.SecurityException): Boolean = false
728+
}
729+
725730
class CaseInsensitiveEnumeration extends Enumeration {
726731

727732
def withCaseInsensitiveName(s: String): Value = {

common/src/test/scala/org/neo4j/spark/service/AuthenticationTest.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@ package org.neo4j.spark.service
1818

1919
import org.junit.Test
2020
import org.junit.runner.RunWith
21-
import org.mockito.ArgumentMatchers
21+
import org.mockito.ArgumentCaptor
2222
import org.mockito.ArgumentMatchers._
2323
import org.mockito.Mockito.times
24+
import org.neo4j.driver.AuthTokenManager
25+
import org.neo4j.driver.AuthTokens
2426
import org.neo4j.driver.Config
2527
import org.neo4j.driver.GraphDatabase
26-
import org.neo4j.driver.internal.security.ExpirationBasedAuthTokenManager
2728
import org.neo4j.spark.util.DriverCache
2829
import org.neo4j.spark.util.Neo4jOptions
2930
import org.powermock.api.mockito.PowerMockito
31+
import org.powermock.core.classloader.annotations.PowerMockIgnore
3032
import org.powermock.core.classloader.annotations.PrepareForTest
3133
import org.powermock.modules.junit4.PowerMockRunner
3234
import org.testcontainers.shaded.com.google.common.io.BaseEncoding
@@ -36,6 +38,7 @@ import java.util
3638

3739
@PrepareForTest(Array(classOf[GraphDatabase]))
3840
@RunWith(classOf[PowerMockRunner])
41+
@PowerMockIgnore(Array("javax.management.*"))
3942
class AuthenticationTest {
4043

4144
@Test
@@ -56,8 +59,9 @@ class AuthenticationTest {
5659
driverCache.getOrCreate()
5760

5861
PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
59-
// was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.custom("", token, "", "")), any(classOf[Config]))
60-
GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]())
62+
val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager])
63+
GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]())
64+
assert(AuthTokens.custom("", token, "", "") == managerCaptor.getValue.getToken.toCompletableFuture.join())
6165
}
6266

6367
@Test
@@ -77,7 +81,8 @@ class AuthenticationTest {
7781
driverCache.getOrCreate()
7882

7983
PowerMockito.verifyStatic(classOf[GraphDatabase], times(1))
80-
// was GraphDatabase.driver(any[URI](), ArgumentMatchers.eq(AuthTokens.bearer(token)), any())
81-
GraphDatabase.driver(any[URI](), ArgumentMatchers.any[ExpirationBasedAuthTokenManager], any[Config]())
84+
val managerCaptor = ArgumentCaptor.forClass(classOf[AuthTokenManager])
85+
GraphDatabase.driver(any[URI](), managerCaptor.capture(), any[Config]())
86+
assert(AuthTokens.bearer(token) == managerCaptor.getValue.getToken.toCompletableFuture.join())
8287
}
8388
}

0 commit comments

Comments
 (0)