From 4846efd65fd885a81cdde971c94c8a06a1d42a9d Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 9 Jan 2025 01:16:05 -0500 Subject: [PATCH] GH-81: [Flight] Expose gRPC in Flight client builder Fixes #81. --- .../org/apache/arrow/flight/FlightClient.java | 158 ++---------- .../apache/arrow/flight/FlightGrpcUtils.java | 14 ++ .../arrow/flight/grpc/NettyClientBuilder.java | 232 ++++++++++++++++++ 3 files changed, 267 insertions(+), 137 deletions(-) create mode 100644 flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index b12ae5ec8c..fd6e498d13 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -23,19 +23,13 @@ import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; import io.grpc.StatusRuntimeException; -import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.ClientCallStreamObserver; import io.grpc.stub.ClientCalls; import io.grpc.stub.ClientResponseObserver; import io.grpc.stub.StreamObserver; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.ServerChannel; -import io.netty.handler.ssl.SslContextBuilder; -import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.InvocationTargetException; import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -45,7 +39,6 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.function.BooleanSupplier; -import javax.net.ssl.SSLException; import org.apache.arrow.flight.FlightProducer.StreamListener; import org.apache.arrow.flight.auth.BasicClientAuthHandler; import org.apache.arrow.flight.auth.ClientAuthHandler; @@ -57,6 +50,7 @@ import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; import org.apache.arrow.flight.grpc.ClientInterceptorAdapter; import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.grpc.NettyClientBuilder; import org.apache.arrow.flight.grpc.StatusUtils; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.Flight.Empty; @@ -73,12 +67,6 @@ public class FlightClient implements AutoCloseable { private static final int PENDING_REQUESTS = 5; - /** - * The maximum number of trace events to keep on the gRPC Channel. This value disables channel - * tracing. - */ - private static final int MAX_CHANNEL_TRACE_EVENTS = 0; - private final BufferAllocator allocator; private final ManagedChannel channel; @@ -97,11 +85,12 @@ public class FlightClient implements AutoCloseable { List middleware) { this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE); this.channel = channel; - this.middleware = middleware; + // We need a mutable copy (shared between this class and ClientInterceptorAdapter) + this.middleware = new ArrayList<>(middleware); final ClientInterceptor[] interceptors; interceptors = - new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(middleware)}; + new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(this.middleware)}; // Create a channel with interceptors pre-applied for DoGet and DoPut Channel interceptedChannel = ClientInterceptors.intercept(channel, interceptors); @@ -772,176 +761,71 @@ public static Builder builder(BufferAllocator allocator, Location location) { /** A builder for Flight clients. */ public static final class Builder { - private BufferAllocator allocator; - private Location location; - private boolean forceTls = false; - private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE; - private InputStream trustedCertificates = null; - private InputStream clientCertificate = null; - private InputStream clientKey = null; - private String overrideHostname = null; - private List middleware = new ArrayList<>(); - private boolean verifyServer = true; - - private Builder() {} + private final NettyClientBuilder builder; + + private Builder() { + this.builder = new NettyClientBuilder(); + } private Builder(BufferAllocator allocator, Location location) { - this.allocator = Preconditions.checkNotNull(allocator); - this.location = Preconditions.checkNotNull(location); + this.builder = new NettyClientBuilder(allocator, location); } /** Force the client to connect over TLS. */ public Builder useTls() { - this.forceTls = true; + builder.useTls(); return this; } /** Override the hostname checked for TLS. Use with caution in production. */ public Builder overrideHostname(final String hostname) { - this.overrideHostname = hostname; + builder.overrideHostname(hostname); return this; } /** Set the maximum inbound message size. */ public Builder maxInboundMessageSize(int maxSize) { - Preconditions.checkArgument(maxSize > 0); - this.maxInboundMessageSize = maxSize; + builder.maxInboundMessageSize(maxSize); return this; } /** Set the trusted TLS certificates. */ public Builder trustedCertificates(final InputStream stream) { - this.trustedCertificates = Preconditions.checkNotNull(stream); + builder.trustedCertificates(stream); return this; } /** Set the trusted TLS certificates. */ public Builder clientCertificate( final InputStream clientCertificate, final InputStream clientKey) { - Preconditions.checkNotNull(clientKey); - this.clientCertificate = Preconditions.checkNotNull(clientCertificate); - this.clientKey = Preconditions.checkNotNull(clientKey); + builder.clientCertificate(clientCertificate, clientKey); return this; } public Builder allocator(BufferAllocator allocator) { - this.allocator = Preconditions.checkNotNull(allocator); + builder.allocator(allocator); return this; } public Builder location(Location location) { - this.location = Preconditions.checkNotNull(location); + builder.location(location); return this; } public Builder intercept(FlightClientMiddleware.Factory factory) { - middleware.add(factory); + builder.intercept(factory); return this; } public Builder verifyServer(boolean verifyServer) { - this.verifyServer = verifyServer; + builder.verifyServer(verifyServer); return this; } /** Create the client from this builder. */ public FlightClient build() { - final NettyChannelBuilder builder; - - switch (location.getUri().getScheme()) { - case LocationSchemes.GRPC: - case LocationSchemes.GRPC_INSECURE: - case LocationSchemes.GRPC_TLS: - { - builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); - break; - } - case LocationSchemes.GRPC_DOMAIN_SOCKET: - { - // The implementation is platform-specific, so we have to find the classes at runtime - builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); - try { - try { - // Linux - builder.channelType( - Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel") - .asSubclass(ServerChannel.class)); - final EventLoopGroup elg = - Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") - .asSubclass(EventLoopGroup.class) - .getDeclaredConstructor() - .newInstance(); - builder.eventLoopGroup(elg); - } catch (ClassNotFoundException e) { - // BSD - builder.channelType( - Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel") - .asSubclass(ServerChannel.class)); - final EventLoopGroup elg = - Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") - .asSubclass(EventLoopGroup.class) - .getDeclaredConstructor() - .newInstance(); - builder.eventLoopGroup(elg); - } - } catch (ClassNotFoundException - | InstantiationException - | IllegalAccessException - | NoSuchMethodException - | InvocationTargetException e) { - throw new UnsupportedOperationException( - "Could not find suitable Netty native transport implementation for domain socket address."); - } - break; - } - default: - throw new IllegalArgumentException( - "Scheme is not supported: " + location.getUri().getScheme()); - } - - if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) { - builder.useTransportSecurity(); - - final boolean hasTrustedCerts = this.trustedCertificates != null; - final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null; - if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) { - throw new IllegalArgumentException( - "FlightClient has been configured to disable server verification, " - + "but certificate options have been specified."); - } - - final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); - - if (!this.verifyServer) { - sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); - } else if (this.trustedCertificates != null - || this.clientCertificate != null - || this.clientKey != null) { - if (this.trustedCertificates != null) { - sslContextBuilder.trustManager(this.trustedCertificates); - } - if (this.clientCertificate != null && this.clientKey != null) { - sslContextBuilder.keyManager(this.clientCertificate, this.clientKey); - } - } - try { - builder.sslContext(sslContextBuilder.build()); - } catch (SSLException e) { - throw new RuntimeException(e); - } - - if (this.overrideHostname != null) { - builder.overrideAuthority(this.overrideHostname); - } - } else { - builder.usePlaintext(); - } - - builder - .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) - .maxInboundMessageSize(maxInboundMessageSize) - .maxInboundMetadataSize(maxInboundMessageSize); - return new FlightClient(allocator, builder.build(), middleware); + final NettyChannelBuilder channelBuilder = builder.build(); + return new FlightClient(builder.allocator(), channelBuilder.build(), builder.middleware()); } } diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java index 13e4f2f215..df5e29741b 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java @@ -23,6 +23,7 @@ import io.grpc.ManagedChannel; import io.grpc.MethodDescriptor; import java.util.Collections; +import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import org.apache.arrow.flight.auth.ServerAuthHandler; @@ -151,6 +152,19 @@ public static FlightClient createFlightClient( return new FlightClient(incomingAllocator, channel, Collections.emptyList()); } + /** + * Creates a Flight client. + * + * @param incomingAllocator Memory allocator + * @param channel provides a connection to a gRPC server. + */ + public static FlightClient createFlightClient( + BufferAllocator incomingAllocator, + ManagedChannel channel, + List middleware) { + return new FlightClient(incomingAllocator, channel, middleware); + } + /** * Creates a Flight client. * diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java new file mode 100644 index 0000000000..42cdaac016 --- /dev/null +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/NettyClientBuilder.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight.grpc; + +import io.grpc.ManagedChannel; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import javax.net.ssl.SSLException; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.LocationSchemes; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; + +/** + * A wrapper around gRPC's Netty builder. + * + *

It is recommended to use the Netty channel builder directly with {@link + * org.apache.arrow.flight.FlightGrpcUtils#createFlightClient(BufferAllocator, ManagedChannel)}. + * However, this class provides an adapter that implements the existing Flight-specific builder + * interface but allows usage of the Netty builder as well. + */ +public class NettyClientBuilder { + /** + * The maximum number of trace events to keep on the gRPC Channel. This value disables channel + * tracing. + */ + private static final int MAX_CHANNEL_TRACE_EVENTS = 0; + + protected BufferAllocator allocator; + protected Location location; + protected boolean forceTls = false; + protected int maxInboundMessageSize = Integer.MAX_VALUE; + protected InputStream trustedCertificates = null; + protected InputStream clientCertificate = null; + protected InputStream clientKey = null; + protected String overrideHostname = null; + protected List middleware = new ArrayList<>(); + protected boolean verifyServer = true; + + public NettyClientBuilder() {} + + public NettyClientBuilder(BufferAllocator allocator, Location location) { + this.allocator = Preconditions.checkNotNull(allocator); + this.location = Preconditions.checkNotNull(location); + } + + /** Force the client to connect over TLS. */ + public NettyClientBuilder useTls() { + this.forceTls = true; + return this; + } + + /** Override the hostname checked for TLS. Use with caution in production. */ + public NettyClientBuilder overrideHostname(final String hostname) { + this.overrideHostname = hostname; + return this; + } + + /** Set the maximum inbound message size. */ + public NettyClientBuilder maxInboundMessageSize(int maxSize) { + Preconditions.checkArgument(maxSize > 0); + this.maxInboundMessageSize = maxSize; + return this; + } + + /** Set the trusted TLS certificates. */ + public NettyClientBuilder trustedCertificates(final InputStream stream) { + this.trustedCertificates = Preconditions.checkNotNull(stream); + return this; + } + + /** Set the trusted TLS certificates. */ + public NettyClientBuilder clientCertificate( + final InputStream clientCertificate, final InputStream clientKey) { + Preconditions.checkNotNull(clientKey); + this.clientCertificate = Preconditions.checkNotNull(clientCertificate); + this.clientKey = Preconditions.checkNotNull(clientKey); + return this; + } + + public BufferAllocator allocator() { + return allocator; + } + + public NettyClientBuilder allocator(BufferAllocator allocator) { + this.allocator = Preconditions.checkNotNull(allocator); + return this; + } + + public NettyClientBuilder location(Location location) { + this.location = Preconditions.checkNotNull(location); + return this; + } + + public List middleware() { + return Collections.unmodifiableList(middleware); + } + + public NettyClientBuilder intercept(FlightClientMiddleware.Factory factory) { + middleware.add(factory); + return this; + } + + public NettyClientBuilder verifyServer(boolean verifyServer) { + this.verifyServer = verifyServer; + return this; + } + + /** Create the client from this builder. */ + public NettyChannelBuilder build() { + final NettyChannelBuilder builder; + + switch (location.getUri().getScheme()) { + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_INSECURE: + case LocationSchemes.GRPC_TLS: + { + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + break; + } + case LocationSchemes.GRPC_DOMAIN_SOCKET: + { + // The implementation is platform-specific, so we have to find the classes at runtime + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + try { + try { + // Linux + builder.channelType( + Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = + Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") + .asSubclass(EventLoopGroup.class) + .getDeclaredConstructor() + .newInstance(); + builder.eventLoopGroup(elg); + } catch (ClassNotFoundException e) { + // BSD + builder.channelType( + Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel") + .asSubclass(ServerChannel.class)); + final EventLoopGroup elg = + Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .asSubclass(EventLoopGroup.class) + .getDeclaredConstructor() + .newInstance(); + builder.eventLoopGroup(elg); + } + } catch (ClassNotFoundException + | InstantiationException + | IllegalAccessException + | NoSuchMethodException + | InvocationTargetException e) { + throw new UnsupportedOperationException( + "Could not find suitable Netty native transport implementation for domain socket address."); + } + break; + } + default: + throw new IllegalArgumentException( + "Scheme is not supported: " + location.getUri().getScheme()); + } + + if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) { + builder.useTransportSecurity(); + + final boolean hasTrustedCerts = this.trustedCertificates != null; + final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null; + if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) { + throw new IllegalArgumentException( + "FlightClient has been configured to disable server verification, " + + "but certificate options have been specified."); + } + + final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + + if (!this.verifyServer) { + sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); + } else if (this.trustedCertificates != null + || this.clientCertificate != null + || this.clientKey != null) { + if (this.trustedCertificates != null) { + sslContextBuilder.trustManager(this.trustedCertificates); + } + if (this.clientCertificate != null && this.clientKey != null) { + sslContextBuilder.keyManager(this.clientCertificate, this.clientKey); + } + } + try { + builder.sslContext(sslContextBuilder.build()); + } catch (SSLException e) { + throw new RuntimeException(e); + } + + if (this.overrideHostname != null) { + builder.overrideAuthority(this.overrideHostname); + } + } else { + builder.usePlaintext(); + } + + builder + .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) + .maxInboundMessageSize(maxInboundMessageSize) + .maxInboundMetadataSize(maxInboundMessageSize); + return builder; + } +}