Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.arrow.util.Preconditions;
import org.apache.calcite.avatica.AvaticaConnection;
import org.apache.calcite.avatica.AvaticaFactory;
import org.apache.calcite.avatica.DriverVersion;

/** Connection to the Arrow Flight server. */
public final class ArrowFlightConnection extends AvaticaConnection {
Expand Down Expand Up @@ -86,13 +87,16 @@ static ArrowFlightConnection createNewConnection(
throws SQLException {
url = replaceSemiColons(url);
final ArrowFlightConnectionConfigImpl config = new ArrowFlightConnectionConfigImpl(properties);
final ArrowFlightSqlClientHandler clientHandler = createNewClientHandler(config, allocator);
final ArrowFlightSqlClientHandler clientHandler =
createNewClientHandler(config, allocator, driver.getDriverVersion());
return new ArrowFlightConnection(
driver, factory, url, properties, config, allocator, clientHandler);
}

private static ArrowFlightSqlClientHandler createNewClientHandler(
final ArrowFlightConnectionConfigImpl config, final BufferAllocator allocator)
final ArrowFlightConnectionConfigImpl config,
final BufferAllocator allocator,
final DriverVersion driverVersion)
throws SQLException {
try {
return new ArrowFlightSqlClientHandler.Builder()
Expand All @@ -116,6 +120,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withCatalog(config.getCatalog())
.withClientCache(config.useClientCache() ? new FlightClientCache() : null)
.withConnectTimeout(config.getConnectTimeout())
.withDriverVersion(driverVersion)
.build();
} catch (final SQLException e) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.DriverVersion;
import org.apache.calcite.avatica.Meta.StatementType;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
Expand Down Expand Up @@ -548,6 +549,9 @@ public FlightInfo getCrossReference(

/** Builder for {@link ArrowFlightSqlClientHandler}. */
public static final class Builder {
static final String USER_AGENT_TEMPLATE = "JDBC Flight SQL Driver %s";
static final String DEFAULT_VERSION = "(unknown or development build)";

private final Set<FlightClientMiddleware.Factory> middlewareFactories = new HashSet<>();
private final Set<CallOption> options = new HashSet<>();
private String host;
Expand Down Expand Up @@ -597,6 +601,8 @@ public static final class Builder {
@VisibleForTesting
ClientCookieMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory();

DriverVersion driverVersion;

public Builder() {}

/**
Expand Down Expand Up @@ -631,6 +637,8 @@ public Builder() {}
if (original.retainAuth) {
this.authFactory = original.authFactory;
}

this.driverVersion = original.driverVersion;
}

/**
Expand Down Expand Up @@ -879,6 +887,17 @@ public Builder withConnectTimeout(Duration connectTimeout) {
return this;
}

/**
* Sets the driver version for this handler.
*
* @param driverVersion the driver version to set
* @return this builder instance
*/
public Builder withDriverVersion(DriverVersion driverVersion) {
this.driverVersion = driverVersion;
return this;
}

public String getCacheKey() {
return getLocation().toString();
}
Expand Down Expand Up @@ -914,6 +933,11 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
final NettyClientBuilder clientBuilder = new NettyClientBuilder();
clientBuilder.allocator(allocator);

String userAgent = String.format(USER_AGENT_TEMPLATE, DEFAULT_VERSION);
if (driverVersion != null && driverVersion.versionString != null) {
userAgent = String.format(USER_AGENT_TEMPLATE, driverVersion.versionString);
}

buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
if (useEncryption) {
Expand Down Expand Up @@ -948,6 +972,9 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
}

NettyChannelBuilder channelBuilder = clientBuilder.build();

channelBuilder.userAgent(userAgent);

if (connectTimeout != null) {
channelBuilder.withOption(
ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ public void testCookies() throws SQLException {
Statement statement = connection.createStatement()) {

// Expect client didn't receive cookies before any operation
assertNull(FLIGHT_SERVER_TEST_EXTENSION.getMiddlewareCookieFactory().getCookie());
assertNull(FLIGHT_SERVER_TEST_EXTENSION.getInterceptorFactory().getCookie());

// Run another action for check if the cookies was sent by the server.
statement.execute(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD);
assertEquals("k=v", FLIGHT_SERVER_TEST_EXTENSION.getMiddlewareCookieFactory().getCookie());
assertEquals("k=v", FLIGHT_SERVER_TEST_EXTENSION.getInterceptorFactory().getCookie());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler;
import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
Expand Down Expand Up @@ -576,4 +577,49 @@ public void testPasswordConnectionPropertyIntegerCorrectCastUrlWithDriverManager
assertTrue(connection.isValid(0));
}
}

/**
* Test that the JDBC driver properly integrates driver version into client handler.
*
* @throws Exception on error.
*/
@Test
public void testJdbcDriverVersionIntegration() throws Exception {
final Properties properties = new Properties();
properties.put(
ArrowFlightConnectionProperty.HOST.camelName(), FLIGHT_SERVER_TEST_EXTENSION.getHost());
properties.put(
ArrowFlightConnectionProperty.PORT.camelName(), FLIGHT_SERVER_TEST_EXTENSION.getPort());
properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest);
properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest);
properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), false);

// Create a driver instance and connect
ArrowFlightJdbcDriver driverVersion = new ArrowFlightJdbcDriver();

try (Connection connection =
ArrowFlightConnection.createNewConnection(
driverVersion,
new ArrowFlightJdbcFactory(),
"jdbc:arrow-flight-sql://localhost:" + FLIGHT_SERVER_TEST_EXTENSION.getPort(),
properties,
allocator)) {

assertTrue(connection.isValid(0));

var actualUserAgent =
FLIGHT_SERVER_TEST_EXTENSION
.getInterceptorFactory()
.getHeader(FlightMethod.HANDSHAKE, "user-agent");

var expectedUserAgent =
"JDBC Flight SQL Driver " + driverVersion.getDriverVersion().versionString;
// Driver appends version to grpc user-agent header. Assert the header starts with the
// expected
// value and ignored grpc version.
assertTrue(
actualUserAgent.startsWith(expectedUserAgent),
"Expected: " + expectedUserAgent + " but found: " + actualUserAgent);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.sql.SQLException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import org.apache.arrow.driver.jdbc.authentication.Authentication;
import org.apache.arrow.driver.jdbc.authentication.TokenAuthentication;
Expand All @@ -33,6 +35,7 @@
import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.Location;
Expand Down Expand Up @@ -67,7 +70,8 @@ public class FlightServerTestExtension
private final CertKeyPair certKeyPair;
private final File mTlsCACert;

private final MiddlewareCookie.Factory middlewareCookieFactory = new MiddlewareCookie.Factory();
private final InterceptorMiddleware.Factory interceptorFactory =
new InterceptorMiddleware.Factory();

private FlightServerTestExtension(
final Properties properties,
Expand Down Expand Up @@ -130,8 +134,8 @@ private void setUseEncryption(boolean useEncryption) {
properties.put("useEncryption", useEncryption);
}

public MiddlewareCookie.Factory getMiddlewareCookieFactory() {
return middlewareCookieFactory;
public InterceptorMiddleware.Factory getInterceptorFactory() {
return interceptorFactory;
}

@FunctionalInterface
Expand All @@ -143,7 +147,7 @@ private FlightServer initiateServer(Location location) throws IOException {
FlightServer.Builder builder =
FlightServer.builder(allocator, location, producer)
.headerAuthenticator(authentication.authenticate())
.middleware(FlightServerMiddleware.Key.of("KEY"), middlewareCookieFactory);
.middleware(FlightServerMiddleware.Key.of("KEY"), interceptorFactory);
if (certKeyPair != null) {
builder.useTls(certKeyPair.cert, certKeyPair.key);
}
Expand Down Expand Up @@ -301,11 +305,11 @@ public FlightServerTestExtension build() {
* A middleware to handle with the cookies in the server. It is used to test if cookies are being
* sent properly.
*/
static class MiddlewareCookie implements FlightServerMiddleware {
static class InterceptorMiddleware implements FlightServerMiddleware {

private final Factory factory;

public MiddlewareCookie(Factory factory) {
public InterceptorMiddleware(Factory factory) {
this.factory = factory;
}

Expand All @@ -323,22 +327,33 @@ public void onCallCompleted(CallStatus callStatus) {}
public void onCallErrored(Throwable throwable) {}

/** A factory for the MiddlewareCookie. */
static class Factory implements FlightServerMiddleware.Factory<MiddlewareCookie> {
static class Factory implements FlightServerMiddleware.Factory<InterceptorMiddleware> {

private final Map<FlightMethod, CallHeaders> receivedCallHeaders = new HashMap<>();
private boolean receivedCookieHeader = false;
private String cookie;

@Override
public MiddlewareCookie onCallStarted(
public InterceptorMiddleware onCallStarted(
CallInfo callInfo, CallHeaders callHeaders, RequestContext requestContext) {
cookie = callHeaders.get("Cookie");
receivedCookieHeader = null != cookie;
return new MiddlewareCookie(this);

receivedCallHeaders.put(callInfo.method(), callHeaders);
return new InterceptorMiddleware(this);
}

public String getCookie() {
return cookie;
}

public String getHeader(FlightMethod method, String key) {
CallHeaders headers = receivedCallHeaders.get(method);
if (headers == null) {
return null;
}
return headers.get(key);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public void testDefaults() {
assertEquals(Optional.empty(), builder.catalog);
assertNull(builder.flightClientCache);
assertNull(builder.connectTimeout);
assertNull(builder.driverVersion);
}

@Test
Expand Down
Loading