Skip to content

Commit 8d3fe93

Browse files
authored
GH-81: [Flight] Expose gRPC in Flight client builder (#660)
## What's Changed Expose the internal gRPC channel builder in the FlightClient so that applications can build on top of it without having to replicate all the special Flight logic. This makes it much easier to apply some gRPC tweaks without duplicating Arrow code. Closes #81.
1 parent d304da5 commit 8d3fe93

File tree

3 files changed

+267
-137
lines changed

3 files changed

+267
-137
lines changed

flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java

Lines changed: 21 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,13 @@
2323
import io.grpc.ManagedChannel;
2424
import io.grpc.MethodDescriptor;
2525
import io.grpc.StatusRuntimeException;
26-
import io.grpc.netty.GrpcSslContexts;
2726
import io.grpc.netty.NettyChannelBuilder;
2827
import io.grpc.stub.ClientCallStreamObserver;
2928
import io.grpc.stub.ClientCalls;
3029
import io.grpc.stub.ClientResponseObserver;
3130
import io.grpc.stub.StreamObserver;
32-
import io.netty.channel.EventLoopGroup;
33-
import io.netty.channel.ServerChannel;
34-
import io.netty.handler.ssl.SslContextBuilder;
35-
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
3631
import java.io.IOException;
3732
import java.io.InputStream;
38-
import java.lang.reflect.InvocationTargetException;
3933
import java.net.URISyntaxException;
4034
import java.nio.ByteBuffer;
4135
import java.util.ArrayList;
@@ -45,7 +39,6 @@
4539
import java.util.concurrent.ExecutionException;
4640
import java.util.concurrent.TimeUnit;
4741
import java.util.function.BooleanSupplier;
48-
import javax.net.ssl.SSLException;
4942
import org.apache.arrow.flight.FlightProducer.StreamListener;
5043
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
5144
import org.apache.arrow.flight.auth.ClientAuthHandler;
@@ -57,6 +50,7 @@
5750
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
5851
import org.apache.arrow.flight.grpc.ClientInterceptorAdapter;
5952
import org.apache.arrow.flight.grpc.CredentialCallOption;
53+
import org.apache.arrow.flight.grpc.NettyClientBuilder;
6054
import org.apache.arrow.flight.grpc.StatusUtils;
6155
import org.apache.arrow.flight.impl.Flight;
6256
import org.apache.arrow.flight.impl.Flight.Empty;
@@ -73,12 +67,6 @@
7367
public class FlightClient implements AutoCloseable {
7468
private static final int PENDING_REQUESTS = 5;
7569

76-
/**
77-
* The maximum number of trace events to keep on the gRPC Channel. This value disables channel
78-
* tracing.
79-
*/
80-
private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
81-
8270
private final BufferAllocator allocator;
8371
private final ManagedChannel channel;
8472

@@ -97,11 +85,12 @@ public class FlightClient implements AutoCloseable {
9785
List<FlightClientMiddleware.Factory> middleware) {
9886
this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
9987
this.channel = channel;
100-
this.middleware = middleware;
88+
// We need a mutable copy (shared between this class and ClientInterceptorAdapter)
89+
this.middleware = new ArrayList<>(middleware);
10190

10291
final ClientInterceptor[] interceptors;
10392
interceptors =
104-
new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(middleware)};
93+
new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(this.middleware)};
10594

10695
// Create a channel with interceptors pre-applied for DoGet and DoPut
10796
Channel interceptedChannel = ClientInterceptors.intercept(channel, interceptors);
@@ -772,176 +761,71 @@ public static Builder builder(BufferAllocator allocator, Location location) {
772761

773762
/** A builder for Flight clients. */
774763
public static final class Builder {
775-
private BufferAllocator allocator;
776-
private Location location;
777-
private boolean forceTls = false;
778-
private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE;
779-
private InputStream trustedCertificates = null;
780-
private InputStream clientCertificate = null;
781-
private InputStream clientKey = null;
782-
private String overrideHostname = null;
783-
private List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
784-
private boolean verifyServer = true;
785-
786-
private Builder() {}
764+
private final NettyClientBuilder builder;
765+
766+
private Builder() {
767+
this.builder = new NettyClientBuilder();
768+
}
787769

788770
private Builder(BufferAllocator allocator, Location location) {
789-
this.allocator = Preconditions.checkNotNull(allocator);
790-
this.location = Preconditions.checkNotNull(location);
771+
this.builder = new NettyClientBuilder(allocator, location);
791772
}
792773

793774
/** Force the client to connect over TLS. */
794775
public Builder useTls() {
795-
this.forceTls = true;
776+
builder.useTls();
796777
return this;
797778
}
798779

799780
/** Override the hostname checked for TLS. Use with caution in production. */
800781
public Builder overrideHostname(final String hostname) {
801-
this.overrideHostname = hostname;
782+
builder.overrideHostname(hostname);
802783
return this;
803784
}
804785

805786
/** Set the maximum inbound message size. */
806787
public Builder maxInboundMessageSize(int maxSize) {
807-
Preconditions.checkArgument(maxSize > 0);
808-
this.maxInboundMessageSize = maxSize;
788+
builder.maxInboundMessageSize(maxSize);
809789
return this;
810790
}
811791

812792
/** Set the trusted TLS certificates. */
813793
public Builder trustedCertificates(final InputStream stream) {
814-
this.trustedCertificates = Preconditions.checkNotNull(stream);
794+
builder.trustedCertificates(stream);
815795
return this;
816796
}
817797

818798
/** Set the trusted TLS certificates. */
819799
public Builder clientCertificate(
820800
final InputStream clientCertificate, final InputStream clientKey) {
821-
Preconditions.checkNotNull(clientKey);
822-
this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
823-
this.clientKey = Preconditions.checkNotNull(clientKey);
801+
builder.clientCertificate(clientCertificate, clientKey);
824802
return this;
825803
}
826804

827805
public Builder allocator(BufferAllocator allocator) {
828-
this.allocator = Preconditions.checkNotNull(allocator);
806+
builder.allocator(allocator);
829807
return this;
830808
}
831809

832810
public Builder location(Location location) {
833-
this.location = Preconditions.checkNotNull(location);
811+
builder.location(location);
834812
return this;
835813
}
836814

837815
public Builder intercept(FlightClientMiddleware.Factory factory) {
838-
middleware.add(factory);
816+
builder.intercept(factory);
839817
return this;
840818
}
841819

842820
public Builder verifyServer(boolean verifyServer) {
843-
this.verifyServer = verifyServer;
821+
builder.verifyServer(verifyServer);
844822
return this;
845823
}
846824

847825
/** Create the client from this builder. */
848826
public FlightClient build() {
849-
final NettyChannelBuilder builder;
850-
851-
switch (location.getUri().getScheme()) {
852-
case LocationSchemes.GRPC:
853-
case LocationSchemes.GRPC_INSECURE:
854-
case LocationSchemes.GRPC_TLS:
855-
{
856-
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
857-
break;
858-
}
859-
case LocationSchemes.GRPC_DOMAIN_SOCKET:
860-
{
861-
// The implementation is platform-specific, so we have to find the classes at runtime
862-
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
863-
try {
864-
try {
865-
// Linux
866-
builder.channelType(
867-
Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")
868-
.asSubclass(ServerChannel.class));
869-
final EventLoopGroup elg =
870-
Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
871-
.asSubclass(EventLoopGroup.class)
872-
.getDeclaredConstructor()
873-
.newInstance();
874-
builder.eventLoopGroup(elg);
875-
} catch (ClassNotFoundException e) {
876-
// BSD
877-
builder.channelType(
878-
Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")
879-
.asSubclass(ServerChannel.class));
880-
final EventLoopGroup elg =
881-
Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
882-
.asSubclass(EventLoopGroup.class)
883-
.getDeclaredConstructor()
884-
.newInstance();
885-
builder.eventLoopGroup(elg);
886-
}
887-
} catch (ClassNotFoundException
888-
| InstantiationException
889-
| IllegalAccessException
890-
| NoSuchMethodException
891-
| InvocationTargetException e) {
892-
throw new UnsupportedOperationException(
893-
"Could not find suitable Netty native transport implementation for domain socket address.");
894-
}
895-
break;
896-
}
897-
default:
898-
throw new IllegalArgumentException(
899-
"Scheme is not supported: " + location.getUri().getScheme());
900-
}
901-
902-
if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
903-
builder.useTransportSecurity();
904-
905-
final boolean hasTrustedCerts = this.trustedCertificates != null;
906-
final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
907-
if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
908-
throw new IllegalArgumentException(
909-
"FlightClient has been configured to disable server verification, "
910-
+ "but certificate options have been specified.");
911-
}
912-
913-
final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
914-
915-
if (!this.verifyServer) {
916-
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
917-
} else if (this.trustedCertificates != null
918-
|| this.clientCertificate != null
919-
|| this.clientKey != null) {
920-
if (this.trustedCertificates != null) {
921-
sslContextBuilder.trustManager(this.trustedCertificates);
922-
}
923-
if (this.clientCertificate != null && this.clientKey != null) {
924-
sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
925-
}
926-
}
927-
try {
928-
builder.sslContext(sslContextBuilder.build());
929-
} catch (SSLException e) {
930-
throw new RuntimeException(e);
931-
}
932-
933-
if (this.overrideHostname != null) {
934-
builder.overrideAuthority(this.overrideHostname);
935-
}
936-
} else {
937-
builder.usePlaintext();
938-
}
939-
940-
builder
941-
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
942-
.maxInboundMessageSize(maxInboundMessageSize)
943-
.maxInboundMetadataSize(maxInboundMessageSize);
944-
return new FlightClient(allocator, builder.build(), middleware);
827+
final NettyChannelBuilder channelBuilder = builder.build();
828+
return new FlightClient(builder.allocator(), channelBuilder.build(), builder.middleware());
945829
}
946830
}
947831

flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.grpc.ManagedChannel;
2424
import io.grpc.MethodDescriptor;
2525
import java.util.Collections;
26+
import java.util.List;
2627
import java.util.concurrent.ExecutorService;
2728
import java.util.concurrent.TimeUnit;
2829
import org.apache.arrow.flight.auth.ServerAuthHandler;
@@ -151,6 +152,19 @@ public static FlightClient createFlightClient(
151152
return new FlightClient(incomingAllocator, channel, Collections.emptyList());
152153
}
153154

155+
/**
156+
* Creates a Flight client.
157+
*
158+
* @param incomingAllocator Memory allocator
159+
* @param channel provides a connection to a gRPC server.
160+
*/
161+
public static FlightClient createFlightClient(
162+
BufferAllocator incomingAllocator,
163+
ManagedChannel channel,
164+
List<FlightClientMiddleware.Factory> middleware) {
165+
return new FlightClient(incomingAllocator, channel, middleware);
166+
}
167+
154168
/**
155169
* Creates a Flight client.
156170
*

0 commit comments

Comments
 (0)