diff --git a/flight/flight-sql-jdbc-core/pom.xml b/flight/flight-sql-jdbc-core/pom.xml index 237d450eec..02fef3dc1f 100644 --- a/flight/flight-sql-jdbc-core/pom.xml +++ b/flight/flight-sql-jdbc-core/pom.xml @@ -47,6 +47,21 @@ under the License. + + io.grpc + grpc-api + + + + io.grpc + grpc-netty + + + + io.netty + netty-transport + + org.apache.arrow arrow-memory-core diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index c1b1c8f8e6..cf9804d68b 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -113,6 +113,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler( .withRetainCookies(config.retainCookies()) .withRetainAuth(config.retainAuth()) .withCatalog(config.getCatalog()) + .withConnectTimeout(config.getConnectTimeout()) .build(); } catch (final SQLException e) { try { diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 0e9c79a090..cbbe223eb8 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -17,10 +17,13 @@ package org.apache.arrow.driver.jdbc.client; import com.google.common.collect.ImmutableMap; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.ChannelOption; import java.io.IOException; import java.net.URI; import java.security.GeneralSecurityException; import java.sql.SQLException; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -36,6 +39,7 @@ import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightClientMiddleware; import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightGrpcUtils; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; @@ -50,6 +54,7 @@ import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; import org.apache.arrow.flight.client.ClientCookieMiddleware; import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.grpc.NettyClientBuilder; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; import org.apache.arrow.flight.sql.util.TableRef; @@ -138,12 +143,11 @@ public List getStreams(final FlightInfo flightInfo) // Clone the builder and then set the new endpoint on it. // GH-38574: Currently a new FlightClient will be made for each partition that returns a - // non-empty Location - // then disposed of. It may be better to cache clients because a server may report the - // same Locations. - // It would also be good to identify when the reported location is the same as the - // original connection's - // Location and skip creating a FlightClient in that scenario. + // non-empty Location then disposed of. It may be better to cache clients because a server + // may report the same Locations. It would also be good to identify when the reported + // location + // is the same as the original connection's Location and skip creating a FlightClient in + // that scenario. List exceptions = new ArrayList<>(); CloseableEndpointStreamPair stream = null; for (Location location : endpoint.getLocations()) { @@ -158,7 +162,8 @@ public List getStreams(final FlightInfo flightInfo) new Builder(ArrowFlightSqlClientHandler.this.builder) .withHost(endpointUri.getHost()) .withPort(endpointUri.getPort()) - .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS)); + .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS)) + .withConnectTimeout(builder.connectTimeout); ArrowFlightSqlClientHandler endpointHandler = null; try { @@ -177,6 +182,7 @@ public List getStreams(final FlightInfo flightInfo) exceptions.add(ex); continue; } + break; } if (stream != null) { @@ -543,6 +549,8 @@ public static final class Builder { @VisibleForTesting Optional catalog = Optional.empty(); + @VisibleForTesting @Nullable Duration connectTimeout; + // These two middleware are for internal use within build() and should not be exposed by builder // APIs. // Note that these middleware may not necessarily be registered. @@ -825,6 +833,19 @@ public Builder withCatalog(@Nullable final String catalog) { return this; } + public Builder withConnectTimeout(Duration connectTimeout) { + this.connectTimeout = connectTimeout; + return this; + } + + /** Get the location that this client will connect to. */ + public Location getLocation() { + if (useEncryption) { + return Location.forGrpcTls(host, port); + } + return Location.forGrpcInsecure(host, port); + } + /** * Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields. * @@ -845,17 +866,15 @@ public ArrowFlightSqlClientHandler build() throws SQLException { if (isUsingUserPasswordAuth) { buildTimeMiddlewareFactories.add(authFactory); } - final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator); + final NettyClientBuilder clientBuilder = new NettyClientBuilder(); + clientBuilder.allocator(allocator); buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory()); buildTimeMiddlewareFactories.forEach(clientBuilder::intercept); - Location location; if (useEncryption) { - location = Location.forGrpcTls(host, port); clientBuilder.useTls(); - } else { - location = Location.forGrpcInsecure(host, port); } + Location location = getLocation(); clientBuilder.location(location); if (useEncryption) { @@ -883,7 +902,14 @@ public ArrowFlightSqlClientHandler build() throws SQLException { } } - client = clientBuilder.build(); + NettyChannelBuilder channelBuilder = clientBuilder.build(); + if (connectTimeout != null) { + channelBuilder.withOption( + ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis()); + } + client = + FlightGrpcUtils.createFlightClient( + allocator, channelBuilder.build(), clientBuilder.middleware()); final ArrayList credentialOptions = new ArrayList<>(); if (isUsingUserPasswordAuth) { // If the authFactory has already been used for a handshake, use the existing token. diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java index e8bae2a207..ab6a5898b7 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -16,6 +16,7 @@ */ package org.apache.arrow.driver.jdbc.utils; +import java.time.Duration; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -163,6 +164,16 @@ public String getCatalog() { return ArrowFlightConnectionProperty.CATALOG.getString(properties); } + /** The initial connect timeout. */ + public Duration getConnectTimeout() { + Integer timeout = ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.getInteger(properties); + if (timeout == null) { + return Duration.ofMillis( + (int) ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.defaultValue()); + } + return Duration.ofMillis(timeout); + } + /** * Gets the {@link CallOption}s from this {@link ConnectionConfig}. * @@ -213,7 +224,9 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty { TOKEN("token", null, Type.STRING, false), RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false), RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false), - CATALOG("catalog", null, Type.STRING, false); + CATALOG("catalog", null, Type.STRING, false), + CONNECT_TIMEOUT_MILLIS("connectTimeoutMs", 10000, Type.NUMBER, false), + ; private final String camelName; private final Object defaultValue; diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java index a8d04dfc83..cd47408f52 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -25,12 +25,7 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; import com.google.common.collect.ImmutableSet; import java.nio.charset.StandardCharsets; @@ -645,6 +640,139 @@ public void testFallbackSecondFlightServer() throws Exception { } } + @Test + public void testFallbackUnresolvableFlightServer() throws Exception { + final Schema schema = + new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = + new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = + FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = + DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) { + // This first attempt should take a measurable amount of time. + long start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + long attempt1 = System.nanoTime(); + double elapsedMs = (attempt1 - start) / 1_000_000.; + assertTrue( + elapsedMs >= 5000., + String.format( + "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs)); + + // Once the client cache is implemented (GH-661), this second attempt should take less time, + // since the failure from before should be cached. + start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + attempt1 = System.nanoTime(); + elapsedMs = (attempt1 - start) / 1_000_000.; + // TODO(GH-661): this assertion should be flipped to assertTrue. + assertFalse( + elapsedMs < 5000., + String.format("Expected second attempt to be the same, but %f ms elapsed", elapsedMs)); + } + } + } + + @Test + public void testFallbackUnresolvableFlightServerDisableCache() throws Exception { + final Schema schema = + new Schema( + Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType()))); + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) { + resultData.setRowCount(1); + ((IntVector) resultData.getVector(0)).set(0, 1); + + try (final FallbackFlightSqlProducer rootProducer = + new FallbackFlightSqlProducer(resultData); + FlightServer rootServer = + FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer) + .build() + .start(); + Connection newConnection = + DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false&useClientCache=false", + rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) { + // This first attempt should take a measurable amount of time. + long start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + long attempt1 = System.nanoTime(); + double elapsedMs = (attempt1 - start) / 1_000_000.; + assertTrue( + elapsedMs >= 5000., + String.format( + "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs)); + + // This second attempt should take a long time still, since we disabled the cache. + start = System.nanoTime(); + try (Statement newStatement = newConnection.createStatement()) { + try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) { + List actualData = new ArrayList<>(); + while (result.next()) { + actualData.add(result.getInt(1)); + } + + // Assert + assertEquals(resultData.getRowCount(), actualData.size()); + assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0))); + } + } + attempt1 = System.nanoTime(); + elapsedMs = (attempt1 - start) / 1_000_000.; + assertTrue( + elapsedMs >= 5000., + String.format( + "Expected second attempt to hit the timeout, but only %f ms elapsed", elapsedMs)); + } + } + } + @Test public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception { try (Statement statement = connection.createStatement(); diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java index 6beaba8236..7b416638e1 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java @@ -147,6 +147,7 @@ public void testDefaults() { assertNull(builder.clientCertificatePath); assertNull(builder.clientKeyPath); assertEquals(Optional.empty(), builder.catalog); + assertNull(builder.connectTimeout); } @Test diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java index 4a46b5f5be..c780d53fab 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java @@ -18,6 +18,7 @@ import static java.lang.Runtime.getRuntime; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD; import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT; @@ -27,6 +28,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import java.time.Duration; import java.util.Properties; import java.util.Random; import java.util.function.Function; @@ -59,49 +61,67 @@ public void setUp() { public void testGetProperty( ArrowFlightConnectionProperty property, Object value, + Object expected, Function configFunction) { properties.put(property.camelName(), value); arrowFlightConnectionConfigFunction = configFunction; - assertThat(configFunction.apply(arrowFlightConnectionConfig), is(value)); - assertThat(arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(value)); + assertThat(configFunction.apply(arrowFlightConnectionConfig), is(expected)); + assertThat( + arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(expected)); } public static Stream provideParameters() { + int port = RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)); + boolean useEncryption = RANDOM.nextBoolean(); + int threadPoolSize = RANDOM.nextInt(getRuntime().availableProcessors()); return Stream.of( Arguments.of( HOST, "host", + "host", (Function) ArrowFlightConnectionConfigImpl::getHost), Arguments.of( PORT, - RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)), + port, + port, (Function) ArrowFlightConnectionConfigImpl::getPort), Arguments.of( USER, "user", + "user", (Function) ArrowFlightConnectionConfigImpl::getUser), Arguments.of( PASSWORD, "password", + "password", (Function) ArrowFlightConnectionConfigImpl::getPassword), Arguments.of( USE_ENCRYPTION, - RANDOM.nextBoolean(), + useEncryption, + useEncryption, (Function) ArrowFlightConnectionConfigImpl::useEncryption), Arguments.of( THREAD_POOL_SIZE, - RANDOM.nextInt(getRuntime().availableProcessors()), + threadPoolSize, + threadPoolSize, (Function) ArrowFlightConnectionConfigImpl::threadPoolSize), Arguments.of( CATALOG, "catalog", + "catalog", + (Function) + ArrowFlightConnectionConfigImpl::getCatalog), + Arguments.of( + CONNECT_TIMEOUT_MILLIS, + 5000, + Duration.ofMillis(5000), (Function) - ArrowFlightConnectionConfigImpl::getCatalog)); + ArrowFlightConnectionConfigImpl::getConnectTimeout)); } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java index 9aa257172c..670b9e3be0 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java @@ -109,6 +109,16 @@ private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) { Location.forGrpcInsecure("localhost", 9999), Location.reuseConnection()) .build()); + } else if (query.equals("fallback with unresolvable")) { + endpoints = + Collections.singletonList( + FlightEndpoint.builder( + ticket, + // Inaccessible IP + // https://stackoverflow.com/questions/10456044/what-is-a-good-invalid-ip-address-to-use-for-unit-tests + Location.forGrpcInsecure("203.0.113.0", 9999), + Location.reuseConnection()) + .build()); } else { throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException(); }