@@ -20,12 +20,9 @@ import org.apache.spark.internal.Logging
2020import org .apache .spark .sql .SaveMode
2121import org .apache .spark .sql .SparkSession
2222import org .jetbrains .annotations .TestOnly
23- import org .neo4j .connectors .authn .AuthenticationToken
2423import org .neo4j .connectors .authn .AuthenticationTokenSupplierFactory
2524import org .neo4j .connectors .authn .BearerAuthenticationToken
2625import org .neo4j .connectors .authn .CustomAuthenticationToken
27- import org .neo4j .connectors .authn .DisabledAuthenticationToken
28- import org .neo4j .connectors .authn .ExpiringAuthenticationToken
2926import org .neo4j .connectors .authn .KerberosAuthenticationToken
3027import org .neo4j .connectors .authn .UserNameAndPasswordAuthenticationToken
3128import org .neo4j .driver .Config .TrustStrategy
@@ -40,8 +37,9 @@ import java.time.Duration
4037import java .util
4138import java .util .ServiceLoader
4239import java .util .UUID
40+ import java .util .concurrent .CompletableFuture
41+ import java .util .concurrent .CompletionStage
4342import java .util .concurrent .TimeUnit
44- import java .util .function .Supplier
4543
4644import scala .collection .JavaConverters ._
4745import scala .language .implicitConversions
@@ -498,47 +496,51 @@ case class Neo4jDriverOptions(
498496 (URI .create(urls.head.trim), resolved)
499497 }
500498
501- // TODO this is partially intentionally not working as expect, changed to make it compile
502- // TODO missing here is also the keycloack support that comes from the authn commons
503499 private def createAuthTokenManager : AuthTokenManager = {
504500 if (auth == null || auth.isEmpty) {
505501 throw new IllegalArgumentException (s " Authentication type name is required " )
506502 }
507- val token = createAuthTokenSupplier.get()
508- token match {
509- case bearerAuthenticationToken : BearerAuthenticationToken =>
510- AuthTokenManagers .bearer(() =>
511- AuthTokens .bearer(
512- bearerAuthenticationToken.getToken
513- ).expiringAt(bearerAuthenticationToken.getExpiresAt.toEpochMilli)
514- )
515- case customAuthenticationToken : CustomAuthenticationToken =>
516- AuthTokenManagers .basic(() =>
517- AuthTokens .custom(
518- customAuthenticationToken.getPrincipal,
519- customAuthenticationToken.getCredentials,
520- customAuthenticationToken.getRealm,
521- customAuthenticationToken.getScheme,
522- customAuthenticationToken.getParameters
523- )
524- )
525- case disabledAuthenticationToken : DisabledAuthenticationToken =>
526- AuthTokenManagers .basic(() => AuthTokens .none())
527- case kerberosAuthenticationToken : KerberosAuthenticationToken =>
528- AuthTokenManagers .basic(() => AuthTokens .kerberos(kerberosAuthenticationToken.getToken))
529- case userNameAndPasswordAuthenticationToken : UserNameAndPasswordAuthenticationToken =>
530- AuthTokenManagers .basic(() =>
531- AuthTokens .basic(
532- userNameAndPasswordAuthenticationToken.getUsername,
533- userNameAndPasswordAuthenticationToken.getPassword,
534- userNameAndPasswordAuthenticationToken.getRealm
535- )
536- )
537- case _ => throw new IllegalStateException (" bam" )
503+ val token = createAuthTokenSupplier
504+ val name = token.getName
505+ val username = authParameters.get(" username" )
506+ val password = authParameters.get(" password" )
507+ val supplier = token.create(username.orNull, password.orNull, authParameters.asJava)
508+
509+ name match {
510+ case " basic" =>
511+ val token = supplier.get().asInstanceOf [UserNameAndPasswordAuthenticationToken ]
512+ new StaticAuthTokenManager (AuthTokens .basic(token.getUsername, token.getPassword))
513+ case " bearer" | " keycloak" =>
514+ AuthTokenManagers .bearer(() => {
515+ val token = supplier.get().asInstanceOf [BearerAuthenticationToken ]
516+ val authToken = AuthTokens .bearer(token.getToken)
517+ val exp = token.getExpiresAt
518+ if (exp == null ) {
519+ authToken.expiringAt(Long .MaxValue )
520+ } else {
521+ authToken.expiringAt(exp.toEpochMilli)
522+ }
523+ })
524+ case " custom" =>
525+ val token = supplier.get().asInstanceOf [CustomAuthenticationToken ]
526+ new StaticAuthTokenManager (AuthTokens .custom(
527+ token.getPrincipal,
528+ token.getCredentials,
529+ token.getRealm,
530+ token.getScheme,
531+ token.getParameters
532+ ))
533+ case " kerberos" =>
534+ AuthTokenManagers .basic(() => {
535+ val token = supplier.get().asInstanceOf [KerberosAuthenticationToken ]
536+ AuthTokens .kerberos(token.getToken)
537+ })
538+ case " none" =>
539+ new StaticAuthTokenManager (AuthTokens .none())
538540 }
539541 }
540542
541- private def createAuthTokenSupplier : Supplier [ AuthenticationToken ] = {
543+ private def createAuthTokenSupplier : AuthenticationTokenSupplierFactory = {
542544 if (auth == null || auth.isEmpty) {
543545 throw new IllegalArgumentException (s " Authentication type name is required " )
544546 }
@@ -562,9 +564,7 @@ case class Neo4jDriverOptions(
562564 )
563565 }
564566
565- val username = authParameters.get(" username" )
566- val password = authParameters.get(" password" )
567- filteredSupplierFactories.head.create(username.orNull, password.orNull, authParameters.asJava)
567+ filteredSupplierFactories.head
568568 }
569569
570570}
@@ -722,6 +722,11 @@ object Neo4jOptions {
722722 }
723723}
724724
725+ class StaticAuthTokenManager (authToken : AuthToken ) extends AuthTokenManager {
726+ override def getToken : CompletionStage [AuthToken ] = CompletableFuture .completedStage(authToken)
727+ override def handleSecurityException (authToken : AuthToken , exception : exceptions.SecurityException ): Boolean = false
728+ }
729+
725730class CaseInsensitiveEnumeration extends Enumeration {
726731
727732 def withCaseInsensitiveName (s : String ): Value = {
0 commit comments