Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions flight/flight-sql-jdbc-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ under the License.
</exclusions>
</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-api</artifactId>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised we have to explicitly define these dependencies, I thought they came transitively from arrow flight.

</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty</artifactId>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-transport</artifactId>
</dependency>

<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.withCatalog(config.getCatalog())
.withConnectTimeout(config.getConnectTimeout())
.build();
} catch (final SQLException e) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
package org.apache.arrow.driver.jdbc.client;

import com.google.common.collect.ImmutableMap;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.ChannelOption;
import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -36,6 +39,7 @@
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightGrpcUtils;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
Expand All @@ -50,6 +54,7 @@
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.grpc.NettyClientBuilder;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo;
import org.apache.arrow.flight.sql.util.TableRef;
Expand Down Expand Up @@ -138,12 +143,11 @@ public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
// Clone the builder and then set the new endpoint on it.

// GH-38574: Currently a new FlightClient will be made for each partition that returns a
// non-empty Location
// then disposed of. It may be better to cache clients because a server may report the
// same Locations.
// It would also be good to identify when the reported location is the same as the
// original connection's
// Location and skip creating a FlightClient in that scenario.
// non-empty Location then disposed of. It may be better to cache clients because a server
// may report the same Locations. It would also be good to identify when the reported
// location
// is the same as the original connection's Location and skip creating a FlightClient in
// that scenario.
List<Exception> exceptions = new ArrayList<>();
CloseableEndpointStreamPair stream = null;
for (Location location : endpoint.getLocations()) {
Expand All @@ -158,7 +162,8 @@ public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
.withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS));
.withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS))
.withConnectTimeout(builder.connectTimeout);

ArrowFlightSqlClientHandler endpointHandler = null;
try {
Expand All @@ -177,6 +182,7 @@ public List<CloseableEndpointStreamPair> getStreams(final FlightInfo flightInfo)
exceptions.add(ex);
continue;
}

break;
}
if (stream != null) {
Expand Down Expand Up @@ -543,6 +549,8 @@ public static final class Builder {

@VisibleForTesting Optional<String> catalog = Optional.empty();

@VisibleForTesting @Nullable Duration connectTimeout;

// These two middleware are for internal use within build() and should not be exposed by builder
// APIs.
// Note that these middleware may not necessarily be registered.
Expand Down Expand Up @@ -825,6 +833,19 @@ public Builder withCatalog(@Nullable final String catalog) {
return this;
}

public Builder withConnectTimeout(Duration connectTimeout) {
this.connectTimeout = connectTimeout;
return this;
}

/** Get the location that this client will connect to. */
public Location getLocation() {
if (useEncryption) {
return Location.forGrpcTls(host, port);
}
return Location.forGrpcInsecure(host, port);
}

/**
* Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
*
Expand All @@ -845,17 +866,15 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
final NettyClientBuilder clientBuilder = new NettyClientBuilder();
clientBuilder.allocator(allocator);

buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
Location location;
if (useEncryption) {
location = Location.forGrpcTls(host, port);
clientBuilder.useTls();
} else {
location = Location.forGrpcInsecure(host, port);
}
Location location = getLocation();
clientBuilder.location(location);

if (useEncryption) {
Expand Down Expand Up @@ -883,7 +902,14 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
}
}

client = clientBuilder.build();
NettyChannelBuilder channelBuilder = clientBuilder.build();
if (connectTimeout != null) {
channelBuilder.withOption(
ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis());
}
client =
FlightGrpcUtils.createFlightClient(
allocator, channelBuilder.build(), clientBuilder.middleware());
final ArrayList<CallOption> credentialOptions = new ArrayList<>();
if (isUsingUserPasswordAuth) {
// If the authFactory has already been used for a handshake, use the existing token.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.arrow.driver.jdbc.utils;

import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -163,6 +164,16 @@ public String getCatalog() {
return ArrowFlightConnectionProperty.CATALOG.getString(properties);
}

/** The initial connect timeout. */
public Duration getConnectTimeout() {
Integer timeout = ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.getInteger(properties);
if (timeout == null) {
return Duration.ofMillis(
(int) ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.defaultValue());
}
return Duration.ofMillis(timeout);
}

/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
Expand Down Expand Up @@ -213,7 +224,9 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false),
CATALOG("catalog", null, Type.STRING, false);
CATALOG("catalog", null, Type.STRING, false),
CONNECT_TIMEOUT_MILLIS("connectTimeoutMs", 10000, Type.NUMBER, false),
;

private final String camelName;
private final Object defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.*;

import com.google.common.collect.ImmutableSet;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -645,6 +640,139 @@ public void testFallbackSecondFlightServer() throws Exception {
}
}

@Test
public void testFallbackUnresolvableFlightServer() throws Exception {
final Schema schema =
new Schema(
Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
resultData.setRowCount(1);
((IntVector) resultData.getVector(0)).set(0, 1);

try (final FallbackFlightSqlProducer rootProducer =
new FallbackFlightSqlProducer(resultData);
FlightServer rootServer =
FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer)
.build()
.start();
Connection newConnection =
DriverManager.getConnection(
String.format(
"jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) {
// This first attempt should take a measurable amount of time.
long start = System.nanoTime();
try (Statement newStatement = newConnection.createStatement()) {
try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
List<Integer> actualData = new ArrayList<>();
while (result.next()) {
actualData.add(result.getInt(1));
}

// Assert
assertEquals(resultData.getRowCount(), actualData.size());
assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
}
}
long attempt1 = System.nanoTime();
double elapsedMs = (attempt1 - start) / 1_000_000.;
assertTrue(
elapsedMs >= 5000.,
String.format(
"Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs));

// Once the client cache is implemented (GH-661), this second attempt should take less time,
// since the failure from before should be cached.
start = System.nanoTime();
try (Statement newStatement = newConnection.createStatement()) {
try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
List<Integer> actualData = new ArrayList<>();
while (result.next()) {
actualData.add(result.getInt(1));
}

// Assert
assertEquals(resultData.getRowCount(), actualData.size());
assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
}
}
attempt1 = System.nanoTime();
elapsedMs = (attempt1 - start) / 1_000_000.;
// TODO(GH-661): this assertion should be flipped to assertTrue.
assertFalse(
elapsedMs < 5000.,
String.format("Expected second attempt to be the same, but %f ms elapsed", elapsedMs));
}
}
}

@Test
public void testFallbackUnresolvableFlightServerDisableCache() throws Exception {
final Schema schema =
new Schema(
Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
resultData.setRowCount(1);
((IntVector) resultData.getVector(0)).set(0, 1);

try (final FallbackFlightSqlProducer rootProducer =
new FallbackFlightSqlProducer(resultData);
FlightServer rootServer =
FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer)
.build()
.start();
Connection newConnection =
DriverManager.getConnection(
String.format(
"jdbc:arrow-flight-sql://%s:%d/?useEncryption=false&useClientCache=false",
rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) {
// This first attempt should take a measurable amount of time.
long start = System.nanoTime();
try (Statement newStatement = newConnection.createStatement()) {
try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
List<Integer> actualData = new ArrayList<>();
while (result.next()) {
actualData.add(result.getInt(1));
}

// Assert
assertEquals(resultData.getRowCount(), actualData.size());
assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
}
}
long attempt1 = System.nanoTime();
double elapsedMs = (attempt1 - start) / 1_000_000.;
assertTrue(
elapsedMs >= 5000.,
String.format(
"Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs));

// This second attempt should take a long time still, since we disabled the cache.
start = System.nanoTime();
try (Statement newStatement = newConnection.createStatement()) {
try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
List<Integer> actualData = new ArrayList<>();
while (result.next()) {
actualData.add(result.getInt(1));
}

// Assert
assertEquals(resultData.getRowCount(), actualData.size());
assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
}
}
attempt1 = System.nanoTime();
elapsedMs = (attempt1 - start) / 1_000_000.;
assertTrue(
elapsedMs >= 5000.,
String.format(
"Expected second attempt to hit the timeout, but only %f ms elapsed", elapsedMs));
}
}
}

@Test
public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception {
try (Statement statement = connection.createStatement();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ public void testDefaults() {
assertNull(builder.clientCertificatePath);
assertNull(builder.clientKeyPath);
assertEquals(Optional.empty(), builder.catalog);
assertNull(builder.connectTimeout);
}

@Test
Expand Down
Loading
Loading