From 7610af42099d5835dab4bb966c88d73a872a576f Mon Sep 17 00:00:00 2001 From: HackPoint Date: Tue, 15 Apr 2025 19:11:48 +0300 Subject: [PATCH 01/10] feat(middleware): add comprehensive tests for CookieMiddleware capturing headers and cookie values - Implemented tests verifying that CookieMiddleware correctly captures and persists values from both headers and cookies. - Added CapturingCookieMiddleware and CapturingCookieMiddlewareFactory for validating middleware lifecycle. - Ensured middleware correctly processes cookie strings and validates persistence throughout the request lifecycle. --- .../Middleware/Extensions/CookieExtensions.cs | 51 +++++++ .../Middleware/Grpc/FlightMethodParser.cs | 64 ++++++++ .../Middleware/Grpc/MetadataAdapter.cs | 102 +++++++++++++ .../Middleware/Grpc/StatusUtils.cs | 58 ++++++++ .../Interceptors/ClientInterceptorAdapter.cs | 131 +++++++++++++++++ .../Middleware/Interfaces/ICallHeaders.cs | 34 +++++ .../Interfaces/IFlightClientMiddleware.cs | 30 ++++ .../Middleware/ClientCookieMiddleware.cs | 129 +++++++++++++++++ .../Middleware/Models/CallInfo.cs | 30 ++++ .../Middleware/Models/CallStatus.cs | 35 +++++ .../Middleware/Models/FlightMethod.cs | 31 ++++ .../Middleware/Models/FlightStatusCode.cs | 37 +++++ .../Apache.Arrow.Flight.Sql.Tests.csproj | 3 +- .../CallHeadersTests.cs | 134 +++++++++++++++++ .../ClientCookieMiddlewareTests.cs | 137 ++++++++++++++++++ .../ClientInterceptorAdapterTests.cs | 101 +++++++++++++ .../Stubs/CapturingMiddleware.cs | 58 ++++++++ .../Stubs/CapturingMiddlewareFactory.cs | 26 ++++ .../Stubs/ClientCookieMiddlewareMock.cs | 64 ++++++++ .../Stubs/InMemoryCallHeaders.cs | 66 +++++++++ .../Stubs/InMemoryFlightStore.cs | 49 +++++++ 21 files changed, 1369 insertions(+), 1 deletion(-) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Extensions/CookieExtensions.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/MetadataAdapter.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/ICallHeaders.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/IFlightClientMiddleware.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallInfo.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallStatus.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightMethod.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightStatusCode.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/CallHeadersTests.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientInterceptorAdapterTests.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddleware.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddlewareFactory.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryCallHeaders.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryFlightStore.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Extensions/CookieExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Extensions/CookieExtensions.cs new file mode 100644 index 00000000000..aba6d0f1f71 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Extensions/CookieExtensions.cs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; + +namespace Apache.Arrow.Flight.Sql.Middleware.Extensions; + +// TODO: Add tests to cover: CookieExtensions +internal static class CookieExtensions +{ + public static IEnumerable ParseHeader(this string headers) + { + var cookies = new List(); + var segments = headers.Split(';', StringSplitOptions.RemoveEmptyEntries); + + if (segments.Length == 0) return cookies; + + var nameValue = segments[0].Split('=', 2); + if (nameValue.Length == 2) + { + var cookie = new Cookie(nameValue[0], nameValue[1]); + foreach (var segment in segments.Skip(1)) + { + if (segment.StartsWith("Expires=", StringComparison.OrdinalIgnoreCase)) + { + if (DateTimeOffset.TryParse(segment["Expires=".Length..], out var expires)) + cookie.Expires = expires.UtcDateTime; + } + } + + cookies.Add(cookie); + } + + return cookies; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs new file mode 100644 index 00000000000..a73921b182f --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.Sql.Middleware.Models; + +namespace Apache.Arrow.Flight.Sql.Middleware.Gprc; + +// TODO: Add tests to cover: FlightMethodParser +public static class FlightMethodParser +{ + /// + /// Parses the gRPC full method name (e.g., "/arrow.flight.protocol.FlightService/DoGet") + /// and maps it to a known FlightMethod. + /// + /// gRPC method name + /// Parsed FlightMethod + public static FlightMethod Parse(string fullMethodName) + { + if (string.IsNullOrWhiteSpace(fullMethodName)) + return FlightMethod.Unknown; + + var parts = fullMethodName.Split('/'); + if (parts.Length < 2) + return FlightMethod.Unknown; + + var methodName = parts[^1]; + + return methodName switch + { + "Handshake" => FlightMethod.Handshake, + "ListFlights" => FlightMethod.ListFlights, + "GetFlightInfo" => FlightMethod.GetFlightInfo, + "GetSchema" => FlightMethod.GetSchema, + "DoGet" => FlightMethod.DoGet, + "DoPut" => FlightMethod.DoPut, + "DoExchange" => FlightMethod.DoExchange, + "DoAction" => FlightMethod.DoAction, + "ListActions" => FlightMethod.ListActions, + "CancelFlightInfo" => FlightMethod.CancelFlightInfo, + _ => FlightMethod.Unknown + }; + } + + public static string ParseMethodName(string fullMethodName) + { + if (string.IsNullOrWhiteSpace(fullMethodName)) + return "Unknown"; + + var parts = fullMethodName.Split('/'); + return parts.Length >= 2 ? parts[^1] : "Unknown"; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/MetadataAdapter.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/MetadataAdapter.cs new file mode 100644 index 00000000000..d81ef73ba99 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/MetadataAdapter.cs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql.Middleware.Grpc; + +public class MetadataAdapter : ICallHeaders +{ + private readonly Metadata _metadata; + + public MetadataAdapter(Metadata metadata) + { + _metadata = metadata ?? throw new ArgumentNullException(nameof(metadata)); + } + + public string? this[string key] => Get(key); + + public string? Get(string key) + { + return _metadata.FirstOrDefault(e => + !e.IsBinary && e.Key.Equals(key, StringComparison.OrdinalIgnoreCase))?.Value; + } + + public byte[]? GetBytes(string key) + { + return _metadata.FirstOrDefault(e => + e.IsBinary && e.Key.Equals(NormalizeBinaryKey(key), StringComparison.OrdinalIgnoreCase))?.ValueBytes; + } + + public IEnumerable GetAll(string key) + { + return _metadata + .Where(e => !e.IsBinary && e.Key.Equals(key, StringComparison.OrdinalIgnoreCase)) + .Select(e => e.Value); + } + + public IEnumerable GetAllBytes(string key) + { + var binaryKey = NormalizeBinaryKey(key); + return _metadata + .Where(e => e.IsBinary && e.Key.Equals(binaryKey, StringComparison.OrdinalIgnoreCase)) + .Select(e => e.ValueBytes); + } + + public void Insert(string key, string value) + { + _metadata.Add(key, value); + } + + public void Insert(string key, byte[] value) + { + _metadata.Add(NormalizeBinaryKey(key), value); + } + + public ISet Keys => + new HashSet(_metadata.Select(e => + e.IsBinary ? DenormalizeBinaryKey(e.Key) : e.Key), + StringComparer.OrdinalIgnoreCase); + + public bool ContainsKey(string key) + { + return _metadata.Any(e => + e.Key.Equals(key, StringComparison.OrdinalIgnoreCase) || + e.Key.Equals(NormalizeBinaryKey(key), StringComparison.OrdinalIgnoreCase)); + } + + private static string NormalizeBinaryKey(string key) + => key.EndsWith(Metadata.BinaryHeaderSuffix, StringComparison.OrdinalIgnoreCase) + ? key + : key + Metadata.BinaryHeaderSuffix; + + private static string DenormalizeBinaryKey(string key) + => key.EndsWith(Metadata.BinaryHeaderSuffix, StringComparison.OrdinalIgnoreCase) + ? key[..^Metadata.BinaryHeaderSuffix.Length] + : key; +} + +public static class MetadataAdapterExtensions +{ + public static bool TryGet(this ICallHeaders headers, string key, out string? value) + { + value = headers.Get(key); + return value is not null; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs new file mode 100644 index 00000000000..9485167f76a --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.Sql.Middleware.Models; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql.Middleware.Gprc; + +public static class StatusUtils +{ + public static CallStatus FromGrpcStatusAndTrailers(Status status, Metadata trailers) + { + var code = FromGrpcStatusCode(status.StatusCode); + return new CallStatus( + code, + status.StatusCode != StatusCode.OK ? new RpcException(status, trailers) : null, + status.Detail, + trailers + ); + } + + public static FlightStatusCode FromGrpcStatusCode(StatusCode grpcCode) + { + return grpcCode switch + { + StatusCode.OK => FlightStatusCode.Ok, + StatusCode.Cancelled => FlightStatusCode.Cancelled, + StatusCode.Unknown => FlightStatusCode.Unknown, + StatusCode.InvalidArgument => FlightStatusCode.InvalidArgument, + StatusCode.DeadlineExceeded => FlightStatusCode.DeadlineExceeded, + StatusCode.NotFound => FlightStatusCode.NotFound, + StatusCode.AlreadyExists => FlightStatusCode.AlreadyExists, + StatusCode.PermissionDenied => FlightStatusCode.PermissionDenied, + StatusCode.Unauthenticated => FlightStatusCode.Unauthenticated, + StatusCode.ResourceExhausted => FlightStatusCode.ResourceExhausted, + StatusCode.FailedPrecondition => FlightStatusCode.FailedPrecondition, + StatusCode.Aborted => FlightStatusCode.Aborted, + StatusCode.OutOfRange => FlightStatusCode.OutOfRange, + StatusCode.Unimplemented => FlightStatusCode.Unimplemented, + StatusCode.Internal => FlightStatusCode.Internal, + StatusCode.Unavailable => FlightStatusCode.Unavailable, + StatusCode.DataLoss => FlightStatusCode.DataLoss, + _ => FlightStatusCode.Unknown + }; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs new file mode 100644 index 00000000000..576f3f72add --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Flight.Sql.Middleware.Gprc; +using Apache.Arrow.Flight.Sql.Middleware.Grpc; +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Sql.Middleware.Models; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Sql.Middleware.Interceptors +{ + public class ClientInterceptorAdapter : Interceptor + { + private readonly IList _factories; + + public ClientInterceptorAdapter(IEnumerable factories) + { + _factories = factories.ToList(); + } + + public override AsyncUnaryCall AsyncUnaryCall( + TRequest request, + ClientInterceptorContext context, + AsyncUnaryCallContinuation continuation) + where TRequest : class + where TResponse : class + { + var middleware = new List(); + var callInfo = new CallInfo( + context.Host ?? "unknown", + FlightMethodParser.ParseMethodName(context.Method.FullName)); + + try + { + middleware.AddRange(_factories.Select(factory => factory.OnCallStarted(callInfo))); + } + catch (Exception e) + { + throw new RpcException(new Status(StatusCode.Internal, "Middleware creation failed"), e.Message); + } + + // Apply middleware headers + var middlewareHeaders = new Metadata(); + var headerAdapter = new MetadataAdapter(middlewareHeaders); + foreach (var m in middleware) + { + m.OnBeforeSendingHeaders(headerAdapter); + } + + // Merge original headers with middleware headers + var mergedHeaders = new Metadata(); + if (context.Options.Headers != null) + { + foreach (var entry in context.Options.Headers) + { + mergedHeaders.Add(entry); + } + } + + foreach (var entry in middlewareHeaders) + { + mergedHeaders.Add(entry); + } + + var updatedContext = new ClientInterceptorContext( + context.Method, + context.Host, + context.Options.WithHeaders(mergedHeaders) + ); + + var headersReceived = false; + var call = continuation(request, updatedContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.Exception is null) + { + var metadataAdapter = new MetadataAdapter(task.Result); + middleware.ForEach(m => m.OnHeadersReceived(metadataAdapter)); + headersReceived = true; + } + return task.Result; + }); + + var responseTask = call.ResponseAsync.ContinueWith(response => + { + // If headers were never received, simulate with trailers + if (!headersReceived) + { + var trailersAdapter = new MetadataAdapter(call.GetTrailers()); + foreach (var m in middleware) + m.OnHeadersReceived(trailersAdapter); + } + + var status = call.GetStatus(); + var trailers = call.GetTrailers(); + var flightStatus = StatusUtils.FromGrpcStatusAndTrailers(status, trailers); + + middleware.ForEach(m => m.OnCallCompleted(flightStatus)); + + if (response.IsFaulted && response.Exception != null) + throw response.Exception; + + return response.Result; + }); + + return new AsyncUnaryCall( + responseTask, + responseHeadersTask, + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/ICallHeaders.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/ICallHeaders.cs new file mode 100644 index 00000000000..1290f185e8d --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/ICallHeaders.cs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; + +namespace Apache.Arrow.Flight.Sql.Middleware.Interfaces; + +public interface ICallHeaders +{ + string? this[string key] { get; } + + string? Get(string key); + byte[]? GetBytes(string key); + IEnumerable GetAll(string key); + IEnumerable GetAllBytes(string key); + + void Insert(string key, string value); + void Insert(string key, byte[] value); + + ISet Keys { get; } + bool ContainsKey(string key); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/IFlightClientMiddleware.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/IFlightClientMiddleware.cs new file mode 100644 index 00000000000..75caf9d56d0 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/IFlightClientMiddleware.cs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.Sql.Middleware.Models; + +namespace Apache.Arrow.Flight.Sql.Middleware.Interfaces; + +public interface IFlightClientMiddleware +{ + void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders); + void OnHeadersReceived(ICallHeaders incomingHeaders); + void OnCallCompleted(CallStatus status); +} + +public interface IFlightClientMiddlewareFactory +{ + IFlightClientMiddleware OnCallStarted(CallInfo callInfo); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs new file mode 100644 index 00000000000..d56fa61d85b --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using Apache.Arrow.Flight.Sql.Middleware.Extensions; +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Sql.Middleware.Models; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Sql.Middleware.Middleware; + +public class ClientCookieMiddleware : IFlightClientMiddleware +{ + private readonly ClientCookieMiddlewareFactory _factory; + private readonly ILogger _logger; + private const string SET_COOKIE_HEADER = "Set-cookie"; + private const string COOKIE_HEADER = "Cookie"; + + private readonly ConcurrentDictionary _cookies = new(); + + public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, + ILogger logger) + { + _factory = factory; + _logger = logger; + } + + public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) + { + var cookieValue = GetValidCookiesAsString(); + if (!string.IsNullOrEmpty(cookieValue)) + { + outgoingHeaders.Insert(COOKIE_HEADER, cookieValue); + } + + _logger.LogInformation("Sending Headers: " + string.Join(", ", outgoingHeaders.Keys)); + } + + public void OnHeadersReceived(ICallHeaders incomingHeaders) + { + var setCookieHeaders = incomingHeaders.GetAll(SET_COOKIE_HEADER); + _factory.UpdateCookies(setCookieHeaders); + _logger.LogInformation("Received Headers: " + string.Join(", ", incomingHeaders.Keys)); + } + + public void OnCallCompleted(CallStatus status) + { + _logger.LogInformation($"Call completed with: {status.Code} ({status.Description})"); + } + + private string GetValidCookiesAsString() + { + var cookieList = new List(); + foreach (var entry in _factory.Cookies) + { + if (entry.Value.Expired) + { + _factory.Cookies.TryRemove(entry.Key, out _); + } + else + { + cookieList.Add(entry.Value.ToString()); + } + } + + return string.Join("; ", cookieList); + } + + public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory + { + public readonly ConcurrentDictionary Cookies = new(StringComparer.OrdinalIgnoreCase); + private readonly ILoggerFactory _loggerFactory; + + public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; + } + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) + { + var logger = _loggerFactory.CreateLogger(); + return new ClientCookieMiddleware(this, logger); + } + + internal void UpdateCookies(IEnumerable newCookieHeaderValues) + { + foreach (var headerValue in newCookieHeaderValues) + { + try + { + var parsedCookies = headerValue.ParseHeader(); + foreach (var parsedCookie in parsedCookies) + { + var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); + if (parsedCookie.Expired) + { + Cookies.TryRemove(nameLc, out _); + } + else + { + Cookies[nameLc] = parsedCookie; + } + } + } + catch (FormatException ex) + { + var logger = _loggerFactory.CreateLogger(); + logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); + } + } + } + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallInfo.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallInfo.cs new file mode 100644 index 00000000000..ee20109b973 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallInfo.cs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; + +namespace Apache.Arrow.Flight.Sql.Middleware.Models; + +public sealed class CallInfo +{ + public string Endpoint { get; } + public string MethodName { get; } + + public CallInfo(string endpoint, string methodName) + { + Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); + MethodName = methodName ?? throw new ArgumentNullException(nameof(methodName)); + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallStatus.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallStatus.cs new file mode 100644 index 00000000000..4294b1d855e --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallStatus.cs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql.Middleware.Models; + +public sealed class CallStatus +{ + public FlightStatusCode Code { get; } + public Exception? Cause { get; } + public string? Description { get; } + public Metadata? Trailers { get; } + + public CallStatus(FlightStatusCode code, Exception? cause, string? description, Metadata? trailers) + { + Code = code; + Cause = cause; + Description = description; + Trailers = trailers; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightMethod.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightMethod.cs new file mode 100644 index 00000000000..c53a1e668a2 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightMethod.cs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Flight.Sql.Middleware.Models; + +public enum FlightMethod +{ + Unknown, + Handshake, + ListFlights, + GetFlightInfo, + GetSchema, + DoGet, + DoPut, + DoExchange, + DoAction, + ListActions, + CancelFlightInfo +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightStatusCode.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightStatusCode.cs new file mode 100644 index 00000000000..65221c1e192 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightStatusCode.cs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Flight.Sql.Middleware.Models; + +public enum FlightStatusCode +{ + Ok, + Cancelled, + Unknown, + InvalidArgument, + DeadlineExceeded, + NotFound, + AlreadyExists, + PermissionDenied, + Unauthenticated, + ResourceExhausted, + FailedPrecondition, + Aborted, + OutOfRange, + Unimplemented, + Internal, + Unavailable, + DataLoss +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index 945c4d1e384..58bd732a11d 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -13,7 +13,8 @@ - + + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/CallHeadersTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/CallHeadersTests.cs new file mode 100644 index 00000000000..00210557e20 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/CallHeadersTests.cs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Linq; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class CallHeadersTests +{ + private readonly InMemoryCallHeaders _headers = new(); + + [Fact] + public void InsertAndGetStringValue() + { + _headers.Insert("Auth", "Bearer 123"); + Assert.Equal("Bearer 123", _headers.Get("Auth")); + Assert.Equal("Bearer 123", _headers["Auth"]); + } + + [Fact] + public void InsertAndGetByteArrayValue() + { + var bytes = new byte[] { 1, 2, 3, 4, 5 }; + _headers.Insert("Data", bytes); + Assert.Equal(bytes, _headers.GetBytes("Data")); + } + + [Fact] + public void InsertMultipleValuesAndGetLast() + { + _headers.Insert("User", "Alice"); + _headers.Insert("User", "Bob"); + Assert.Equal("Bob", _headers.Get("User")); + } + + [Fact] + public void GetAllShouldReturnAllStringValues() + { + _headers.Insert("Header", "v1"); + _headers.Insert("Header", "v2"); + var all = _headers.GetAll("Header").ToList(); + Assert.Contains("v1", all); + Assert.Contains("v2", all); + Assert.Equal(2, all.Count); + } + + [Fact] + public void GetAllBytesShouldReturnAllByteArrayValues() + { + var a = new byte[] { 1 }; + var b = new byte[] { 2 }; + _headers.Insert("Binary", a); + _headers.Insert("Binary", b); + var all = _headers.GetAllBytes("Binary").ToList(); + Assert.Contains(a, all); + Assert.Contains(b, all); + Assert.Equal(2, all.Count); + } + + [Fact] + public void KeysShouldReturnAllKeys() + { + _headers.Insert("A", "x"); + _headers.Insert("B", "y"); + Assert.Contains("A", _headers.Keys); + Assert.Contains("B", _headers.Keys); + } + + [Fact] + public void ContainsKeyShouldWork() + { + _headers.Insert("Check", "yes"); + Assert.True(_headers.ContainsKey("Check")); + Assert.False(_headers.ContainsKey("Missing")); + } + + [Fact] + public void GetNonExistentKeyShouldReturnNull() + { + Assert.Null(_headers.Get("MissingKey")); + Assert.Null(_headers.GetBytes("MissingKey")); + Assert.Empty(_headers.GetAll("MissingKey")); + Assert.Empty(_headers.GetAllBytes("MissingKey")); + } + + [Fact] + public void ContainsKeyShouldBeFalseForMissingKey() + { + Assert.False(_headers.ContainsKey("DefinitelyMissing")); + } + + [Fact] + public void KeysShouldBeEmptyWhenNoHeaders() + { + Assert.Empty(_headers.Keys); + } + + [Fact] + public void IndexerShouldReturnNullForMissingKey() + { + string value = _headers["nonexistent"]; + Assert.Null(value); + } + + [Fact] + public void InsertEmptyStringsShouldStillStore() + { + _headers.Insert("Empty", ""); + Assert.Equal("", _headers.Get("Empty")); + Assert.Single(_headers.GetAll("Empty")); + } + + [Fact] + public void InsertEmptyByteArrayShouldStillStore() + { + var empty = System.Array.Empty(); + _headers.Insert("BinaryEmpty", empty); + Assert.Equal(empty, _headers.GetBytes("BinaryEmpty")); + Assert.Single(_headers.GetAllBytes("BinaryEmpty")); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs new file mode 100644 index 00000000000..7e04e153349 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Sql.Middleware.Middleware; +using Apache.Arrow.Flight.Sql.Tests.Stubs; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class ClientCookieMiddlewareTests +{ + private readonly ClientCookieMiddlewareMock _middlewareMock = new(); + + [Fact] + public void NoCookiesReturnsEmptyString() + { + var factory = _middlewareMock.CreateFactory(); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + Assert.Empty(headers.GetAll("Cookie")); + } + + [Fact] + public void OnlyExpiredCookiesRemovesAll() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["expired"] = + _middlewareMock.CreateCookie("expired", "value", DateTimeOffset.UtcNow.AddMinutes(-5)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + Assert.Empty(headers.GetAll("Cookie")); + Assert.Empty(factory.Cookies); + } + + [Fact] + public void OnlyValidCookiesReturnsCookieHeader() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["valid"] = _middlewareMock.CreateCookie("valid", "abc", DateTimeOffset.UtcNow.AddMinutes(10)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + var header = headers.GetAll("Cookie").FirstOrDefault(); + Assert.NotNull(header); + Assert.Contains("valid=abc", header); + } + + [Fact] + public void MixedCookiesRemovesExpiredOnly() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["expired"] = + _middlewareMock.CreateCookie("expired", "x", DateTimeOffset.UtcNow.AddMinutes(-10)); + factory.Cookies["valid"] = _middlewareMock.CreateCookie("valid", "y", DateTimeOffset.UtcNow.AddMinutes(10)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + var header = headers.GetAll("Cookie").FirstOrDefault(); + Assert.NotNull(header); + Assert.Contains("valid=y", header); + Assert.DoesNotContain("expired=x", header); + Assert.Single(factory.Cookies); + } + + [Fact] + public void DuplicateCookieKeysLastValidRemains() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["token"] = _middlewareMock.CreateCookie("token", "old", DateTimeOffset.UtcNow.AddMinutes(-5)); + factory.Cookies["token"] = _middlewareMock.CreateCookie("token", "new", DateTimeOffset.UtcNow.AddMinutes(5)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + var header = headers.GetAll("Cookie").FirstOrDefault(); + Assert.NotNull(header); + Assert.Contains("token=new", header); + } + + [Fact] + public void FalsePositiveValidDateButMarkedExpired() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["wrong"] = + _middlewareMock.CreateCookie("wrong", "v", DateTimeOffset.UtcNow.AddMinutes(10), expiredOverride: true); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + Assert.Empty(headers.GetAll("Cookie")); + } + + [Fact] + public async Task ConcurrentInsertRemoveDoesNotCorrupt() + { + var factory = _middlewareMock.CreateFactory(); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + + for (int i = 0; i < 100; i++) + factory.Cookies[$"cookie{i}"] = + _middlewareMock.CreateCookie($"cookie{i}", $"{i}", DateTimeOffset.UtcNow.AddMinutes(5)); + + var tasks = Enumerable.Range(0, 20).Select(_ => Task.Run(() => + { + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + foreach (var key in factory.Cookies.Keys) + factory.Cookies.TryRemove(key, out Cookie _); + })); + + await Task.WhenAll(tasks); + Assert.True(factory.Cookies.Count >= 0); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientInterceptorAdapterTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientInterceptorAdapterTests.cs new file mode 100644 index 00000000000..cb653a5c6ff --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientInterceptorAdapterTests.cs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Sql.Middleware.Interceptors; +using Apache.Arrow.Flight.Sql.Tests.Stubs; +using Apache.Arrow.Flight.Tests; +using Grpc.Core; +using Grpc.Core.Interceptors; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class ClientInterceptorAdapterTests +{ + private readonly TestWebFactory _testWebFactory; + private readonly FlightClient _client; + private readonly CapturingMiddlewareFactory _middlewareFactory; + + public ClientInterceptorAdapterTests() + { + _testWebFactory = new TestWebFactory(new InMemoryFlightStore()); + + _middlewareFactory = new CapturingMiddlewareFactory(); + var interceptor = new ClientInterceptorAdapter([_middlewareFactory]); + + _client = new FlightClient(_testWebFactory.GetChannel().Intercept(interceptor)); + } + + [Fact] + public async Task MiddlewareFlowIsCalledCorrectly() + { + // Arrange + var descriptor = FlightDescriptor.CreatePathDescriptor("test"); + + // Act + var info = await _client.GetInfo(descriptor); + var middleware = _middlewareFactory.Instance; + + // Assert + Assert.NotNull(info); + Assert.True(middleware.BeforeHeadersCalled, "BeforeHeaders not called"); + Assert.True(middleware.HeadersReceivedCalled, "HeadersReceived not called"); + Assert.True(middleware.CallCompletedCalled, "CallCompleted not called"); + } + + [Fact] + public async Task CookieAndHeaderValuesArePersistedThroughMiddleware() + { + // Arrange + var descriptor = FlightDescriptor.CreatePathDescriptor("test"); + + // Act + try + { + await _client.GetInfo(descriptor); + } + catch (RpcException) + { + // Expected: Flight not found, but middleware should have run + } + + // Assert Middleware captured the headers and cookies correctly + var middleware = _middlewareFactory.Instance; + + Assert.True(middleware.BeforeHeadersCalled, "OnBeforeSendingHeaders not called"); + Assert.True(middleware.HeadersReceivedCalled, "OnHeadersReceived not called"); + Assert.True(middleware.CallCompletedCalled, "OnCallCompleted not called"); + + // Validate Cookies captured correctly + Assert.True(middleware.CapturedHeaders.ContainsKey("cookie")); + var cookies = ParseCookies(middleware.CapturedHeaders["cookie"]); + + Assert.Equal("abc123", cookies["sessionId"]); + Assert.Equal("xyz789", cookies["token"]); + } + + private static IDictionary ParseCookies(string cookieHeader) + { + return cookieHeader.Split(';') + .Select(pair => pair.Split('=')) + .ToDictionary(parts => parts[0].Trim(), parts => parts[1].Trim()); + } +} + + \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddleware.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddleware.cs new file mode 100644 index 00000000000..bdd3d9e6c61 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddleware.cs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Sql.Middleware.Models; + +namespace Apache.Arrow.Flight.Sql.Tests.Stubs; + +public class CapturingMiddleware : IFlightClientMiddleware +{ + public Dictionary CapturedHeaders { get; } = new(); + + public bool BeforeHeadersCalled { get; private set; } + public bool HeadersReceivedCalled { get; private set; } + public bool CallCompletedCalled { get; private set; } + public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) + { + BeforeHeadersCalled = true; + outgoingHeaders.Insert("x-test-header", "test-value"); + outgoingHeaders.Insert("cookie", "sessionId=abc123; token=xyz789"); + CaptureHeaders(outgoingHeaders); + } + public void OnHeadersReceived(ICallHeaders incomingHeaders) + { + HeadersReceivedCalled = true; + CaptureHeaders(incomingHeaders); + } + + public void OnCallCompleted(CallStatus status) + { + CallCompletedCalled = true; + } + + private void CaptureHeaders(ICallHeaders headers) + { + foreach (var key in headers.Keys) + { + var value = headers.Get(key); + if (value != null) + { + CapturedHeaders[key] = value; + } + } + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddlewareFactory.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddlewareFactory.cs new file mode 100644 index 00000000000..a3aa652b81f --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddlewareFactory.cs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Sql.Middleware.Models; + +namespace Apache.Arrow.Flight.Sql.Tests.Stubs; + +public class CapturingMiddlewareFactory : IFlightClientMiddlewareFactory +{ + public CapturingMiddleware Instance { get; } = new(); + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) => Instance; +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs new file mode 100644 index 00000000000..8b86b57d21b --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Net; +using Apache.Arrow.Flight.Sql.Middleware.Middleware; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Sql.Tests.Stubs; + +internal class ClientCookieMiddlewareMock +{ + public Cookie CreateCookie(string name, string value, DateTimeOffset? expires = null, bool? expiredOverride = null) + { + return new Cookie + { + Name = name, + Value = value, + Expires = expires!.Value.UtcDateTime, + Expired = expiredOverride ?? (expires.HasValue && expires.Value < DateTimeOffset.UtcNow) + }; + } + + public ClientCookieMiddleware.ClientCookieMiddlewareFactory CreateFactory() + { + return new ClientCookieMiddleware.ClientCookieMiddlewareFactory(new TestLoggerFactory()); + } + + public class TestLogger : ILogger + { + public IDisposable BeginScope(TState state) => null; + public bool IsEnabled(LogLevel logLevel) => false; + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, + Func formatter) + { + } + } + + internal class TestLoggerFactory : ILoggerFactory + { + public void AddProvider(ILoggerProvider provider) + { + } + + public ILogger CreateLogger(string categoryName) => new TestLogger(); + + public void Dispose() + { + } + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryCallHeaders.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryCallHeaders.cs new file mode 100644 index 00000000000..2077cb08263 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryCallHeaders.cs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class InMemoryCallHeaders : ICallHeaders +{ + private readonly Dictionary> _stringHeaders = new(); + private readonly Dictionary> _byteHeaders = new(); + + public string this[string key] => Get(key); + + public string Get(string key) => + _stringHeaders.TryGetValue(key, out var values) ? values.LastOrDefault() : null; + + public byte[] GetBytes(string key) => + _byteHeaders.TryGetValue(key, out var values) + ? values.LastOrDefault() + : null; + + public IEnumerable GetAll(string key) => + _stringHeaders.TryGetValue(key, out var values) + ? values + : Enumerable.Empty(); + + public IEnumerable GetAllBytes(string key) => + _byteHeaders.TryGetValue(key, out var values) + ? values + : Enumerable.Empty(); + + public void Insert(string key, string value) + { + if (!_stringHeaders.TryGetValue(key, out var list)) + _stringHeaders[key] = list = new List(); + list.Add(value); + } + + public void Insert(string key, byte[] value) + { + if (!_byteHeaders.TryGetValue(key, out var list)) + _byteHeaders[key] = list = new List(); + list.Add(value); + } + + public ISet Keys => + (HashSet) [.._stringHeaders.Keys.Concat(_byteHeaders.Keys)]; + + public bool ContainsKey(string key) => + _stringHeaders.ContainsKey(key) || _byteHeaders.ContainsKey(key); +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryFlightStore.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryFlightStore.cs new file mode 100644 index 00000000000..619b4aa3b0f --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryFlightStore.cs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.TestWeb; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Flight.Sql.Tests.Stubs; + +public class InMemoryFlightStore : FlightStore +{ + public InMemoryFlightStore() + { + // Pre-register a dummy flight so GetFlightInfo can resolve it + var descriptor = FlightDescriptor.CreatePathDescriptor("test"); + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + + var recordBatch = new RecordBatch(schema, new Array[] + { + new Int32Array.Builder().Append(1).Build(), + new StringArray.Builder().Append("John Doe").Build() + }, 1); + + var location = new FlightLocation("grpc+tcp://localhost:50051"); + + var flightHolder = new FlightHolder(descriptor, schema, location.Uri); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + Flights.Add(descriptor, flightHolder); + } + + public override string ToString() + { + return $"InMemoryFlightStore(Flights={Flights.Count})"; + } +} \ No newline at end of file From 26fa098f9b7f3d403c059b9174658555b331d995 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 27 Apr 2025 17:33:35 +0300 Subject: [PATCH 02/10] feat: headers + cookie --- csharp/Directory.Build.props | 2 +- .../Client/FlightSqlClient.cs | 959 ++++++++++++++++++ .../FlightCallOptions.cs | 38 + .../FlightExtensions.cs | 41 + .../FlightSqlServer.cs | 1 + .../ClientCookieMiddleware.cs | 59 +- .../ClientCookieMiddlewareFactory.cs | 61 ++ .../Middleware/CookieManager.cs | 159 +++ .../Middleware/Grpc/FlightMethodParser.cs | 2 +- .../Middleware/Grpc/StatusUtils.cs | 2 +- .../Interceptors/ClientInterceptorAdapter.cs | 40 +- .../PreparedStatement.cs | 382 +++++++ .../SchemaExtensions.cs | 50 + .../src/Apache.Arrow.Flight.Sql/SqlActions.cs | 5 + .../src/Apache.Arrow.Flight.Sql/TableRef.cs | 38 + .../Apache.Arrow.Flight.Sql/Transaction.cs | 58 ++ .../FlightInfoCancelRequest.cs | 53 + .../FlightInfoCancelResult.cs | 52 + .../Apache.Arrow.Flight.Sql.Tests.csproj | 6 +- .../ClientCookieMiddlewareTests.cs | 2 +- .../FlightSqlClientTests.cs | 855 ++++++++++++++++ .../FlightSqlPreparedStatementTests.cs | 226 +++++ .../FlightSqlTestExtensions.cs | 240 +++++ .../FlightSqlTestUtils.cs | 63 ++ .../Apache.Arrow.Flight.Sql.Tests/Startup.cs | 57 ++ .../Stubs/ClientCookieMiddlewareMock.cs | 6 +- .../TestFlightSqlWebFactory.cs | 82 ++ .../Apache.Arrow.Flight.TestWeb.csproj | 1 + .../FlightHolder.cs | 19 +- .../TestFlightSqlServer.cs | 162 +++ 30 files changed, 3644 insertions(+), 77 deletions(-) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs rename csharp/src/Apache.Arrow.Flight.Sql/Middleware/{Middleware => }/ClientCookieMiddleware.cs (55%) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/ClientCookieMiddlewareFactory.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Middleware/CookieManager.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs create mode 100644 csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs create mode 100644 csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs create mode 100644 csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs diff --git a/csharp/Directory.Build.props b/csharp/Directory.Build.props index 474c5773460..2bfe93bdceb 100644 --- a/csharp/Directory.Build.props +++ b/csharp/Directory.Build.props @@ -29,7 +29,7 @@ Apache Arrow library Copyright 2016-2024 The Apache Software Foundation The Apache Software Foundation - 20.0.0-SNAPSHOT + 20.0.11-SNAPSHOT diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs new file mode 100644 index 00000000000..20c53660299 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -0,0 +1,959 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql.Client; + +public class FlightSqlClient +{ + private readonly FlightClient _client; + + public FlightSqlClient(FlightClient client) + { + _client = client; + } + + /// + /// Execute a SQL query on the server. + /// + /// The UTF8-encoded SQL query to be executed. + /// A transaction to associate this query with. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task ExecuteAsync(string query, Transaction transaction = default, FlightCallOptions? options = null) + { + if (transaction == default) + { + transaction = Transaction.NoTransaction; + } + + if (string.IsNullOrEmpty(query)) + { + throw new ArgumentException($"Query cannot be null or empty: {nameof(query)}"); + } + + try + { + var commandQuery = new CommandStatementQuery { Query = query }; + + if (transaction.IsValid()) + { + commandQuery.TransactionId = transaction.TransactionId; + } + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandQuery.PackAndSerialize()); + return await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to execute query", ex); + } + } + + /// + /// Executes an update SQL command and returns the number of affected rows. + /// + /// The UTF8-encoded SQL query to be executed. + /// A transaction to associate this query with. Defaults to no transaction if not provided. + /// RPC-layer hints for this call. + /// The number of rows affected by the operation. + public async Task ExecuteUpdateAsync(string query, Transaction transaction = default, FlightCallOptions? options = null) + { + if (transaction == default) + { + transaction = Transaction.NoTransaction; + } + + if (string.IsNullOrEmpty(query)) + { + throw new ArgumentException("Query cannot be null or empty", nameof(query)); + } + + try + { + var updateRequestCommand = + new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; + byte[] serializedUpdateRequestCommand = updateRequestCommand.PackAndSerialize(); + var action = new FlightAction(SqlAction.CreateRequest, serializedUpdateRequestCommand); + var call = DoActionAsync(action, options); + long affectedRows = 0; + + await foreach (var result in call.ConfigureAwait(false)) + { + var preparedStatementResponse = result.Body.ParseAndUnpack(); + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options); + var doGetResult = DoGetAsync(flightInfo.Endpoints[0].Ticket, options); + + await foreach (var recordBatch in doGetResult.ConfigureAwait(false)) + { + affectedRows += recordBatch.ExtractRowCount(); + } + } + + return affectedRows; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to execute update query", ex); + } + } + + /// + /// Asynchronously retrieves flight information for a given flight descriptor. + /// + /// The descriptor of the dataset request, whether a named dataset or a command. + /// RPC-layer hints for this call. + /// A task that represents the asynchronous operation. The task result contains the FlightInfo describing where to access the dataset. + public async Task GetFlightInfoAsync(FlightDescriptor descriptor, FlightCallOptions? options = default) + { + if (descriptor is null) + { + throw new ArgumentNullException(nameof(descriptor)); + } + + try + { + var flightInfoCall = _client.GetInfo(descriptor, options?.Headers); + var flightInfo = await flightInfoCall.ResponseAsync.ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get flight info", ex); + } + } + + /// + /// Perform the indicated action, returning an iterator to the stream of results, if any. + /// + /// The action to be performed + /// Per-RPC options + /// An async enumerable of results + public async IAsyncEnumerable DoActionAsync(FlightAction action, FlightCallOptions? options = default) + { + if (action is null) + throw new ArgumentNullException(nameof(action)); + + var call = _client.DoAction(action, options?.Headers); + + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + { + yield return result; + } + } + + /// + /// Get the result set schema from the server for the given query. + /// + /// The UTF8-encoded SQL query + /// A transaction to associate this query with + /// Per-RPC options + /// The SchemaResult describing the schema of the result set + public async Task GetExecuteSchemaAsync(string query, Transaction transaction = default, FlightCallOptions? options = null) + { + if (transaction == default) + { + transaction = Transaction.NoTransaction; + } + + if (string.IsNullOrEmpty(query)) + throw new ArgumentException($"Query cannot be null or empty: {nameof(query)}"); + try + { + var prepareStatementRequest = + new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; + var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); + var call = _client.DoAction(action, options?.Headers); + + var preparedStatementResponse = await ReadPreparedStatementAsync(call).ConfigureAwait(false); + + if (preparedStatementResponse.PreparedStatementHandle.IsEmpty) + throw new InvalidOperationException("Received an empty or invalid PreparedStatementHandle."); + var commandSqlCall = new CommandPreparedStatementQuery + { + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCall.PackAndSerialize()); + var schemaResult = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return schemaResult.Schema; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get execute schema", ex); + } + } + + /// + /// Request a list of catalogs. + /// + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetCatalogsAsync(FlightCallOptions? options = default) + { + try + { + var command = new CommandGetCatalogs(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var catalogsInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return catalogsInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get catalogs", ex); + } + } + + /// + /// Get the catalogs schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the catalogs. + public async Task GetCatalogsSchemaAsync(FlightCallOptions? options = default) + { + try + { + var commandGetCatalogsSchema = new CommandGetCatalogs(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCatalogsSchema.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get catalogs schema", ex); + } + } + + /// + /// Asynchronously retrieves schema information for a given flight descriptor. + /// + /// The descriptor of the dataset request, whether a named dataset or a command. + /// RPC-layer hints for this call. + /// A task that represents the asynchronous operation. The task result contains the SchemaResult describing the dataset schema. + public virtual async Task GetSchemaAsync(FlightDescriptor descriptor, FlightCallOptions? options = default) + { + if (descriptor is null) + { + throw new ArgumentNullException(nameof(descriptor)); + } + + try + { + var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); + var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get schema", ex); + } + } + + /// + /// Request a list of database schemas. + /// + /// RPC-layer hints for this call. + /// The catalog. + /// The schema filter pattern. + /// The FlightInfo describing where to access the dataset. + public async Task GetDbSchemasAsync(string? catalog = null, string? dbSchemaFilterPattern = null, FlightCallOptions? options = default) + { + try + { + var command = new CommandGetDbSchemas(); + + if (catalog != null) + { + command.Catalog = catalog; + } + + if (dbSchemaFilterPattern != null) + { + command.DbSchemaFilterPattern = dbSchemaFilterPattern; + } + + byte[] serializedAndPackedCommand = command.PackAndSerialize(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(serializedAndPackedCommand); + var flightInfoCall = GetFlightInfoAsync(descriptor, options); + var flightInfo = await flightInfoCall.ConfigureAwait(false); + + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get database schemas", ex); + } + } + + /// + /// Get the database schemas schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the database schemas. + public async Task GetDbSchemasSchemaAsync(FlightCallOptions? options = default) + { + try + { + var command = new CommandGetDbSchemas(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get database schemas schema", ex); + } + } + + /// + /// Given a flight ticket and schema, request to be sent the stream. Returns record batch stream reader. + /// + /// The flight ticket to use + /// Per-RPC options + /// The returned RecordBatchReader + public async IAsyncEnumerable DoGetAsync(FlightTicket ticket, FlightCallOptions? options = default) + { + if (ticket == null) + { + throw new ArgumentNullException(nameof(ticket)); + } + + var call = _client.GetStream(ticket, options?.Headers); + await foreach (var recordBatch in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + { + yield return recordBatch; + } + } + + /// + /// Upload data to a Flight described by the given descriptor. The caller must call Close() on the returned stream + /// once they are done writing. + /// + /// The descriptor of the stream. + /// The record for the data to upload. + /// RPC-layer hints for this call. + /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. + public async Task DoPutAsync(FlightDescriptor descriptor, RecordBatch recordBatch, FlightCallOptions? options = default) + { + if (descriptor is null) + throw new ArgumentNullException(nameof(descriptor)); + + if (recordBatch is null) + throw new ArgumentNullException(nameof(recordBatch)); + try + { + var doPutResult = _client.StartPut(descriptor, options?.Headers); + var writer = doPutResult.RequestStream; + var reader = doPutResult.ResponseStream; + + if (recordBatch == null || recordBatch.Length == 0) + throw new InvalidOperationException("RecordBatch is empty or improperly initialized."); + + await writer.WriteAsync(recordBatch).ConfigureAwait(false); + await writer.CompleteAsync().ConfigureAwait(false); + + if (await reader.MoveNext().ConfigureAwait(false)) + { + var putResult = reader.Current; + return new FlightPutResult(putResult.ApplicationMetadata); + } + return FlightPutResult.Empty; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to perform DoPut operation", ex); + } + } + + /// + /// Request the primary keys for a table. + /// + /// The table reference. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetPrimaryKeysAsync(TableRef tableRef, FlightCallOptions? options = default) + { + if (tableRef == null) + throw new ArgumentNullException(nameof(tableRef)); + + try + { + var getPrimaryKeysRequest = new CommandGetPrimaryKeys + { + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(getPrimaryKeysRequest.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get primary keys", ex); + } + } + + /// + /// Request a list of tables. + /// + /// The catalog. + /// The schema filter pattern. + /// The table filter pattern. + /// True to include the schema upon return, false to not include the schema. + /// The table types to include. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task> GetTablesAsync(string? catalog = null, string? dbSchemaFilterPattern = null, string? tableFilterPattern = null, bool includeSchema = false, IEnumerable? tableTypes = null, FlightCallOptions? options = default) + { + var command = new CommandGetTables + { + Catalog = catalog ?? string.Empty, + DbSchemaFilterPattern = dbSchemaFilterPattern ?? string.Empty, + TableNameFilterPattern = tableFilterPattern ?? string.Empty, + IncludeSchema = includeSchema + }; + + if (tableTypes != null) + { + command.TableTypes.AddRange(tableTypes); + } + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfoCall = GetFlightInfoAsync(descriptor, options); + var flightInfo = await flightInfoCall.ConfigureAwait(false); + var flightInfos = new List { flightInfo }; + + return flightInfos; + } + + + /// + /// Retrieves a description about the foreign key columns that reference the primary key columns of the given table. + /// + /// The table reference. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetExportedKeysAsync(TableRef tableRef, FlightCallOptions? options = default) + { + if (tableRef == null) + throw new ArgumentNullException(nameof(tableRef)); + + try + { + var getExportedKeysRequest = new CommandGetExportedKeys + { + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(getExportedKeysRequest.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get exported keys", ex); + } + } + + /// + /// Get the exported keys schema from the server. + /// + /// The table reference. + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the exported keys. + public async Task GetExportedKeysSchemaAsync(TableRef tableRef, FlightCallOptions? options = default) + { + try + { + var commandGetExportedKeysSchema = new CommandGetExportedKeys + { + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table + }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetExportedKeysSchema.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get exported keys schema", ex); + } + } + + /// + /// Retrieves the foreign key columns for the given table. + /// + /// The table reference. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetImportedKeysAsync(TableRef tableRef, FlightCallOptions? options = default) + { + if (tableRef == null) + throw new ArgumentNullException(nameof(tableRef)); + + try + { + var getImportedKeysRequest = new CommandGetImportedKeys + { + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table + }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(getImportedKeysRequest.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get imported keys", ex); + } + } + + /// + /// Get the imported keys schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the imported keys. + public async Task GetImportedKeysSchemaAsync(FlightCallOptions? options = default) + { + try + { + var commandGetImportedKeysSchema = new CommandGetImportedKeys(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetImportedKeysSchema.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get imported keys schema", ex); + } + } + + /// + /// Retrieves a description of the foreign key columns in the given foreign key table that reference the primary key or the columns representing a unique constraint of the parent table. + /// + /// The table reference that exports the key. + /// The table reference that imports the key. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetCrossReferenceAsync(TableRef pkTableRef, TableRef fkTableRef, FlightCallOptions? options = default) + { + if (pkTableRef == null) + throw new ArgumentNullException(nameof(pkTableRef)); + + if (fkTableRef == null) + throw new ArgumentNullException(nameof(fkTableRef)); + + try + { + var commandGetCrossReference = new CommandGetCrossReference + { + PkCatalog = pkTableRef.Catalog ?? string.Empty, + PkDbSchema = pkTableRef.DbSchema, + PkTable = pkTableRef.Table, + FkCatalog = fkTableRef.Catalog ?? string.Empty, + FkDbSchema = fkTableRef.DbSchema, + FkTable = fkTableRef.Table + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCrossReference.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get cross reference", ex); + } + } + + /// + /// Get the cross-reference schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the cross-reference. + public async Task GetCrossReferenceSchemaAsync(FlightCallOptions? options = default) + { + try + { + var commandGetCrossReferenceSchema = new CommandGetCrossReference(); + var descriptor = + FlightDescriptor.CreateCommandDescriptor(commandGetCrossReferenceSchema.PackAndSerialize()); + var schemaResultCall = GetSchemaAsync(descriptor, options); + var schemaResult = await schemaResultCall.ConfigureAwait(false); + + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get cross-reference schema", ex); + } + } + + /// + /// Request a list of table types. + /// + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetTableTypesAsync(FlightCallOptions? options = default) + { + try + { + var command = new CommandGetTableTypes(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get table types", ex); + } + } + + /// + /// Get the table types schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the table types. + public async Task GetTableTypesSchemaAsync(FlightCallOptions? options = default) + { + try + { + var command = new CommandGetTableTypes(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get table types schema", ex); + } + } + + /// + /// Request the information about all the data types supported with filtering by data type. + /// + /// The data type to search for as filtering. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetXdbcTypeInfoAsync(int dataType, FlightCallOptions? options = default) + { + try + { + var command = new CommandGetXdbcTypeInfo { DataType = dataType }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get XDBC type info", ex); + } + } + + /// + /// Request the information about all the data types supported. + /// + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetXdbcTypeInfoAsync(FlightCallOptions? options = default) + { + try + { + var command = new CommandGetXdbcTypeInfo(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get XDBC type info", ex); + } + } + + /// + /// Get the type info schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the type info. + public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions? options = default) + { + try + { + var command = new CommandGetXdbcTypeInfo(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get XDBC type info schema", ex); + } + } + + /// + /// Request a list of SQL information. + /// + /// The SQL info required. + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetSqlInfoAsync(List? sqlInfo = default, FlightCallOptions? options = default) + { + sqlInfo ??= new List(); + try + { + var command = new CommandGetSqlInfo(); + command.Info.AddRange(sqlInfo.ConvertAll(item => (uint)item)); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get SQL info", ex); + } + } + + /// + /// Get the SQL information schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the SQL information. + public async Task GetSqlInfoSchemaAsync(FlightCallOptions? options = default) + { + try + { + var command = new CommandGetSqlInfo(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); + var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); + + return schemaResult; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to get SQL info schema", ex); + } + } + + /// + /// Explicitly cancel a FlightInfo. + /// + /// The CancelFlightInfoRequest. + /// RPC-layer hints for this call. + /// A Task representing the asynchronous operation. The task result contains the CancelFlightInfoResult describing the canceled result. + public async Task CancelFlightInfoAsync(FlightInfoCancelRequest request, FlightCallOptions? options = default) + { + if (request == null) throw new ArgumentNullException(nameof(request)); + + try + { + var action = new FlightAction(SqlAction.CancelFlightInfoRequest, request.PackAndSerialize()); + var call = _client.DoAction(action, options?.Headers); + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + { + if (Any.Parser.ParseFrom(result.Body) is Any anyResult && + anyResult.TryUnpack(out FlightInfoCancelResult cancelResult)) + { + return cancelResult; + } + } + + throw new InvalidOperationException("No response received for the CancelFlightInfo request."); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to cancel flight info", ex); + } + } + + /// + /// Explicitly cancel a query. + /// + /// The FlightInfo of the query to cancel. + /// RPC-layer hints for this call. + /// A Task representing the asynchronous operation. + public async Task CancelQueryAsync(FlightInfo info, FlightCallOptions? options = default) + { + if (info == null) + throw new ArgumentNullException(nameof(info)); + + try + { + var cancelQueryRequest = new FlightInfoCancelRequest(info); + var cancelQueryAction = + new FlightAction(SqlAction.CancelFlightInfoRequest, cancelQueryRequest.PackAndSerialize()); + var cancelQueryCall = _client.DoAction(cancelQueryAction, options?.Headers); + + await foreach (var result in cancelQueryCall.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + { + if (Any.Parser.ParseFrom(result.Body) is Any anyResult && + anyResult.TryUnpack(out FlightInfoCancelResult cancelResult)) + { + return cancelResult; + } + } + throw new InvalidOperationException("Failed to cancel query: No response received."); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to cancel query", ex); + } + } + + /// + /// Begin a new transaction. + /// + /// RPC-layer hints for this call. + /// A Task representing the asynchronous operation. The task result contains the Transaction object representing the new transaction. + public async Task BeginTransactionAsync(FlightCallOptions? options = default) + { + try + { + var actionBeginTransaction = new ActionBeginTransactionRequest(); + var action = new FlightAction(SqlAction.BeginTransactionRequest, actionBeginTransaction.PackAndSerialize()); + var responseStream = _client.DoAction(action, options?.Headers); + await foreach (var result in responseStream.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + { + string? beginTransactionResult = result.Body.ToStringUtf8(); + return new Transaction(beginTransactionResult); + } + + throw new InvalidOperationException("Failed to begin transaction: No response received."); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to begin transaction", ex); + } + } + + /// + /// Commit a transaction. + /// After this, the transaction and all associated savepoints will be invalidated. + /// + /// The transaction. + /// RPC-layer hints for this call. + /// A Task representing the asynchronous operation. + public AsyncServerStreamingCall CommitAsync(Transaction transaction, FlightCallOptions? options = default) + { + if (transaction == null) + throw new ArgumentNullException(nameof(transaction)); + + try + { + var actionCommit = new FlightAction(SqlAction.CommitRequest, transaction.TransactionId); + return _client.DoAction(actionCommit, options?.Headers); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to commit transaction", ex); + } + } + + /// + /// Rollback a transaction. + /// After this, the transaction and all associated savepoints will be invalidated. + /// + /// The transaction to rollback. + /// RPC-layer hints for this call. + /// A Task representing the asynchronous operation. + public AsyncServerStreamingCall RollbackAsync(Transaction transaction, FlightCallOptions? options = default) + { + if (transaction == null) + { + throw new ArgumentNullException(nameof(transaction)); + } + + try + { + var actionRollback = new FlightAction(SqlAction.RollbackRequest, transaction.TransactionId); + return _client.DoAction(actionRollback, options?.Headers); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to rollback transaction", ex); + } + } + + /// + /// Create a prepared statement object. + /// + /// The query that will be executed. + /// A transaction to associate this query with. + /// RPC-layer hints for this call. + /// The created prepared statement. + public async Task PrepareAsync(string query, Transaction transaction = default, FlightCallOptions? options = null) + { + if (string.IsNullOrEmpty(query)) + throw new ArgumentException("Query cannot be null or empty", nameof(query)); + + if (transaction == default) + { + transaction = Transaction.NoTransaction; + } + + try + { + var command = new ActionCreatePreparedStatementRequest + { + Query = query + }; + + if (transaction.IsValid()) + { + command.TransactionId = transaction.TransactionId; + } + + var action = new FlightAction(SqlAction.CreateRequest, command.PackAndSerialize()); + var call = _client.DoAction(action, options?.Headers); + var preparedStatementResponse = await ReadPreparedStatementAsync(call).ConfigureAwait(false); + + + return new PreparedStatement(this, + preparedStatementResponse.PreparedStatementHandle.ToStringUtf8(), + SchemaExtensions.DeserializeSchema(preparedStatementResponse.DatasetSchema.ToByteArray()), + SchemaExtensions.DeserializeSchema(preparedStatementResponse.ParameterSchema.ToByteArray()) + ); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to prepare statement", ex); + } + } + + private static async Task ReadPreparedStatementAsync( + AsyncServerStreamingCall call) + { + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + { + var response = Any.Parser.ParseFrom(result.Body); + if (response.Is(ActionCreatePreparedStatementResult.Descriptor)) + { + return response.Unpack(); + } + } + throw new InvalidOperationException("Server did not return a valid prepared statement response."); + } +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs new file mode 100644 index 00000000000..17541b26e73 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Buffers; +using System.Threading; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql; + +public class FlightCallOptions +{ + public FlightCallOptions() + { + Timeout = TimeSpan.FromSeconds(-1); + } + + // Implement any necessary options for RPC calls + public Metadata Headers { get; set; } = new(); + + /// + /// Gets or sets the optional timeout for this call. + /// Negative durations mean an implementation-defined default behavior will be used instead. + /// + public TimeSpan Timeout { get; set; } +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs new file mode 100644 index 00000000000..a0936eb7f34 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; + +namespace Apache.Arrow.Flight.Sql; + +internal static class FlightExtensions +{ + public static byte[] PackAndSerialize(this IMessage command) => Any.Pack(command).ToByteArray(); + + public static T ParseAndUnpack(this ByteString source) where T : IMessage, new() => + Any.Parser.ParseFrom(source).Unpack(); + + public static int ExtractRowCount(this RecordBatch batch) + { + if (batch.ColumnCount == 0) return 0; + int length = batch.Column(0).Length; + foreach (var column in batch.Arrays) + { + if (column.Length != length) + throw new InvalidOperationException("Inconsistent column lengths in RecordBatch."); + } + + return length; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs index cbccc9dab34..227ea4812a6 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs @@ -285,6 +285,7 @@ public override Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWri public override Task DoAction(FlightAction action, IAsyncStreamWriter responseStream, ServerCallContext context) { Logger?.LogTrace("Executing Flight SQL DoAction: {ActionType}", action.Type); + switch (action.Type) { case SqlAction.CreateRequest: diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/ClientCookieMiddleware.cs similarity index 55% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs rename to csharp/src/Apache.Arrow.Flight.Sql/Middleware/ClientCookieMiddleware.cs index d56fa61d85b..a65005d0a7a 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/ClientCookieMiddleware.cs @@ -13,27 +13,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -using System; -using System.Collections.Concurrent; using System.Collections.Generic; -using System.Globalization; -using System.Net; -using Apache.Arrow.Flight.Sql.Middleware.Extensions; +using System.Linq; using Apache.Arrow.Flight.Sql.Middleware.Interfaces; using Apache.Arrow.Flight.Sql.Middleware.Models; using Microsoft.Extensions.Logging; -namespace Apache.Arrow.Flight.Sql.Middleware.Middleware; +namespace Apache.Arrow.Flight.Sql.Middleware; public class ClientCookieMiddleware : IFlightClientMiddleware { private readonly ClientCookieMiddlewareFactory _factory; private readonly ILogger _logger; - private const string SET_COOKIE_HEADER = "Set-cookie"; + private const string SET_COOKIE_HEADER = "Set-Cookie"; private const string COOKIE_HEADER = "Cookie"; - private readonly ConcurrentDictionary _cookies = new(); - public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, ILogger logger) { @@ -55,7 +49,8 @@ public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) public void OnHeadersReceived(ICallHeaders incomingHeaders) { var setCookieHeaders = incomingHeaders.GetAll(SET_COOKIE_HEADER); - _factory.UpdateCookies(setCookieHeaders); + var cookieHeaders = incomingHeaders.GetAll(COOKIE_HEADER); + _factory.UpdateCookies(setCookieHeaders.Concat(cookieHeaders)); _logger.LogInformation("Received Headers: " + string.Join(", ", incomingHeaders.Keys)); } @@ -82,48 +77,4 @@ private string GetValidCookiesAsString() return string.Join("; ", cookieList); } - public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory - { - public readonly ConcurrentDictionary Cookies = new(StringComparer.OrdinalIgnoreCase); - private readonly ILoggerFactory _loggerFactory; - - public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) - { - _loggerFactory = loggerFactory; - } - - public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) - { - var logger = _loggerFactory.CreateLogger(); - return new ClientCookieMiddleware(this, logger); - } - - internal void UpdateCookies(IEnumerable newCookieHeaderValues) - { - foreach (var headerValue in newCookieHeaderValues) - { - try - { - var parsedCookies = headerValue.ParseHeader(); - foreach (var parsedCookie in parsedCookies) - { - var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); - if (parsedCookie.Expired) - { - Cookies.TryRemove(nameLc, out _); - } - else - { - Cookies[nameLc] = parsedCookie; - } - } - } - catch (FormatException ex) - { - var logger = _loggerFactory.CreateLogger(); - logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); - } - } - } - } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/ClientCookieMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/ClientCookieMiddlewareFactory.cs new file mode 100644 index 00000000000..3ed1731e4fc --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/ClientCookieMiddlewareFactory.cs @@ -0,0 +1,61 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using System.Threading; +using Apache.Arrow.Flight.Sql.Middleware.Extensions; +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Sql.Middleware.Models; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Sql.Middleware; + +public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory +{ + public readonly ConcurrentDictionary Cookies = new(StringComparer.OrdinalIgnoreCase); + private readonly ILoggerFactory _loggerFactory; + private int _createdInstances = 0; + public int CreatedInstances => _createdInstances; + + public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; + } + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) + { + int callNumber = Interlocked.Increment(ref _createdInstances); + var logger = _loggerFactory.CreateLogger(); + logger.LogInformation($"Creating ClientCookieMiddleware #{callNumber} for {callInfo.MethodName}"); + return new ClientCookieMiddleware(this, logger); + } + + public void UpdateCookies(IEnumerable newCookieHeaderValues) + { + foreach (var headerValue in newCookieHeaderValues) + { + try + { + var parsedCookies = headerValue.ParseHeader(); + foreach (var parsedCookie in parsedCookies) + { + var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); + if (parsedCookie.Expired) + { + Cookies.TryRemove(nameLc, out _); + } + else + { + Cookies[nameLc] = parsedCookie; + } + } + } + catch (FormatException ex) + { + var logger = _loggerFactory.CreateLogger(); + logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); + } + } + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/CookieManager.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/CookieManager.cs new file mode 100644 index 00000000000..90b68fe185d --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/CookieManager.cs @@ -0,0 +1,159 @@ +/*using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Net; +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Sql.Middleware.Extensions; +using Apache.Arrow.Flight.Sql.Middleware.Interceptors; +using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Sql.Middleware.Models; +using Grpc.Core; +using Grpc.Core.Interceptors; +using Grpc.Net.Client; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Sql.Middleware; + +public class CookieManager +{ + private readonly ILogger _logger; + public ConcurrentDictionary Cookies { get; } = new(StringComparer.OrdinalIgnoreCase); + + public CookieManager(ILoggerFactory loggerFactory) + { + _logger = loggerFactory.CreateLogger(); + } + + public void UpdateCookies(IEnumerable cookieHeaders) + { + foreach (var header in cookieHeaders) + { + try + { + var cookies = header.ParseHeader(); + foreach (var cookie in cookies) + { + if (cookie.Expired) + Cookies.TryRemove(cookie.Name.ToLowerInvariant(), out _); + else + Cookies[cookie.Name.ToLowerInvariant()] = cookie; + } + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{Header}'", header); + } + } + } + + public string GetCookieHeader() + { + var validCookies = + Cookies.Values.Where(c => !c.Expired).Select(c => $"{c.Name}={c.Value}"); + return string.Join("; ", validCookies); + } +} + +public class ClientCookieMiddleware : IFlightClientMiddleware +{ + private readonly CookieManager _cookieManager; + private readonly ILogger _logger; + + public ClientCookieMiddleware(CookieManager cookieManager, ILogger logger) + { + _cookieManager = cookieManager; + _logger = logger; + } + + public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) + { + var cookieHeader = _cookieManager.GetCookieHeader(); + if (!string.IsNullOrEmpty(cookieHeader)) + { + outgoingHeaders.Insert("cookie", cookieHeader); + } + + _logger.LogInformation("Sending cookies: {CookieHeader}", cookieHeader); + } + + public void OnHeadersReceived(ICallHeaders incomingHeaders) + { + var setCookie = incomingHeaders.GetAll("set-cookie"); + var xCookie = incomingHeaders.GetAll("x-cookie"); + + _cookieManager.UpdateCookies(setCookie.Concat(xCookie)); + + _logger.LogInformation("Received Headers: {Keys}", string.Join(", ", incomingHeaders.Keys)); + } + + public void OnCallCompleted(CallStatus status) + { + _logger.LogInformation("Call completed: {Status}", status); + } +} + +public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory +{ + private readonly CookieManager _cookieManager; + private readonly ILoggerFactory _loggerFactory; + + public ClientCookieMiddlewareFactory(CookieManager cookieManager, ILoggerFactory loggerFactory) + { + _cookieManager = cookieManager; + _loggerFactory = loggerFactory; + } + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) + { + return new ClientCookieMiddleware(_cookieManager, _loggerFactory.CreateLogger()); + } + + public void UpdateCookies(IEnumerable newCookieHeaderValues) + { + foreach (var headerValue in newCookieHeaderValues) + { + try + { + var parsedCookies = headerValue.ParseHeader(); + foreach (var parsedCookie in parsedCookies) + { + var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); + if (parsedCookie.Expired) + { + _cookieManager.Cookies.TryRemove(nameLc, out _); + } + else + { + _cookieManager.Cookies[nameLc] = parsedCookie; + } + } + } + catch (FormatException ex) + { + var logger = _loggerFactory.CreateLogger(); + logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); + } + } + } +} + +public static class FlightClientFactory +{ + public static FlightClient Create(string address, CookieManager cookieManager, ILoggerFactory loggerFactory) + { + var channel = GrpcChannel.ForAddress(address, new GrpcChannelOptions + { + Credentials = ChannelCredentials.Insecure, + MaxReceiveMessageSize = 100 * 1024 * 1024 + }); + + var invoker = channel.Intercept( + new ClientInterceptorAdapter([ + new ClientCookieMiddlewareFactory(cookieManager, loggerFactory) + ])); + + return new FlightClient(invoker); + } +}*/ \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs index a73921b182f..0771ecc821c 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs @@ -15,7 +15,7 @@ using Apache.Arrow.Flight.Sql.Middleware.Models; -namespace Apache.Arrow.Flight.Sql.Middleware.Gprc; +namespace Apache.Arrow.Flight.Sql.Middleware.Grpc; // TODO: Add tests to cover: FlightMethodParser public static class FlightMethodParser diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs index 9485167f76a..e4390e0c0a0 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs @@ -16,7 +16,7 @@ using Apache.Arrow.Flight.Sql.Middleware.Models; using Grpc.Core; -namespace Apache.Arrow.Flight.Sql.Middleware.Gprc; +namespace Apache.Arrow.Flight.Sql.Middleware.Grpc; public static class StatusUtils { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs index 576f3f72add..6aca028513e 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs @@ -16,7 +16,6 @@ using System; using System.Collections.Generic; using System.Linq; -using Apache.Arrow.Flight.Sql.Middleware.Gprc; using Apache.Arrow.Flight.Sql.Middleware.Grpc; using Apache.Arrow.Flight.Sql.Middleware.Interfaces; using Apache.Arrow.Flight.Sql.Middleware.Models; @@ -55,12 +54,14 @@ public override AsyncUnaryCall AsyncUnaryCall( throw new RpcException(new Status(StatusCode.Internal, "Middleware creation failed"), e.Message); } + AddCallerMetadata(ref context); + // Apply middleware headers - var middlewareHeaders = new Metadata(); - var headerAdapter = new MetadataAdapter(middlewareHeaders); + var headers = context.Options.Headers ?? new Metadata(); + var adapter = new MetadataAdapter(headers); foreach (var m in middleware) { - m.OnBeforeSendingHeaders(headerAdapter); + m.OnBeforeSendingHeaders(adapter); } // Merge original headers with middleware headers @@ -73,11 +74,6 @@ public override AsyncUnaryCall AsyncUnaryCall( } } - foreach (var entry in middlewareHeaders) - { - mergedHeaders.Add(entry); - } - var updatedContext = new ClientInterceptorContext( context.Method, context.Host, @@ -95,6 +91,7 @@ public override AsyncUnaryCall AsyncUnaryCall( middleware.ForEach(m => m.OnHeadersReceived(metadataAdapter)); headersReceived = true; } + return task.Result; }); @@ -107,7 +104,7 @@ public override AsyncUnaryCall AsyncUnaryCall( foreach (var m in middleware) m.OnHeadersReceived(trailersAdapter); } - + var status = call.GetStatus(); var trailers = call.GetTrailers(); var flightStatus = StatusUtils.FromGrpcStatusAndTrailers(status, trailers); @@ -119,7 +116,7 @@ public override AsyncUnaryCall AsyncUnaryCall( return response.Result; }); - + return new AsyncUnaryCall( responseTask, responseHeadersTask, @@ -127,5 +124,26 @@ public override AsyncUnaryCall AsyncUnaryCall( call.GetTrailers, call.Dispose); } + + private void AddCallerMetadata(ref ClientInterceptorContext context) + where TRequest : class + where TResponse : class + { + var headers = context.Options.Headers; + + // Call doesn't have a headers collection to add to. + // Need to create a new context with headers for the call. + if (headers == null) + { + headers = new Metadata(); + var options = context.Options.WithHeaders(headers); + context = new ClientInterceptorContext(context.Method, context.Host, options); + } + + // Add caller metadata to call headers + headers.Add("caller-user", Environment.UserName); + headers.Add("caller-machine", Environment.MachineName); + headers.Add("caller-os", Environment.OSVersion.ToString()); + } } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs new file mode 100644 index 00000000000..673c74e4d51 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -0,0 +1,382 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Sql.Client; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql; + +public class PreparedStatement : IDisposable +{ + private readonly FlightSqlClient _client; + private readonly string _handle; + private Schema _datasetSchema; + private Schema _parameterSchema; + private RecordBatch? _recordsBatch; + private bool _isClosed; + public bool IsClosed => _isClosed; + public string Handle => _handle; + public RecordBatch? ParametersBatch => _recordsBatch; + + /// + /// Initializes a new instance of the class. + /// + /// The Flight SQL client used for executing SQL operations. + /// The handle representing the prepared statement. + /// The schema of the result dataset. + /// The schema of the parameters for this prepared statement. + public PreparedStatement(FlightSqlClient client, string handle, Schema datasetSchema, Schema parameterSchema) + { + _client = client ?? throw new ArgumentNullException(nameof(client)); + _handle = handle ?? throw new ArgumentNullException(nameof(handle)); + _datasetSchema = datasetSchema ?? throw new ArgumentNullException(nameof(datasetSchema)); + _parameterSchema = parameterSchema ?? throw new ArgumentNullException(nameof(parameterSchema)); + _isClosed = false; + } + + /// + /// Retrieves the schema associated with the prepared statement asynchronously. + /// + /// The options used to configure the Flight call. + /// A task representing the asynchronous operation, which returns the schema of the result set. + /// Thrown when the schema is empty or invalid. + public async Task GetSchemaAsync(FlightCallOptions? options = default) + { + EnsureStatementIsNotClosed(); + + try + { + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8) + }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schema = await _client.GetSchemaAsync(descriptor, options).ConfigureAwait(false); + if (schema == null || !schema.FieldsList.Any()) + { + throw new InvalidOperationException("Schema is empty or invalid."); + } + return schema; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to retrieve the schema for the prepared statement", ex); + } + } + + /// + /// Closes the prepared statement asynchronously. + /// + /// The options used to configure the Flight call. + /// A task representing the asynchronous operation. + /// Thrown if closing the prepared statement fails. + public async Task CloseAsync(FlightCallOptions? options = default) + { + EnsureStatementIsNotClosed(); + try + { + var closeRequest = new ActionClosePreparedStatementRequest + { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8) + }; + + var action = new FlightAction(SqlAction.CloseRequest, closeRequest.PackAndSerialize()); + await foreach (var result in _client.DoActionAsync(action, options).ConfigureAwait(false)) + { + // Just drain the results to complete the operation + } + + _isClosed = true; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to close the prepared statement", ex); + } + } + + /// + /// Reads the result from an asynchronous stream of FlightData and populates the provided Protobuf message. + /// + /// The asynchronous stream of objects. + /// The Protobuf message to populate with the data from the stream. + /// A task that represents the asynchronous read operation. + /// Thrown if or is null. + /// Thrown if parsing the data fails. + public async Task ReadResultAsync(IAsyncEnumerable results, IMessage message) + { + if (results == null) throw new ArgumentNullException(nameof(results)); + if (message == null) throw new ArgumentNullException(nameof(message)); + + await foreach (var flightData in results.ConfigureAwait(false)) + { + if (flightData.DataBody == null || flightData.DataBody.Length == 0) + continue; + + try + { + message.MergeFrom(message.PackAndSerialize()); + } + catch (InvalidProtocolBufferException ex) + { + throw new InvalidOperationException("Failed to parse the received FlightData into the specified message.", ex); + } + } + } + + /// + /// Parses the response of a prepared statement execution from the FlightData stream. + /// + /// The Flight SQL client. + /// The asynchronous stream of objects. + /// A task representing the asynchronous operation, which returns the populated . + /// Thrown if or is null. + /// Thrown if the prepared statement handle or data is invalid. + public async Task ParseResponseAsync(FlightSqlClient client, IAsyncEnumerable results) + { + if (client == null) + throw new ArgumentNullException(nameof(client)); + + if (results == null) + throw new ArgumentNullException(nameof(results)); + + var preparedStatementResult = new ActionCreatePreparedStatementResult(); + await foreach (var flightData in results.ConfigureAwait(false)) + { + if (flightData.DataBody == null || flightData.DataBody.Length == 0) + { + continue; + } + + try + { + preparedStatementResult.MergeFrom(flightData.DataBody.ToByteArray()); + } + catch (InvalidProtocolBufferException ex) + { + throw new InvalidOperationException("Failed to parse FlightData into ActionCreatePreparedStatementResult.", ex); + } + } + + if (preparedStatementResult.PreparedStatementHandle.Length == 0) + { + throw new InvalidOperationException("Received an empty or invalid PreparedStatementHandle."); + } + + Schema datasetSchema = null!; + Schema parameterSchema = null!; + + if (preparedStatementResult.DatasetSchema.Length > 0) + { + datasetSchema = SchemaExtensions.DeserializeSchema(preparedStatementResult.DatasetSchema.ToByteArray()); + } + + if (preparedStatementResult.ParameterSchema.Length > 0) + { + parameterSchema = SchemaExtensions.DeserializeSchema(preparedStatementResult.ParameterSchema.ToByteArray()); + } + + // Create and return the PreparedStatement object + return new PreparedStatement(client, preparedStatementResult.PreparedStatementHandle.ToStringUtf8(), + datasetSchema, parameterSchema); + } + + /// + /// Binds the specified parameter batch to the prepared statement and returns the status. + /// + /// The containing parameters to bind to the statement. + /// A cancellation token for the binding operation. + /// A indicating success or failure. + /// Thrown if is null. + public void SetParameters(RecordBatch parameterBatch) + { + _recordsBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); + } + + /// + /// Executes the prepared statement asynchronously and retrieves the query results as . + /// + /// Optional to observe while waiting for the task to complete. The task will be canceled if the token is canceled. + /// Optional The for the operation, which may include timeouts, headers, and other options for the call. + /// A representing the asynchronous operation. The task result contains the describing the executed query results. + /// Thrown if the prepared statement is closed or if there is an error during execution. + /// Thrown if the operation is canceled by the . + public async Task ExecuteAsync(CancellationToken cancellationToken = default, FlightCallOptions? options = default) + { + EnsureStatementIsNotClosed(); + + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8), + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + cancellationToken.ThrowIfCancellationRequested(); + + if (_recordsBatch != null) + { + await BindParametersAsync(descriptor, _recordsBatch, options).ConfigureAwait(false); + } + cancellationToken.ThrowIfCancellationRequested(); + return await _client.GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + } + + /// + /// Executes a prepared update statement asynchronously with the provided parameter batch. + /// + /// + /// This method executes an update operation using a prepared statement. The provided + /// is bound to the statement, and the operation is sent to the server. The server processes the update and returns + /// metadata indicating the number of affected rows. + /// + /// This operation is asynchronous and can be canceled via the provided . + /// + /// + /// A containing the parameters to be bound to the update statement. + /// This batch should match the schema expected by the prepared statement. + /// + /// The for this execution, containing headers and other options. + /// + /// A representing the asynchronous operation. + /// The task result contains the number of rows affected by the update. + /// + /// + /// Thrown if is null, as a valid parameter batch is required for execution. + /// + /// + /// Thrown if the update operation fails for any reason, including when the server returns invalid or empty metadata, + /// or if the operation is canceled via the . + /// + /// + /// The following example demonstrates how to use the method to execute an update operation: + /// + /// var parameterBatch = CreateParameterBatch(); + /// var affectedRows = await preparedStatement.ExecuteUpdateAsync(new FlightCallOptions(), parameterBatch); + /// Console.WriteLine($"Rows affected: {affectedRows}"); + /// + /// + public async Task ExecuteUpdateAsync(RecordBatch parameterBatch, FlightCallOptions? options = default) + { + if (parameterBatch == null) + { + throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); + } + + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8), + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var metadata = await BindParametersAsync(descriptor, parameterBatch, options).ConfigureAwait(false); + + try + { + return ParseAffectedRows(metadata); + } + catch (OperationCanceledException) + { + throw new InvalidOperationException("Update operation was canceled."); + } + catch (Exception ex) + { + throw new InvalidOperationException("Failed to execute the prepared update statement.", ex); + } + } + + private long ParseAffectedRows(ByteString metadata) + { + if (metadata == null || metadata.Length == 0) + { + throw new InvalidOperationException("Server returned empty metadata, unable to determine affected row count."); + } + + var updateResult = new DoPutUpdateResult(); + updateResult.MergeFrom(metadata); + return updateResult.RecordCount; + } + + /// + /// Binds parameters to the prepared statement by streaming the given RecordBatch to the server asynchronously. + /// + /// The that identifies the statement or command being executed. + /// The containing the parameters to bind to the prepared statement. + /// The for the operation, which may include timeouts, headers, and other options for the call. + /// A that represents the asynchronous operation. The task result contains the metadata from the server after binding the parameters. + /// Thrown when is null. + /// Thrown if the operation is canceled or if there is an error during the DoPut operation. + public async Task BindParametersAsync(FlightDescriptor descriptor, RecordBatch parameterBatch, FlightCallOptions? options = default) + { + if (parameterBatch == null) + { + throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); + } + var putResult = await _client.DoPutAsync(descriptor, parameterBatch, options).ConfigureAwait(false); + try + { + var metadata = putResult.ApplicationMetadata; + return metadata; + } + catch (OperationCanceledException) + { + throw new InvalidOperationException("Parameter binding was canceled."); + } + catch (Exception ex) + { + throw new InvalidOperationException("Failed to bind parameters to the prepared statement.", ex); + } + } + + /// + /// Ensures that the statement is not already closed. + /// + private void EnsureStatementIsNotClosed() + { + if (_isClosed) + throw new InvalidOperationException("Cannot execute a closed statement."); + } + + /// + /// Disposes of the resources used by the prepared statement. + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Disposes of the resources used by the prepared statement. + /// + /// Whether the method is called from . + protected virtual void Dispose(bool disposing) + { + if (_isClosed) return; + + if (disposing) + { + CloseAsync(new FlightCallOptions()).GetAwaiter().GetResult(); + } + + _isClosed = true; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs new file mode 100644 index 00000000000..e734736242e --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.IO; +using Apache.Arrow.Ipc; + +namespace Apache.Arrow.Flight.Sql; + +public static class SchemaExtensions +{ + /// + /// Deserializes a schema from a byte array. + /// + /// The byte array representing the serialized schema. + /// The deserialized Schema object. + public static Schema DeserializeSchema(ReadOnlyMemory serializedSchema) + { + if (serializedSchema.IsEmpty) + { + throw new ArgumentException("Invalid serialized schema", nameof(serializedSchema)); + } + using var reader = new ArrowStreamReader(serializedSchema); + return reader.Schema; + } + + /// + /// Serializes the provided schema to a byte array. + /// + public static byte[] SerializeSchema(Schema schema) + { + using var memoryStream = new MemoryStream(); + using var writer = new ArrowStreamWriter(memoryStream, schema); + writer.WriteStart(); + writer.WriteEnd(); + return memoryStream.ToArray(); + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs index f3f3bef1e1d..b0c2b454a24 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs @@ -19,4 +19,9 @@ public static class SqlAction { public const string CreateRequest = "CreatePreparedStatement"; public const string CloseRequest = "ClosePreparedStatement"; + public const string CancelFlightInfoRequest = "CancelFlightInfo"; + public const string BeginTransactionRequest = "BeginTransaction"; + public const string CommitRequest = "CommitTransaction"; + public const string RollbackRequest = "RollbackTransaction"; + public const string GetPrimaryKeysRequest = "GetPrimaryKeys"; } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs b/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs new file mode 100644 index 00000000000..9c98e5ff1d1 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; + +namespace Apache.Arrow.Flight.Sql; + +public class TableRef +{ + public string? Catalog { get; } + public string DbSchema { get; } + public string Table { get; } + + public TableRef(string dbSchema, string table) + { + DbSchema = dbSchema ?? throw new ArgumentNullException(nameof(dbSchema)); + Table = table ?? throw new ArgumentNullException(nameof(table)); + } + + public TableRef(string? catalog, string dbSchema, string table) + { + Catalog = catalog; + DbSchema = dbSchema ?? throw new ArgumentNullException(nameof(dbSchema)); + Table = table ?? throw new ArgumentNullException(nameof(table)); + } +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs new file mode 100644 index 00000000000..097cadea387 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Flight.Sql; + +using Google.Protobuf; // Ensure you have the Protobuf dependency + +public readonly struct Transaction +{ + private static readonly ByteString TransactionIdDefaultValue = ByteString.Empty; + private readonly ByteString _transactionId; + + public ByteString TransactionId => _transactionId ?? TransactionIdDefaultValue; + + public static readonly Transaction NoTransaction = new(TransactionIdDefaultValue); + + public Transaction(ByteString transactionId) + { + _transactionId = ProtoPreconditions.CheckNotNull(transactionId, nameof(transactionId)); + } + + public Transaction(string transactionId) + { + _transactionId = ByteString.CopyFromUtf8(transactionId); + } + + public bool IsValid() => TransactionId.Length > 0; + + public override bool Equals(object? obj) + { + if (obj is not Transaction other) + return false; + + // Safe compare even if _transactionId is null (from default(Transaction)) + return (_transactionId ?? TransactionIdDefaultValue) + .Equals(other._transactionId); + } + + public override int GetHashCode() + { + return (_transactionId ?? TransactionIdDefaultValue).GetHashCode(); + } + + public static bool operator ==(Transaction left, Transaction right) => left.Equals(right); + public static bool operator !=(Transaction left, Transaction right) => !left.Equals(right); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs new file mode 100644 index 00000000000..f4573ef38d2 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using Apache.Arrow.Flight.Protocol; +using Google.Protobuf; +using Google.Protobuf.Reflection; + +namespace Apache.Arrow.Flight; + +public class FlightInfoCancelRequest : IMessage +{ + private readonly CancelFlightInfoRequest _cancelFlightInfoRequest; + public FlightInfo FlightInfo { get; private set; } + + public FlightInfoCancelRequest(FlightInfo flightInfo) + { + FlightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); + _cancelFlightInfoRequest = new CancelFlightInfoRequest(); + } + + public FlightInfoCancelRequest() + { + _cancelFlightInfoRequest = new CancelFlightInfoRequest(); + } + + public void MergeFrom(CodedInputStream input) + { + _cancelFlightInfoRequest.MergeFrom(input); + } + + public void WriteTo(CodedOutputStream output) + { + _cancelFlightInfoRequest.WriteTo(output); + } + + public int CalculateSize() => _cancelFlightInfoRequest.CalculateSize(); + + public MessageDescriptor Descriptor => + DescriptorReflection.Descriptor.MessageTypes[0]; +} diff --git a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs new file mode 100644 index 00000000000..14cad662338 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.Protocol; +using Google.Protobuf; +using Google.Protobuf.Reflection; + +namespace Apache.Arrow.Flight; + +public class FlightInfoCancelResult : IMessage +{ + private readonly CancelFlightInfoResult _flightInfoCancelResult; + + public FlightInfoCancelResult() + { + _flightInfoCancelResult = new CancelFlightInfoResult(); + Descriptor = DescriptorReflection.Descriptor.MessageTypes[0]; + } + + public void MergeFrom(CodedInputStream input) => _flightInfoCancelResult.MergeFrom(input); + + public void WriteTo(CodedOutputStream output) => _flightInfoCancelResult.WriteTo(output); + + public int CalculateSize() + { + return _flightInfoCancelResult.CalculateSize(); + } + + public MessageDescriptor Descriptor { get; } + + public int GetCancelStatus() + { + return (int)_flightInfoCancelResult.Status; + } + + public void SetStatus(int status) + { + _flightInfoCancelResult.Status = (CancelStatus)status; + } +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index 58bd732a11d..8cdaf146b5e 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -13,8 +13,8 @@ - - + + + - diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs index 7e04e153349..17467a6a6fc 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs @@ -17,7 +17,7 @@ using System.Linq; using System.Net; using System.Threading.Tasks; -using Apache.Arrow.Flight.Sql.Middleware.Middleware; +using Apache.Arrow.Flight.Sql.Middleware; using Apache.Arrow.Flight.Sql.Tests.Stubs; using Xunit; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs new file mode 100644 index 00000000000..420f0dba8ef --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -0,0 +1,855 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Sql.Client; +using Apache.Arrow.Flight.TestWeb; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Grpc.Core.Utils; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class FlightSqlClientTests : IDisposable +{ + readonly TestFlightSqlWebFactory _testWebFactory; + readonly FlightStore _flightStore; + private readonly FlightSqlClient _flightSqlClient; + private readonly FlightSqlTestUtils _testUtils; + + public FlightSqlClientTests() + { + _flightStore = new FlightStore(); + _testWebFactory = new TestFlightSqlWebFactory(_flightStore); + FlightClient flightClient = new(_testWebFactory.GetChannel()); + _flightSqlClient = new FlightSqlClient(flightClient); + + _testUtils = new FlightSqlTestUtils(_testWebFactory, _flightStore); + } + + #region Transactions + + [Fact] + public async Task CommitTransactionAsync() + { + // Arrange + string transactionId = "sample-transaction-id"; + var transaction = new Transaction(transactionId); + + // Act + var streamCall = _flightSqlClient.CommitAsync(transaction); + var result = await streamCall.ResponseStream.ToListAsync(); + + // Assert + Assert.NotNull(result); + Assert.Equal(transaction.TransactionId, result.FirstOrDefault()?.Body); + } + + [Fact] + public async Task BeginTransactionAsync() + { + // Arrange + string expectedTransactionId = "sample-transaction-id"; + + // Act + var transaction = await _flightSqlClient.BeginTransactionAsync(); + + // Assert + Assert.NotEqual(Transaction.NoTransaction, transaction); + Assert.Equal(ByteString.CopyFromUtf8(expectedTransactionId), transaction.TransactionId); + } + + [Fact] + public async Task RollbackTransactionAsync() + { + // Arrange + string transactionId = "sample-transaction-id"; + var transaction = new Transaction(transactionId); + + // Act + var streamCall = _flightSqlClient.RollbackAsync(transaction); + var result = await streamCall.ResponseStream.ToListAsync(); + + // Assert + Assert.Equal(result.FirstOrDefault()?.Body, transaction.TransactionId); + } + + #endregion + + #region PreparedStatement + + [Fact] + public async Task PreparedAsync() + { + // Arrange + string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + + // Create a sample schema for the dataset and parameters + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + + var recordBatch = new RecordBatch(schema, new Array[] + { + new Int32Array.Builder().Append(1).Build(), + new StringArray.Builder().Append("John Doe").Build() + }, 1); + + var flightHolder = new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + var datasetSchemaBytes = SchemaExtensions.SerializeSchema(schema); + var parameterSchemaBytes = SchemaExtensions.SerializeSchema(schema); + + var preparedStatementResponse = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFromUtf8("prepared-handle"), + DatasetSchema = ByteString.CopyFrom(datasetSchemaBytes), + ParameterSchema = ByteString.CopyFrom(parameterSchemaBytes) + }; + + // Act + var preparedStatement = await _flightSqlClient.PrepareAsync(query, transaction); + var deserializedDatasetSchema = + SchemaExtensions.DeserializeSchema(preparedStatementResponse.DatasetSchema.ToByteArray()); + var deserializedParameterSchema = + SchemaExtensions.DeserializeSchema(preparedStatementResponse.ParameterSchema.ToByteArray()); + + // Assert + Assert.NotNull(preparedStatement); + Assert.NotNull(deserializedDatasetSchema); + Assert.NotNull(deserializedParameterSchema); + CompareSchemas(schema, deserializedDatasetSchema); + CompareSchemas(schema, deserializedParameterSchema); + } + + #endregion + + [Fact] + public async Task ExecuteUpdateAsync() + { + // Arrange + string query = "UPDATE test_table SET column1 = 'value'"; + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + + var recordBatch = new RecordBatch(schema, new IArrowArray[] + { + new Int32Array.Builder().AppendRange([1, 2, 3, 4, 5]).Build(), + new StringArray.Builder().AppendRange(["John Doe", "Jane Doe", "Alice", "Bob", "Charlie"]).Build() + }, 5); + + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + long affectedRows = await _flightSqlClient.ExecuteUpdateAsync(query, transaction); + + // Assert + Assert.Equal(5, affectedRows); + } + + [Fact] + public async Task ExecuteAsync() + { + // Arrange + string query = "SELECT * FROM test_table"; + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.ExecuteAsync(query, transaction); + + // Assert + Assert.NotNull(flightInfo); + Assert.Single(flightInfo.Endpoints); + } + + [Fact] + public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided() + { + // Arrange + string query = "SELECT * FROM test_table"; + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.ExecuteAsync(query, transaction); + + // Assert + Assert.NotNull(flightInfo); + Assert.IsType(flightInfo); + } + + [Fact] + public async Task ExecuteAsync_ShouldThrowArgumentException_WhenQueryIsEmpty() + { + // Arrange + string emptyQuery = string.Empty; + var transaction = new Transaction("sample-transaction-id"); + + // Act & Assert + await Assert.ThrowsAsync(async () => + await _flightSqlClient.ExecuteAsync(emptyQuery, transaction)); + } + + [Fact] + public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenTransactionIsNoTransaction() + { + // Arrange + string query = "SELECT * FROM test_table"; + var transaction = Transaction.NoTransaction; + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.ExecuteAsync(query, transaction); + + // Assert + Assert.NotNull(flightInfo); + Assert.IsType(flightInfo); + } + + [Fact] + public async Task GetFlightInfoAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + // Act + var flightInfo = await _flightSqlClient.GetFlightInfoAsync(flightDescriptor); + + // Assert + Assert.NotNull(flightInfo); + } + + [Fact] + public async Task GetExecuteSchemaAsync() + { + // Arrange + string query = "SELECT * FROM test_table"; + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + Schema resultSchema = + await _flightSqlClient.GetExecuteSchemaAsync(query, new Transaction("sample-transaction-id")); + + // Assert + Assert.NotNull(resultSchema); + Assert.Equal(recordBatch.Schema.FieldsList.Count, resultSchema.FieldsList.Count); + CompareSchemas(resultSchema, recordBatch.Schema); + } + + [Fact] + public async Task GetCatalogsAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var result = await _flightSqlClient.GetCatalogsAsync(options); + + // Assert + Assert.NotNull(result); + Assert.Equal(flightHolder.GetFlightInfo().Endpoints.Count, result.Endpoints.Count); + Assert.Equal(flightDescriptor, result.Descriptor); + } + + [Fact] + public async Task GetSchemaAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var result = await _flightSqlClient.GetSchemaAsync(flightDescriptor); + + // Assert + Assert.NotNull(result); + Assert.Equal(recordBatch.Schema.FieldsList.Count, result.FieldsList.Count); + CompareSchemas(result, recordBatch.Schema); + } + + [Fact] + public async Task GetDbSchemasAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + string catalog = "test-catalog"; + string dbSchemaFilterPattern = "test-schema-pattern"; + + // Act + var result = await _flightSqlClient.GetDbSchemasAsync(catalog, dbSchemaFilterPattern); + + // Assert + Assert.NotNull(result); + var expectedFlightInfo = flightHolder.GetFlightInfo(); + Assert.Equal(recordBatch.Schema.FieldsList.Count, result.Schema.FieldsList.Count); + Assert.Equal(expectedFlightInfo.Descriptor.Command, result.Descriptor.Command); + Assert.Equal(expectedFlightInfo.Descriptor.Type, result.Descriptor.Type); + Assert.Equal(expectedFlightInfo.Schema.FieldsList.Count, result.Schema.FieldsList.Count); + Assert.Equal(expectedFlightInfo.Endpoints.Count, result.Endpoints.Count); + + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = result.Schema.FieldsList[i]; + + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + Assert.Equal(expectedField.Metadata?.Count ?? 0, actualField.Metadata?.Count ?? 0); + } + + for (int i = 0; i < expectedFlightInfo.Endpoints.Count; i++) + { + var expectedEndpoint = expectedFlightInfo.Endpoints[i]; + var actualEndpoint = result.Endpoints[i]; + + Assert.Equal(expectedEndpoint.Ticket, actualEndpoint.Ticket); + Assert.Equal(expectedEndpoint.Locations.Count(), actualEndpoint.Locations.Count()); + } + } + + [Fact] + public async Task GetPrimaryKeysAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var tableRef = new TableRef("test-catalog", "test-schema", "test-table"); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var result = await _flightSqlClient.GetPrimaryKeysAsync(tableRef); + + // Assert + Assert.NotNull(result); + var expectedFlightInfo = flightHolder.GetFlightInfo(); + Assert.Equal(expectedFlightInfo.Descriptor.Command, result.Descriptor.Command); + Assert.Equal(expectedFlightInfo.Descriptor.Type, result.Descriptor.Type); + Assert.Equal(expectedFlightInfo.Schema.FieldsList.Count, result.Schema.FieldsList.Count); + + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = result.Schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + Assert.Equal(expectedField.Metadata?.Count ?? 0, actualField.Metadata?.Count ?? 0); + } + + Assert.Equal(expectedFlightInfo.Endpoints.Count, result.Endpoints.Count); + + for (int i = 0; i < expectedFlightInfo.Endpoints.Count; i++) + { + var expectedEndpoint = expectedFlightInfo.Endpoints[i]; + var actualEndpoint = result.Endpoints[i]; + + Assert.Equal(expectedEndpoint.Ticket, actualEndpoint.Ticket); + Assert.Equal(expectedEndpoint.Locations.Count(), actualEndpoint.Locations.Count()); + } + } + + [Fact] + public async Task GetTablesAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + string catalog = "sample_catalog"; + string dbSchemaFilterPattern = "sample_schema"; + string tableFilterPattern = "sample_table"; + bool includeSchema = true; + var tableTypes = new List { "BASE TABLE" }; + + // Act + var result = await _flightSqlClient.GetTablesAsync(catalog, dbSchemaFilterPattern, tableFilterPattern, + includeSchema, tableTypes); + + // Assert + Assert.NotNull(result); + Assert.Single(result); + + var expectedFlightInfo = flightHolder.GetFlightInfo(); + var flightInfo = result.First(); + Assert.Equal(expectedFlightInfo.Descriptor.Command, flightInfo.Descriptor.Command); + Assert.Equal(expectedFlightInfo.Descriptor.Type, flightInfo.Descriptor.Type); + Assert.Equal(expectedFlightInfo.Schema.FieldsList.Count, flightInfo.Schema.FieldsList.Count); + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = flightInfo.Schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + + Assert.Equal(expectedFlightInfo.Endpoints.Count, flightInfo.Endpoints.Count); + } + + [Fact] + public async Task GetCatalogsSchemaAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetCatalogsSchemaAsync(); + + // Assert + Assert.NotNull(schema); + var expectedFlightInfo = flightHolder.GetFlightInfo(); + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + } + + [Fact] + public async Task GetDbSchemasSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetDbSchemasSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + for (int i = 0; i < schema.FieldsList.Count; i++) + { + var expectedField = schema.FieldsList[i]; + var actualField = schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + } + + [Fact] + public async Task DoPutAsync() + { + // Arrange + var schema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("COLUMN_SIZE").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + + int[] dataTypeIds = { 1, 2, 3 }; + string[] typeNames = ["INTEGER", "VARCHAR", "BOOLEAN"]; + int[] precisions = { 32, 255, 1 }; + string[] literalPrefixes = ["N'", "'", "b'"]; + int[] columnSizes = [10, 255, 1]; + + var recordBatch = new RecordBatch(schema, + [ + new Int32Array.Builder().AppendRange(dataTypeIds).Build(), + new StringArray.Builder().AppendRange(typeNames).Build(), + new Int32Array.Builder().AppendRange(precisions).Build(), + new StringArray.Builder().AppendRange(literalPrefixes).Build(), + new Int32Array.Builder().AppendRange(columnSizes).Build() + ], 5); + Assert.NotNull(recordBatch); + Assert.Equal(5, recordBatch.Length); + var flightHolder = new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(_testUtils.CreateTestBatch(0, 100))); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + var expectedBatch = _testUtils.CreateTestBatch(0, 100); + + // Act + var result = await _flightSqlClient.DoPutAsync(flightDescriptor, expectedBatch); + + // Assert + Assert.NotNull(result); + } + + [Fact] + public async Task GetExportedKeysAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var tableRef = new TableRef("test-catalog", "test-schema", "test-table"); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.GetExportedKeysAsync(tableRef); + + // Assert + Assert.NotNull(flightInfo); + Assert.Equal("test", flightInfo.Descriptor.Command.ToStringUtf8()); + } + + [Fact] + public async Task GetExportedKeysSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var tableRef = new TableRef("test-catalog", "test-schema", "test-table"); + var schema = await _flightSqlClient.GetExportedKeysSchemaAsync(tableRef); + + // Assert + Assert.NotNull(schema); + Assert.True(schema.FieldsList.Count > 0, "Schema should contain fields."); + Assert.Equal("test", schema.FieldsList.First().Name); + } + + [Fact] + public async Task GetImportedKeysAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = + await _flightSqlClient.GetImportedKeysAsync(new TableRef("test-catalog", "test-schema", "test-table")); + + // Assert + Assert.NotNull(flightInfo); + for (int i = 0; i < recordBatch.Schema.FieldsList.Count; i++) + { + var expectedField = recordBatch.Schema.FieldsList[i]; + var actualField = flightInfo.Schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + } + + [Fact] + public async Task GetImportedKeysSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetImportedKeysSchemaAsync(options); + + // Assert + var expectedSchema = recordBatch.Schema; + Assert.NotNull(schema); + Assert.Equal(expectedSchema.FieldsList.Count, schema.FieldsList.Count); + CompareSchemas(expectedSchema, schema); + } + + [Fact] + public async Task GetCrossReferenceAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + var pkTableRef = new TableRef("PKCatalog", "PKSchema", "PKTable"); + var fkTableRef = new TableRef("FKCatalog", "FKSchema", "FKTable"); + + // Act + var flightInfo = await _flightSqlClient.GetCrossReferenceAsync(pkTableRef, fkTableRef); + + // Assert + Assert.NotNull(flightInfo); + Assert.Equal(flightDescriptor, flightInfo.Descriptor); + Assert.Single(flightInfo.Schema.FieldsList); + } + + [Fact] + public async Task GetCrossReferenceSchemaAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetCrossReferenceSchemaAsync(); + + // Assert + var expectedSchema = recordBatch.Schema; + Assert.NotNull(schema); + Assert.Equal(expectedSchema.FieldsList.Count, schema.FieldsList.Count); + } + + [Fact] + public async Task GetTableTypesAsync() + { + // Arrange + var expectedSchema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var commandGetTableTypes = new CommandGetTableTypes(); + byte[] packedCommand = commandGetTableTypes.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.GetTableTypesAsync(); + var actualSchema = flightInfo.Schema; + + // Assert + Assert.NotNull(flightInfo); + CompareSchemas(expectedSchema, actualSchema); + } + + [Fact] + public async Task GetTableTypesSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var expectedSchema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var commandGetTableTypesSchema = new CommandGetTableTypes(); + byte[] packedCommand = commandGetTableTypesSchema.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schemaResult = await _flightSqlClient.GetTableTypesSchemaAsync(options); + + // Assert + Assert.NotNull(schemaResult); + CompareSchemas(expectedSchema, schemaResult); + } + + [Fact] + public async Task GetXdbcTypeInfoAsync() + { + // Arrange + var options = new FlightCallOptions(); + var expectedSchema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("COLUMN_SIZE").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var commandGetXdbcTypeInfo = new CommandGetXdbcTypeInfo(); + byte[] packedCommand = commandGetXdbcTypeInfo.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + + // Creating a flight holder with the expected schema and adding it to the flight store + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.GetXdbcTypeInfoAsync(options); + + // Assert + Assert.NotNull(flightInfo); + CompareSchemas(expectedSchema, flightInfo.Schema); + } + + [Fact] + public async Task GetXdbcTypeInfoSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var expectedSchema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) + .Build(); + + var commandGetXdbcTypeInfo = new CommandGetXdbcTypeInfo(); + byte[] packedCommand = commandGetXdbcTypeInfo.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetXdbcTypeInfoSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + CompareSchemas(expectedSchema, schema); + } + + [Fact] + public async Task GetSqlInfoSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("sqlInfo"); + var expectedSchema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetSqlInfoSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + CompareSchemas(expectedSchema, schema); + } + + [Fact] + public async Task CancelFlightInfoAsync() + { + // Arrange + var schema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); + var cancelRequest = new FlightInfoCancelRequest(flightInfo); + + // Act + var cancelResult = await _flightSqlClient.CancelFlightInfoAsync(cancelRequest); + + // Assert + Assert.Equal(1, cancelResult.GetCancelStatus()); + } + + [Fact] + public async Task CancelQueryAsync() + { + // Arrange + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var schema = new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); + + // Adding the flight info to the flight store for testing + _flightStore.Flights.Add(flightDescriptor, + new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); + + // Act + var cancelStatus = await _flightSqlClient.CancelQueryAsync(flightInfo); + + // Assert + Assert.Equal(1, cancelStatus.GetCancelStatus()); + } + + public void Dispose() => _testWebFactory?.Dispose(); + + private void CompareSchemas(Schema expectedSchema, Schema actualSchema) + { + Assert.Equal(expectedSchema.FieldsList.Count, actualSchema.FieldsList.Count); + + for (int i = 0; i < expectedSchema.FieldsList.Count; i++) + { + var expectedField = expectedSchema.FieldsList[i]; + var actualField = actualSchema.FieldsList[i]; + + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + Assert.Equal(expectedField.Metadata, actualField.Metadata); + } + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs new file mode 100644 index 00000000000..24647bf16e3 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Sql.Client; +using Apache.Arrow.Flight.TestWeb; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Grpc.Core; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests +{ + public class FlightSqlPreparedStatementTests + { + readonly TestFlightSqlWebFactory _testWebFactory; + readonly FlightStore _flightStore; + readonly FlightSqlClient _flightSqlClient; + private readonly PreparedStatement _preparedStatement; + private readonly Schema _schema; + private readonly FlightDescriptor _flightDescriptor; + private readonly RecordBatch _parameterBatch; + + public FlightSqlPreparedStatementTests() + { + _flightStore = new FlightStore(); + _testWebFactory = new TestFlightSqlWebFactory(_flightStore); + _flightSqlClient = new FlightSqlClient(new FlightClient(_testWebFactory.GetChannel())); + + _flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-query"); + _schema = CreateSchema(); + _parameterBatch = CreateParameterBatch(); + _preparedStatement = new PreparedStatement(_flightSqlClient, "test-handle-guid", _schema, _schema); + } + + private static Schema CreateSchema() + { + return new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + } + + private RecordBatch CreateParameterBatch() + { + return new RecordBatch(_schema, + new IArrowArray[] + { + new Int32Array.Builder().AppendRange(new[] { 1, 2, 3 }).Build(), + new StringArray.Builder().AppendRange(new[] { "INTEGER", "VARCHAR", "BOOLEAN" }).Build(), + new Int32Array.Builder().AppendRange(new[] { 32, 255, 1 }).Build() + }, 3); + } + + [Fact] + public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided() + { + var validRecordBatch = CreateParameterBatch(); + _preparedStatement.SetParameters(validRecordBatch); + var flightInfo = await _preparedStatement.ExecuteAsync(); + + Assert.NotNull(flightInfo); + Assert.IsType(flightInfo); + } + + [Fact] + public async Task GetSchemaAsync_ShouldThrowInvalidOperationException_WhenStatementIsClosed() + { + await _preparedStatement.CloseAsync(new FlightCallOptions()); + await Assert.ThrowsAsync(() => + _preparedStatement.GetSchemaAsync(new FlightCallOptions())); + } + + [Fact] + public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreSet() + { + var affectedRows = await _preparedStatement.ExecuteUpdateAsync(_parameterBatch); + Assert.True(affectedRows > 0, "Expected affected rows to be greater than 0."); + } + + [Fact] + public async Task BindParametersAsync_ShouldReturnMetadata_WhenValidInputsAreProvided() + { + var metadata = await _preparedStatement.BindParametersAsync(_flightDescriptor, _parameterBatch); + Assert.NotNull(metadata); + Assert.True(metadata.Length > 0, "Metadata should have a length greater than 0 when valid."); + } + + [Theory] + [MemberData(nameof(GetTestData))] + public async Task TestSetParameters(RecordBatch parameterBatch, Schema parameterSchema, Type expectedException) + { + var preparedStatement = new PreparedStatement(_flightSqlClient, "TestHandle", _schema, parameterSchema); + if (expectedException != null) + { + var exception = + await Record.ExceptionAsync(() => Task.Run(() => preparedStatement.SetParameters(parameterBatch))); + Assert.NotNull(exception); + Assert.IsType(expectedException, exception); + } + } + + [Fact] + public async Task TestSetParameters_Cancelled() + { + var validRecordBatch = CreateRecordBatch([1, 2, 3]); + var cts = new CancellationTokenSource(); + await cts.CancelAsync(); + _preparedStatement.SetParameters(validRecordBatch); + } + + [Fact] + public async Task TestCloseAsync() + { + await _preparedStatement.CloseAsync(new FlightCallOptions()); + Assert.True(_preparedStatement.IsClosed, + "PreparedStatement should be marked as closed after calling CloseAsync."); + } + + [Fact] + public async Task ReadResultAsync_ShouldPopulateMessage_WhenValidFlightData() + { + var message = new ActionCreatePreparedStatementResult(); + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("test-data")); + var results = GetAsyncEnumerable(new List { flightData }); + + await _preparedStatement.ReadResultAsync(results, message); + Assert.NotEmpty(message.PreparedStatementHandle.ToStringUtf8()); + } + + [Fact] + public async Task ReadResultAsync_ShouldNotThrow_WhenFlightDataBodyIsNullOrEmpty() + { + var message = new ActionCreatePreparedStatementResult(); + var flightData1 = new FlightData(_flightDescriptor, ByteString.Empty); + var flightData2 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("")); + var results = GetAsyncEnumerable(new List { flightData1, flightData2 }); + + await _preparedStatement.ReadResultAsync(results, message); + Assert.Empty(message.PreparedStatementHandle.ToStringUtf8()); + } + + [Fact] + public async Task ParseResponseAsync_ShouldReturnPreparedStatement_WhenValidData() + { + var preparedStatementHandle = "test-handle"; + var actionResult = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFrom(preparedStatementHandle, Encoding.UTF8), + DatasetSchema = _schema.ToByteString(), + ParameterSchema = _schema.ToByteString() + }; + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); + var results = GetAsyncEnumerable(new List { flightData }); + + var preparedStatement = await _preparedStatement.ParseResponseAsync(_flightSqlClient, results); + Assert.NotNull(preparedStatement); + Assert.Equal(preparedStatementHandle, preparedStatement.Handle); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + public async Task ParseResponseAsync_ShouldThrowException_WhenPreparedStatementHandleIsNullOrEmpty( + string handle) + { + ActionCreatePreparedStatementResult actionResult = string.IsNullOrEmpty(handle) + ? new ActionCreatePreparedStatementResult() + : new ActionCreatePreparedStatementResult + { PreparedStatementHandle = ByteString.CopyFrom(handle, Encoding.UTF8) }; + + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); + var results = GetAsyncEnumerable(new List { flightData }); + + await Assert.ThrowsAsync(() => + _preparedStatement.ParseResponseAsync(_flightSqlClient, results)); + } + + private async IAsyncEnumerable GetAsyncEnumerable(IEnumerable enumerable) + { + foreach (var item in enumerable) + { + yield return item; + await Task.Yield(); + } + } + + public static IEnumerable GetTestData() + { + var schema = new Schema.Builder().Field(f => f.Name("field1").DataType(Int32Type.Default)).Build(); + var validRecordBatch = CreateRecordBatch([1, 2, 3]); + var invalidSchema = new Schema.Builder().Field(f => f.Name("invalid_field").DataType(Int32Type.Default)) + .Build(); + var invalidRecordBatch = CreateRecordBatch([4, 5, 6]); + + return new List + { + new object[] { validRecordBatch, schema, null }, + new object[] { null, schema, typeof(ArgumentNullException) } + }; + } + + public static RecordBatch CreateRecordBatch(int[] values) + { + var int32Array = new Int32Array.Builder().AppendRange(values).Build(); + return new RecordBatch.Builder().Append("field1", true, int32Array).Build(); + } + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs index 031495fffdc..c1cd8f2bded 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs @@ -13,8 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Memory; +using Apache.Arrow.Types; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; +using Type = System.Type; namespace Apache.Arrow.Flight.Sql.Tests; @@ -25,3 +31,237 @@ public static ByteString PackAndSerialize(this IMessage command) return Any.Pack(command).Serialize(); } } + +internal static class TestSchemaExtensions +{ + public static void PrintSchema(this RecordBatch recordBatchResult) + { + // Display column headers + foreach (var field in recordBatchResult.Schema.FieldsList) + { + Console.Write($"{field.Name}\t"); + } + + Console.WriteLine(); + + int rowCount = recordBatchResult.Length; + + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) + { + foreach (var array in recordBatchResult.Arrays) + { + // Retrieve value based on array type + if (array is Int32Array intArray) + { + Console.Write($"{intArray.GetValue(rowIndex)}\t"); + } + else if (array is StringArray stringArray) + { + Console.Write($"{stringArray.GetString(rowIndex)}\t"); + } + else if (array is Int64Array longArray) + { + Console.Write($"{longArray.GetValue(rowIndex)}\t"); + } + else if (array is FloatArray floatArray) + { + Console.Write($"{floatArray.GetValue(rowIndex)}\t"); + } + else if (array is BooleanArray boolArray) + { + Console.Write($"{boolArray.GetValue(rowIndex)}\t"); + } + else + { + Console.Write("N/A\t"); // Fallback for unsupported types + } + } + + Console.WriteLine(); // Move to the next row + } + } + + public static RecordBatch CreateRecordBatch(int[] values) + { + var paramsList = new List(); + var schema = new Schema.Builder(); + for (var index = 0; index < values.Length; index++) + { + var val = values[index]; + var builder = new Int32Array.Builder(); + builder.Append(val); + var paramsArray = builder.Build(); + paramsList.Add(paramsArray); + schema.Field(f => f.Name($"param{index}").DataType(Int32Type.Default).Nullable(false)); + } + + return new RecordBatch(schema.Build(), paramsList, values.Length); + } + + public static void PrintSchema(this Schema schema) + { + Console.WriteLine("Schema Fields:"); + Console.WriteLine("{0,-20} {1,-20} {2,-20}", "Field Name", "Field Type", "Is Nullable"); + Console.WriteLine(new string('-', 60)); + + foreach (var field in schema.FieldsLookup) + { + string fieldName = field.First().Name; + string fieldType = field.First().DataType.TypeId.ToString(); + string isNullable = field.First().IsNullable ? "Yes" : "No"; + + Console.WriteLine("{0,-20} {1,-20} {2,-20}", fieldName, fieldType, isNullable); + } + } + + public static string GetStringValue(IArrowArray array, int index) + { + return array switch + { + StringArray stringArray => stringArray.GetString(index), + Int32Array intArray => intArray.GetValue(index).ToString(), + Int64Array longArray => longArray.GetValue(index).ToString(), + BooleanArray boolArray => boolArray.GetValue(index).Value ? "true" : "false", + _ => "Unsupported Type" + }; + } + + public static void PrintRecordBatch(RecordBatch recordBatch) + { + int rowCount = recordBatch.Length; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) + { + string catalogName = GetStringValue(recordBatch.Column(0), rowIndex); + string schemaName = GetStringValue(recordBatch.Column(1), rowIndex); + string tableName = GetStringValue(recordBatch.Column(2), rowIndex); + string tableType = GetStringValue(recordBatch.Column(3), rowIndex); + + Console.WriteLine("{0,-20} {1,-20} {2,-20} {3,-20}", catalogName, schemaName, tableName, tableType); + } + } + + public static RecordBatch CreateRecordBatch(int[] ids, string[] values) + { + var idArrayBuilder = new Int32Array.Builder(); + var valueArrayBuilder = new StringArray.Builder(); + + for (int i = 0; i < ids.Length; i++) + { + idArrayBuilder.Append(ids[i]); + valueArrayBuilder.Append(values[i]); + } + + var schema = new Schema.Builder() + .Field(f => f.Name("Id").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("Value").DataType(StringType.Default).Nullable(false)) + .Build(); + + return new RecordBatch(schema, [idArrayBuilder.Build(), valueArrayBuilder.Build()], ids.Length); + } + + public static RecordBatch CreateRecordBatch(T[] items) + { + if (items is null || items.Length == 0) + { + throw new ArgumentException("Items array cannot be null or empty."); + } + + var schema = BuildSchema(typeof(T)); + + var arrays = new List(); + foreach (var field in schema.FieldsList) + { + var property = typeof(T).GetProperty(field.Name); + if (property is null) + { + throw new InvalidOperationException($"Property {field.Name} not found in type {typeof(T).Name}."); + } + + // extract values and build the array + var values = items.Select(item => property.GetValue(item, null)).ToArray(); + var array = BuildArrowArray(field.DataType, values); + arrays.Add(array); + } + return new RecordBatch(schema, arrays, items.Length); + } + private static Schema BuildSchema(Type type) + { + var builder = new Schema.Builder(); + + foreach (var property in type.GetProperties()) + { + var fieldType = InferArrowType(property.PropertyType); + builder.Field(f => f.Name(property.Name).DataType(fieldType).Nullable(true)); + } + + return builder.Build(); + } + + private static IArrowType InferArrowType(Type type) + { + return type switch + { + { } t when t == typeof(string) => StringType.Default, + { } t when t == typeof(int) => Int32Type.Default, + { } t when t == typeof(float) => FloatType.Default, + { } t when t == typeof(bool) => BooleanType.Default, + { } t when t == typeof(long) => Int64Type.Default, + _ => throw new NotSupportedException($"Unsupported type: {type}") + }; + } + + private static IArrowArray BuildArrowArray(IArrowType dataType, object[] values, MemoryAllocator allocator = default) + { + allocator ??= MemoryAllocator.Default.Value; + + return dataType switch + { + StringType => BuildStringArray(values), + Int32Type => BuildArray(values, allocator), + FloatType => BuildArray(values, allocator), + BooleanType => BuildArray(values, allocator), + Int64Type => BuildArray(values, allocator), + _ => throw new NotSupportedException($"Unsupported Arrow type: {dataType}") + }; + } + + private static IArrowArray BuildStringArray(object[] values) + { + var builder = new StringArray.Builder(); + + foreach (var value in values) + { + if (value is null) + { + builder.AppendNull(); + } + else + { + builder.Append(value.ToString()); + } + } + + return builder.Build(); + } + + private static IArrowArray BuildArray(object[] values, MemoryAllocator allocator) + where TArray : IArrowArray + where TBuilder : IArrowArrayBuilder, new() + { + var builder = new TBuilder(); + + foreach (var value in values) + { + if (value == null) + { + builder.AppendNull(); + } + else + { + builder.Append((T)value); + } + } + + return builder.Build(allocator); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs new file mode 100644 index 00000000000..e32cc198aa3 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Linq; +using Apache.Arrow.Flight.Tests; +using Apache.Arrow.Flight.TestWeb; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class FlightSqlTestUtils +{ + private readonly TestFlightSqlWebFactory _testWebFactory; + private readonly FlightStore _flightStore; + + public FlightSqlTestUtils(TestFlightSqlWebFactory testWebFactory, FlightStore flightStore) + { + _testWebFactory = testWebFactory; + _flightStore = flightStore; + } + + public RecordBatch CreateTestBatch(int startValue, int length) + { + var batchBuilder = new RecordBatch.Builder(); + Int32Array.Builder builder = new(); + for (int i = 0; i < length; i++) + { + builder.Append(startValue + i); + } + + batchBuilder.Append("test", true, builder.Build()); + return batchBuilder.Build(); + } + + public FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, + params RecordBatchWithMetadata[] batches) + { + var initialBatch = batches.FirstOrDefault(); + + var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, + _testWebFactory.GetAddress()); + + foreach (var batch in batches) + { + flightHolder.AddBatch(batch); + } + + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + return flightHolder.GetFlightInfo(); + } +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs new file mode 100644 index 00000000000..b99418ce996 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.TestWeb; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class StartupFlightSql +{ + // This method gets called by the runtime. Use this method to add services to the container. + // For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940 + public void ConfigureServices(IServiceCollection services) + { + services.AddGrpc() + .AddFlightServer(); + + services.AddSingleton(new FlightStore()); + } + + // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. + public void Configure(IApplicationBuilder app, IWebHostEnvironment env) + { + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.UseRouting(); + + app.UseEndpoints(endpoints => + { + endpoints.MapFlightEndpoint(); + + endpoints.MapGet("/", async context => + { + await context.Response.WriteAsync("Communication with gRPC endpoints must be made through a gRPC client. To learn how to create a client, visit: https://go.microsoft.com/fwlink/?linkid=2086909"); + }); + }); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs index 8b86b57d21b..695f5f54287 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs @@ -15,7 +15,7 @@ using System; using System.Net; -using Apache.Arrow.Flight.Sql.Middleware.Middleware; +using Apache.Arrow.Flight.Sql.Middleware; using Microsoft.Extensions.Logging; namespace Apache.Arrow.Flight.Sql.Tests.Stubs; @@ -33,9 +33,9 @@ public Cookie CreateCookie(string name, string value, DateTimeOffset? expires = }; } - public ClientCookieMiddleware.ClientCookieMiddlewareFactory CreateFactory() + public ClientCookieMiddlewareFactory CreateFactory() { - return new ClientCookieMiddleware.ClientCookieMiddlewareFactory(new TestLoggerFactory()); + return new ClientCookieMiddlewareFactory(new TestLoggerFactory()); } public class TestLogger : ILogger diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs new file mode 100644 index 00000000000..594c5d884b3 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Linq; +using Apache.Arrow.Flight.TestWeb; +using Grpc.Net.Client; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class TestFlightSqlWebFactory : IDisposable +{ + readonly IHost host; + private int _port; + + public TestFlightSqlWebFactory(FlightStore flightStore) + { + host = WebHostBuilder(flightStore).Build(); + host.Start(); + var addressInfo = host.Services.GetRequiredService().Features.Get(); + if (addressInfo == null) + { + throw new Exception("No address info could be found for configured server"); + } + + var address = addressInfo.Addresses.First(); + var addressUri = new Uri(address); + _port = addressUri.Port; + AppContext.SetSwitch( + "System.Net.Http.SocketsHttpHandler.Http2UnencryptedSupport", true); + } + + private IHostBuilder WebHostBuilder(FlightStore flightStore) + { + return Host.CreateDefaultBuilder() + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder + .ConfigureKestrel(c => { c.ListenAnyIP(0, l => l.Protocols = HttpProtocols.Http2); }) + .UseStartup() + .ConfigureServices(services => { services.AddSingleton(flightStore); }); + }); + } + + public string GetAddress() + { + return $"http://127.0.0.1:{_port}"; + } + + public GrpcChannel GetChannel() + { + return GrpcChannel.ForAddress(GetAddress()); + } + + public void Stop() + { + host.StopAsync().Wait(); + } + + public void Dispose() + { + Stop(); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj index 1c61c434d7c..41e723113b9 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj @@ -11,6 +11,7 @@ + diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs index c6f7e66c6c2..b79edc4ae54 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs @@ -55,10 +55,25 @@ public FlightInfo GetFlightInfo() int batchBytes = _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b=>b.Length))); return new FlightInfo(_schema, _flightDescriptor, new List() { - new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault()), new List(){ + new FlightEndpoint(new FlightTicket(GetTicket(_flightDescriptor)), new List(){ new FlightLocation(_location) }) }, batchArrayLength, batchBytes); } + + private string GetTicket(FlightDescriptor descriptor) + { + if (descriptor.Paths.FirstOrDefault() != null) + { + return descriptor.Paths.FirstOrDefault(); + } + + if (descriptor.Command.Length > 0) + { + return $"{descriptor.Command.ToStringUtf8()}"; + } + + return "default_custom_ticket"; + } } -} +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs new file mode 100644 index 00000000000..a7aaad4fb2d --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Server; +using Apache.Arrow.Flight.Sql; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; + +namespace Apache.Arrow.Flight.TestWeb; + +public class TestFlightSqlServer : FlightServer +{ + private readonly FlightStore _flightStore; + + public TestFlightSqlServer(FlightStore flightStore) + { + _flightStore = flightStore; + } + + public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + switch (request.Type) + { + case "test": + await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); + break; + case SqlAction.GetPrimaryKeysRequest: + await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); + break; + case SqlAction.CancelFlightInfoRequest: + var cancelRequest = new FlightInfoCancelResult(); + cancelRequest.SetStatus(1); + await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())) + .ConfigureAwait(false); + break; + case SqlAction.BeginTransactionRequest: + case SqlAction.CommitRequest: + case SqlAction.RollbackRequest: + await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))) + .ConfigureAwait(false); + break; + case SqlAction.CreateRequest: + case SqlAction.CloseRequest: + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + var datasetSchemaBytes = SchemaExtensions.SerializeSchema(schema); + var parameterSchemaBytes = SchemaExtensions.SerializeSchema(schema); + + var preparedStatementResponse = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement"), + DatasetSchema = ByteString.CopyFrom(datasetSchemaBytes), + ParameterSchema = ByteString.CopyFrom(parameterSchemaBytes) + }; + byte[] packedResult = Any.Pack(preparedStatementResponse).Serialize().ToByteArray(); + var flightResult = new FlightResult(packedResult); + await responseStream.WriteAsync(flightResult).ConfigureAwait(false); + break; + default: + throw new NotImplementedException(); + } + } + + public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, + ServerCallContext context) + { + FlightDescriptor flightDescriptor = FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); + + if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + var batches = flightHolder.GetRecordBatches(); + + foreach (var batch in batches) + { + await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); + } + } + } + + public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) + { + var flightDescriptor = await requestStream.FlightDescriptor; + + if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + } + + int affectedRows = 0; + while (await requestStream.MoveNext()) + { + // Process the record batch (if needed) here + // Increment the affected row count for demonstration purposes + affectedRows += requestStream.Current.Column(0).Length; // Example of counting rows in the first column + } + + // Create a DoPutUpdateResult with the affected row count + var updateResult = new DoPutUpdateResult + { + RecordCount = affectedRows // Set the actual affected row count + }; + + // Serialize the DoPutUpdateResult into a ByteString + var metadata = updateResult.ToByteString(); + + // Send the metadata back as part of the FlightPutResult + var flightPutResult = new FlightPutResult(metadata); + await responseStream.WriteAsync(flightPutResult); + } + + public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo()); + } + + if (_flightStore.Flights.Count > 0) + { + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + public override Task GetSchema(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo().Schema); + } + + if (_flightStore.Flights.Count > 0) + { + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo().Schema); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } +} From a14a845205e5e86ecbc3aa69e9323d4577ddfd2a Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 27 Apr 2025 17:40:12 +0300 Subject: [PATCH 03/10] feat: DoPut --- .../Apache.Arrow.Flight.Sql/DoPutResult.cs | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs new file mode 100644 index 00000000000..646ed38647d --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql; + +public class DoPutResult +{ + public FlightClientRecordBatchStreamWriter Writer { get; } + public IAsyncStreamReader Reader { get; } + + public DoPutResult(FlightClientRecordBatchStreamWriter writer, IAsyncStreamReader reader) + { + Writer = writer; + Reader = reader; + } + + /// + /// Reads the metadata asynchronously from the reader. + /// + /// A ByteString containing the metadata read from the reader. + public async Task ReadMetadataAsync() + { + if (await Reader.MoveNext().ConfigureAwait(false)) + { + return Reader.Current.ApplicationMetadata; + } + throw new RpcException(new Status(StatusCode.Internal, "No metadata available in the response stream.")); + } + + /// + /// Completes the writer by signaling the end of the writing process. + /// + /// A Task representing the completion of the writer. + public async Task CompleteAsync() + { + await Writer.CompleteAsync().ConfigureAwait(false); + } +} From 81ad7268ae83ce146eb55152b12a77336dc3799c Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 28 Apr 2025 18:57:22 +0300 Subject: [PATCH 04/10] feat: Middleware implementation --- .../Middleware/ClientCookieMiddleware.cs | 59 +-------- .../ClientCookieMiddlewareFactory.cs | 71 +++++++++++ .../Middleware/Extensions/CookieExtensions.cs | 36 ++++-- .../Extensions}/FlightMethodParser.cs | 9 +- .../Middleware/Extensions}/StatusUtils.cs | 4 +- .../Interceptors/ClientInterceptorAdapter.cs | 12 +- .../Middleware/Interfaces/ICallHeaders.cs | 8 +- .../Interfaces/IFlightClientMiddleware.cs | 5 +- .../Middleware}/MetadataAdapter.cs | 14 +-- .../Middleware/Models/CallInfo.cs | 2 +- .../Middleware/Models/CallStatus.cs | 10 +- .../Middleware/Models/FlightMethod.cs | 2 +- .../Middleware/Models/FlightStatusCode.cs | 2 +- .../MiddlewareTests}/CallHeadersTests.cs | 3 +- .../ClientCookieMiddlewareTests.cs | 6 +- .../ClientInterceptorAdapterTests.cs | 6 +- .../MiddlewareTests/CookieExtensionsTests.cs | 112 ++++++++++++++++++ .../Stubs/CapturingMiddleware.cs | 6 +- .../Stubs/CapturingMiddlewareFactory.cs | 6 +- .../Stubs/ClientCookieMiddlewareMock.cs | 8 +- .../Stubs/InMemoryCallHeaders.cs | 4 +- .../Stubs/InMemoryFlightStore.cs | 0 22 files changed, 267 insertions(+), 118 deletions(-) rename csharp/src/{Apache.Arrow.Flight.Sql/Middleware => Apache.Arrow.Flight}/Middleware/ClientCookieMiddleware.cs (55%) create mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Extensions/CookieExtensions.cs (52%) rename csharp/src/{Apache.Arrow.Flight.Sql/Middleware/Grpc => Apache.Arrow.Flight/Middleware/Extensions}/FlightMethodParser.cs (90%) rename csharp/src/{Apache.Arrow.Flight.Sql/Middleware/Grpc => Apache.Arrow.Flight/Middleware/Extensions}/StatusUtils.cs (96%) rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Interceptors/ClientInterceptorAdapter.cs (93%) rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Interfaces/ICallHeaders.cs (87%) rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Interfaces/IFlightClientMiddleware.cs (87%) rename csharp/src/{Apache.Arrow.Flight.Sql/Middleware/Grpc => Apache.Arrow.Flight/Middleware}/MetadataAdapter.cs (90%) rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Models/CallInfo.cs (95%) rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Models/CallStatus.cs (78%) rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Models/FlightMethod.cs (94%) rename csharp/src/{Apache.Arrow.Flight.Sql => Apache.Arrow.Flight}/Middleware/Models/FlightStatusCode.cs (95%) rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/CallHeadersTests.cs (97%) rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/ClientCookieMiddlewareTests.cs (97%) rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/ClientInterceptorAdapterTests.cs (95%) create mode 100644 csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/Stubs/CapturingMiddleware.cs (92%) rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/Stubs/CapturingMiddlewareFactory.cs (86%) rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/Stubs/ClientCookieMiddlewareMock.cs (87%) rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/Stubs/InMemoryCallHeaders.cs (95%) rename csharp/test/{Apache.Arrow.Flight.Sql.Tests => Apache.Arrow.Flight.Tests/MiddlewareTests}/Stubs/InMemoryFlightStore.cs (100%) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs similarity index 55% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs index d56fa61d85b..cd34c9eac4b 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Middleware/ClientCookieMiddleware.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs @@ -13,17 +13,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -using System; -using System.Collections.Concurrent; using System.Collections.Generic; -using System.Globalization; -using System.Net; -using Apache.Arrow.Flight.Sql.Middleware.Extensions; -using Apache.Arrow.Flight.Sql.Middleware.Interfaces; -using Apache.Arrow.Flight.Sql.Middleware.Models; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Apache.Arrow.Flight.Middleware.Models; using Microsoft.Extensions.Logging; -namespace Apache.Arrow.Flight.Sql.Middleware.Middleware; +namespace Apache.Arrow.Flight.Middleware; public class ClientCookieMiddleware : IFlightClientMiddleware { @@ -32,8 +27,6 @@ public class ClientCookieMiddleware : IFlightClientMiddleware private const string SET_COOKIE_HEADER = "Set-cookie"; private const string COOKIE_HEADER = "Cookie"; - private readonly ConcurrentDictionary _cookies = new(); - public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, ILogger logger) { @@ -78,52 +71,6 @@ private string GetValidCookiesAsString() cookieList.Add(entry.Value.ToString()); } } - return string.Join("; ", cookieList); } - - public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory - { - public readonly ConcurrentDictionary Cookies = new(StringComparer.OrdinalIgnoreCase); - private readonly ILoggerFactory _loggerFactory; - - public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) - { - _loggerFactory = loggerFactory; - } - - public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) - { - var logger = _loggerFactory.CreateLogger(); - return new ClientCookieMiddleware(this, logger); - } - - internal void UpdateCookies(IEnumerable newCookieHeaderValues) - { - foreach (var headerValue in newCookieHeaderValues) - { - try - { - var parsedCookies = headerValue.ParseHeader(); - foreach (var parsedCookie in parsedCookies) - { - var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); - if (parsedCookie.Expired) - { - Cookies.TryRemove(nameLc, out _); - } - else - { - Cookies[nameLc] = parsedCookie; - } - } - } - catch (FormatException ex) - { - var logger = _loggerFactory.CreateLogger(); - logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); - } - } - } - } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs new file mode 100644 index 00000000000..6ed8012071e --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using Apache.Arrow.Flight.Middleware.Extensions; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Microsoft.Extensions.Logging; +using CallInfo = Apache.Arrow.Flight.Middleware.Models.CallInfo; + +namespace Apache.Arrow.Flight.Middleware; + +public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory +{ + public readonly ConcurrentDictionary Cookies = new(StringComparer.OrdinalIgnoreCase); + private readonly ILoggerFactory _loggerFactory; + + public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; + } + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) + { + var logger = _loggerFactory.CreateLogger(); + return new ClientCookieMiddleware(this, logger); + } + + internal void UpdateCookies(IEnumerable newCookieHeaderValues) + { + foreach (var headerValue in newCookieHeaderValues) + { + try + { + var parsedCookies = headerValue.ParseHeader(); + foreach (var parsedCookie in parsedCookies) + { + var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); + if (parsedCookie.Expired) + { + Cookies.TryRemove(nameLc, out _); + } + else + { + Cookies[nameLc] = parsedCookie; + } + } + } + catch (FormatException ex) + { + var logger = _loggerFactory.CreateLogger(); + logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); + } + } + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Extensions/CookieExtensions.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs similarity index 52% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Extensions/CookieExtensions.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs index aba6d0f1f71..255006a8b48 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Extensions/CookieExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs @@ -18,28 +18,48 @@ using System.Linq; using System.Net; -namespace Apache.Arrow.Flight.Sql.Middleware.Extensions; +namespace Apache.Arrow.Flight.Middleware.Extensions; -// TODO: Add tests to cover: CookieExtensions -internal static class CookieExtensions +public static class CookieExtensions { public static IEnumerable ParseHeader(this string headers) { var cookies = new List(); - var segments = headers.Split(';', StringSplitOptions.RemoveEmptyEntries); + if (string.IsNullOrEmpty(headers)) + return cookies; + + var segments = headers.Split([';'], StringSplitOptions.RemoveEmptyEntries); if (segments.Length == 0) return cookies; - var nameValue = segments[0].Split('=', 2); + var nameValue = segments[0].Split(['='], 2); if (nameValue.Length == 2) { - var cookie = new Cookie(nameValue[0], nameValue[1]); + var cookie = new Cookie(nameValue[0].Trim(), nameValue[1].Trim()); + foreach (var segment in segments.Skip(1)) { - if (segment.StartsWith("Expires=", StringComparison.OrdinalIgnoreCase)) + var trimmedSegment = segment.Trim(); + if (trimmedSegment.StartsWith("Expires=", StringComparison.OrdinalIgnoreCase)) { - if (DateTimeOffset.TryParse(segment["Expires=".Length..], out var expires)) + var value = trimmedSegment.Substring("Expires=".Length).Trim(); + + if (!DateTimeOffset.TryParseExact( + value, + "R", + System.Globalization.CultureInfo.InvariantCulture, + System.Globalization.DateTimeStyles.None, + out var expires)) + { + if (DateTimeOffset.TryParse(value, out expires)) + { + cookie.Expires = expires.UtcDateTime; + } + } + else + { cookie.Expires = expires.UtcDateTime; + } } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/FlightMethodParser.cs similarity index 90% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Extensions/FlightMethodParser.cs index a73921b182f..a8cbbe56896 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/FlightMethodParser.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/FlightMethodParser.cs @@ -13,11 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -using Apache.Arrow.Flight.Sql.Middleware.Models; +using Apache.Arrow.Flight.Middleware.Models; -namespace Apache.Arrow.Flight.Sql.Middleware.Gprc; +namespace Apache.Arrow.Flight.Middleware.Extensions; -// TODO: Add tests to cover: FlightMethodParser public static class FlightMethodParser { /// @@ -35,7 +34,7 @@ public static FlightMethod Parse(string fullMethodName) if (parts.Length < 2) return FlightMethod.Unknown; - var methodName = parts[^1]; + var methodName = parts[parts.Length - 1]; return methodName switch { @@ -59,6 +58,6 @@ public static string ParseMethodName(string fullMethodName) return "Unknown"; var parts = fullMethodName.Split('/'); - return parts.Length >= 2 ? parts[^1] : "Unknown"; + return parts.Length >= 2 ? parts[parts.Length - 1] : "Unknown"; } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/StatusUtils.cs similarity index 96% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Extensions/StatusUtils.cs index 9485167f76a..d1bde27dfac 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/StatusUtils.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/StatusUtils.cs @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -using Apache.Arrow.Flight.Sql.Middleware.Models; +using Apache.Arrow.Flight.Middleware.Models; using Grpc.Core; -namespace Apache.Arrow.Flight.Sql.Middleware.Gprc; +namespace Apache.Arrow.Flight.Middleware.Extensions; public static class StatusUtils { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs similarity index 93% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs index 576f3f72add..caaaf3f1dca 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interceptors/ClientInterceptorAdapter.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs @@ -16,14 +16,13 @@ using System; using System.Collections.Generic; using System.Linq; -using Apache.Arrow.Flight.Sql.Middleware.Gprc; -using Apache.Arrow.Flight.Sql.Middleware.Grpc; -using Apache.Arrow.Flight.Sql.Middleware.Interfaces; -using Apache.Arrow.Flight.Sql.Middleware.Models; +using Apache.Arrow.Flight.Middleware.Extensions; +using Apache.Arrow.Flight.Middleware.Interfaces; using Grpc.Core; using Grpc.Core.Interceptors; +using CallInfo = Apache.Arrow.Flight.Middleware.Models.CallInfo; -namespace Apache.Arrow.Flight.Sql.Middleware.Interceptors +namespace Apache.Arrow.Flight.Middleware.Interceptors { public class ClientInterceptorAdapter : Interceptor { @@ -36,8 +35,7 @@ public ClientInterceptorAdapter(IEnumerable fact public override AsyncUnaryCall AsyncUnaryCall( TRequest request, - ClientInterceptorContext context, - AsyncUnaryCallContinuation continuation) + ClientInterceptorContext context, AsyncUnaryCallContinuation continuation) where TRequest : class where TResponse : class { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/ICallHeaders.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs similarity index 87% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/ICallHeaders.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs index 1290f185e8d..03f02bd2b6b 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/ICallHeaders.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs @@ -15,14 +15,14 @@ using System.Collections.Generic; -namespace Apache.Arrow.Flight.Sql.Middleware.Interfaces; +namespace Apache.Arrow.Flight.Middleware.Interfaces; public interface ICallHeaders { - string? this[string key] { get; } + string this[string key] { get; } - string? Get(string key); - byte[]? GetBytes(string key); + string Get(string key); + byte[] GetBytes(string key); IEnumerable GetAll(string key); IEnumerable GetAllBytes(string key); diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/IFlightClientMiddleware.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs similarity index 87% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/IFlightClientMiddleware.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs index 75caf9d56d0..4bfdba31fdc 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Interfaces/IFlightClientMiddleware.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs @@ -13,9 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -using Apache.Arrow.Flight.Sql.Middleware.Models; +using Apache.Arrow.Flight.Middleware.Models; +using CallInfo = Apache.Arrow.Flight.Middleware.Models.CallInfo; -namespace Apache.Arrow.Flight.Sql.Middleware.Interfaces; +namespace Apache.Arrow.Flight.Middleware.Interfaces; public interface IFlightClientMiddleware { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/MetadataAdapter.cs b/csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs similarity index 90% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/MetadataAdapter.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs index d81ef73ba99..c000059ef09 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Grpc/MetadataAdapter.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs @@ -16,10 +16,10 @@ using System; using System.Collections.Generic; using System.Linq; -using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Middleware.Interfaces; using Grpc.Core; -namespace Apache.Arrow.Flight.Sql.Middleware.Grpc; +namespace Apache.Arrow.Flight.Middleware; public class MetadataAdapter : ICallHeaders { @@ -30,15 +30,15 @@ public MetadataAdapter(Metadata metadata) _metadata = metadata ?? throw new ArgumentNullException(nameof(metadata)); } - public string? this[string key] => Get(key); + public string this[string key] => Get(key); - public string? Get(string key) + public string Get(string key) { return _metadata.FirstOrDefault(e => !e.IsBinary && e.Key.Equals(key, StringComparison.OrdinalIgnoreCase))?.Value; } - public byte[]? GetBytes(string key) + public byte[] GetBytes(string key) { return _metadata.FirstOrDefault(e => e.IsBinary && e.Key.Equals(NormalizeBinaryKey(key), StringComparison.OrdinalIgnoreCase))?.ValueBytes; @@ -88,13 +88,13 @@ private static string NormalizeBinaryKey(string key) private static string DenormalizeBinaryKey(string key) => key.EndsWith(Metadata.BinaryHeaderSuffix, StringComparison.OrdinalIgnoreCase) - ? key[..^Metadata.BinaryHeaderSuffix.Length] + ? key.Substring(0, key.Length - Metadata.BinaryHeaderSuffix.Length) : key; } public static class MetadataAdapterExtensions { - public static bool TryGet(this ICallHeaders headers, string key, out string? value) + public static bool TryGet(this ICallHeaders headers, string key, out string value) { value = headers.Get(key); return value is not null; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallInfo.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallInfo.cs similarity index 95% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallInfo.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Models/CallInfo.cs index ee20109b973..d66683fe413 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallInfo.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallInfo.cs @@ -15,7 +15,7 @@ using System; -namespace Apache.Arrow.Flight.Sql.Middleware.Models; +namespace Apache.Arrow.Flight.Middleware.Models; public sealed class CallInfo { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallStatus.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallStatus.cs similarity index 78% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallStatus.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Models/CallStatus.cs index 4294b1d855e..37cd47c1a30 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/CallStatus.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallStatus.cs @@ -16,16 +16,16 @@ using System; using Grpc.Core; -namespace Apache.Arrow.Flight.Sql.Middleware.Models; +namespace Apache.Arrow.Flight.Middleware.Models; public sealed class CallStatus { public FlightStatusCode Code { get; } - public Exception? Cause { get; } - public string? Description { get; } - public Metadata? Trailers { get; } + public Exception Cause { get; } + public string Description { get; } + public Metadata Trailers { get; } - public CallStatus(FlightStatusCode code, Exception? cause, string? description, Metadata? trailers) + public CallStatus(FlightStatusCode code, Exception cause, string description, Metadata trailers) { Code = code; Cause = cause; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightMethod.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightMethod.cs similarity index 94% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightMethod.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightMethod.cs index c53a1e668a2..c42de699e4c 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightMethod.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightMethod.cs @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -namespace Apache.Arrow.Flight.Sql.Middleware.Models; +namespace Apache.Arrow.Flight.Middleware.Models; public enum FlightMethod { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightStatusCode.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightStatusCode.cs similarity index 95% rename from csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightStatusCode.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightStatusCode.cs index 65221c1e192..60f28a9062f 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Middleware/Models/FlightStatusCode.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightStatusCode.cs @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -namespace Apache.Arrow.Flight.Sql.Middleware.Models; +namespace Apache.Arrow.Flight.Middleware.Models; public enum FlightStatusCode { diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/CallHeadersTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs similarity index 97% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/CallHeadersTests.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs index 00210557e20..79cf7fbc099 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/CallHeadersTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs @@ -14,9 +14,10 @@ // limitations under the License. using System.Linq; +using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; using Xunit; -namespace Apache.Arrow.Flight.Sql.Tests; +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; public class CallHeadersTests { diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs similarity index 97% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs index 7e04e153349..b1d89a36073 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientCookieMiddlewareTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs @@ -17,11 +17,11 @@ using System.Linq; using System.Net; using System.Threading.Tasks; -using Apache.Arrow.Flight.Sql.Middleware.Middleware; -using Apache.Arrow.Flight.Sql.Tests.Stubs; +using Apache.Arrow.Flight.Middleware; +using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; using Xunit; -namespace Apache.Arrow.Flight.Sql.Tests; +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; public class ClientCookieMiddlewareTests { diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientInterceptorAdapterTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs similarity index 95% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientInterceptorAdapterTests.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs index cb653a5c6ff..fc1706408e4 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/ClientInterceptorAdapterTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs @@ -17,14 +17,14 @@ using System.Linq; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; -using Apache.Arrow.Flight.Sql.Middleware.Interceptors; +using Apache.Arrow.Flight.Middleware.Interceptors; using Apache.Arrow.Flight.Sql.Tests.Stubs; -using Apache.Arrow.Flight.Tests; +using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; using Grpc.Core; using Grpc.Core.Interceptors; using Xunit; -namespace Apache.Arrow.Flight.Sql.Tests; +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; public class ClientInterceptorAdapterTests { diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs new file mode 100644 index 00000000000..2f8b21a4833 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs @@ -0,0 +1,112 @@ +using System; +using Apache.Arrow.Flight.Middleware.Extensions; +using Xunit; +using static System.Linq.Enumerable; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; + +public class CookieExtensionsTests +{ + [Fact] + public void ParseHeaderShouldParseSimpleCookie() + { + // Arrange + var header = "sessionId=abc123"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("sessionId", cookies[0].Name); + Assert.Equal("abc123", cookies[0].Value); + Assert.False(cookies[0].Expired); + } + + [Fact] + public void ParseHeaderShouldParseCookieWithExpires() + { + // Arrange + var futureDate = DateTimeOffset.UtcNow.AddDays(7); + var header = $"userId=789; Expires={futureDate:R}"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("userId", cookies[0].Name); + Assert.Equal("789", cookies[0].Value); + Assert.True(Math.Abs((cookies[0].Expires - futureDate.UtcDateTime).TotalSeconds) < 5); + } + + [Fact] + public void ParseHeaderShouldReturnEmptyWhenMalformed() + { + // Arrange + var header = "this_is_wrong"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Empty(cookies); + } + + [Fact] + public void ParseHeaderShouldReturnEmptyWhenEmptyString() + { + // Arrange + var header = string.Empty; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Empty(cookies); + } + + [Fact] + public void ParseHeaderShouldReturnEmptyWhenNullString() + { + // Arrange + string header = null; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Empty(cookies); + } + + [Fact] + public void ParseHeaderShouldParseCookieIgnoringAttributes() + { + // Arrange + var header = "token=xyz; Path=/; HttpOnly"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("token", cookies[0].Name); + Assert.Equal("xyz", cookies[0].Value); + } + + [Fact] + public void ParseHeaderShouldIgnoreInvalidExpires() + { + // Arrange + var header = "name=value; Expires=invalid-date"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("name", cookies[0].Name); + Assert.Equal("value", cookies[0].Value); + Assert.False(cookies[0].Expired); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddleware.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs similarity index 92% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddleware.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs index bdd3d9e6c61..29e6348154c 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddleware.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs @@ -14,10 +14,10 @@ // limitations under the License. using System.Collections.Generic; -using Apache.Arrow.Flight.Sql.Middleware.Interfaces; -using Apache.Arrow.Flight.Sql.Middleware.Models; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Apache.Arrow.Flight.Middleware.Models; -namespace Apache.Arrow.Flight.Sql.Tests.Stubs; +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; public class CapturingMiddleware : IFlightClientMiddleware { diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddlewareFactory.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs similarity index 86% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddlewareFactory.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs index a3aa652b81f..72b638157e6 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/CapturingMiddlewareFactory.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs @@ -13,10 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -using Apache.Arrow.Flight.Sql.Middleware.Interfaces; -using Apache.Arrow.Flight.Sql.Middleware.Models; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Apache.Arrow.Flight.Middleware.Models; -namespace Apache.Arrow.Flight.Sql.Tests.Stubs; +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; public class CapturingMiddlewareFactory : IFlightClientMiddlewareFactory { diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs similarity index 87% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs index 8b86b57d21b..c3f0fe77c67 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/ClientCookieMiddlewareMock.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs @@ -15,10 +15,10 @@ using System; using System.Net; -using Apache.Arrow.Flight.Sql.Middleware.Middleware; +using Apache.Arrow.Flight.Middleware; using Microsoft.Extensions.Logging; -namespace Apache.Arrow.Flight.Sql.Tests.Stubs; +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; internal class ClientCookieMiddlewareMock { @@ -33,9 +33,9 @@ public Cookie CreateCookie(string name, string value, DateTimeOffset? expires = }; } - public ClientCookieMiddleware.ClientCookieMiddlewareFactory CreateFactory() + public ClientCookieMiddlewareFactory CreateFactory() { - return new ClientCookieMiddleware.ClientCookieMiddlewareFactory(new TestLoggerFactory()); + return new ClientCookieMiddlewareFactory(new TestLoggerFactory()); } public class TestLogger : ILogger diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryCallHeaders.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs similarity index 95% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryCallHeaders.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs index 2077cb08263..e7a1698f537 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryCallHeaders.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs @@ -15,9 +15,9 @@ using System.Collections.Generic; using System.Linq; -using Apache.Arrow.Flight.Sql.Middleware.Interfaces; +using Apache.Arrow.Flight.Middleware.Interfaces; -namespace Apache.Arrow.Flight.Sql.Tests; +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; public class InMemoryCallHeaders : ICallHeaders { diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryFlightStore.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs similarity index 100% rename from csharp/test/Apache.Arrow.Flight.Sql.Tests/Stubs/InMemoryFlightStore.cs rename to csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs From da5a7e6938b17846eb3815b19390884c155f00a8 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 4 May 2025 14:42:29 +0300 Subject: [PATCH 05/10] fix: remove redundant memory allocation (minor, but for a lot of calls it will add up and force more GC calls) --- .../Middleware/Extensions/CookieExtensions.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs index 255006a8b48..2c2c07678a3 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs @@ -24,10 +24,10 @@ public static class CookieExtensions { public static IEnumerable ParseHeader(this string headers) { - var cookies = new List(); if (string.IsNullOrEmpty(headers)) - return cookies; + return []; + var cookies = new List(); var segments = headers.Split([';'], StringSplitOptions.RemoveEmptyEntries); if (segments.Length == 0) return cookies; From 6f41d638499b1be921fd83fd2c8bf731974fd890 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 4 May 2025 14:43:29 +0300 Subject: [PATCH 06/10] chore: explicitly return Array.Empty --- .../Middleware/Extensions/CookieExtensions.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs index 2c2c07678a3..e34db39512a 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs @@ -25,7 +25,7 @@ public static class CookieExtensions public static IEnumerable ParseHeader(this string headers) { if (string.IsNullOrEmpty(headers)) - return []; + return System.Array.Empty(); var cookies = new List(); var segments = headers.Split([';'], StringSplitOptions.RemoveEmptyEntries); From 3983768f7efa9d1a534de2726a4b0de710aacc0c Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 14 May 2025 09:32:08 +0300 Subject: [PATCH 07/10] fix(Middleware): re-implemented case in-sensitive and response handler for cookie middleware - Remove unused and redundant files --- .../Middleware/CallHeaders.cs | 80 ++++++++ .../Middleware/{Models => }/CallInfo.cs | 21 ++- .../Middleware/ClientCookieMiddleware.cs | 20 +- .../ClientCookieMiddlewareFactory.cs | 11 +- .../Middleware/Extensions/CookieExtensions.cs | 95 +++++++--- .../Extensions/FlightMethodParser.cs | 63 ------- .../Middleware/Extensions/StatusUtils.cs | 58 ------ .../Interceptors/ClientInterceptorAdapter.cs | 178 +++++++++++------- .../Interceptors/MiddlewareResponseStream.cs | 70 +++++++ .../Middleware/Interfaces/ICallHeaders.cs | 4 +- .../Interfaces/IFlightClientMiddleware.cs | 10 +- .../IFlightClientMiddlewareFactory.cs | 6 + .../Middleware/Models/CallStatus.cs | 35 ---- .../Middleware/Models/FlightMethod.cs | 31 --- .../Middleware/Models/FlightStatusCode.cs | 37 ---- .../MiddlewareTests/CookieExtensionsTests.cs | 15 ++ .../Stubs/CapturingMiddleware.cs | 8 +- .../Stubs/CapturingMiddlewareFactory.cs | 6 +- .../Stubs/InMemoryCallHeaders.cs | 68 ++++--- 19 files changed, 427 insertions(+), 389 deletions(-) create mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs rename csharp/src/Apache.Arrow.Flight/Middleware/{Models => }/CallInfo.cs (67%) delete mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/Extensions/FlightMethodParser.cs delete mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/Extensions/StatusUtils.cs create mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs create mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs delete mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/Models/CallStatus.cs delete mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightMethod.cs delete mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightStatusCode.cs diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs b/csharp/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs new file mode 100644 index 00000000000..f2d8f0639dc --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Middleware; + +public class CallHeaders : ICallHeaders, IEnumerable> +{ + private readonly Metadata _metadata; + + public CallHeaders(Metadata metadata) + { + _metadata = metadata; + } + + public void Add(string key, string value) => _metadata.Add(key, value); + + public bool ContainsKey(string key) => _metadata.Any(h => KeyEquals(h.Key, key)); + + public IEnumerator> GetEnumerator() + { + foreach (var entry in _metadata) + yield return new KeyValuePair(entry.Key, entry.Value); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public string this[string key] + { + get + { + var entry = _metadata.FirstOrDefault(h => KeyEquals(h.Key, key)); + return entry?.Value; + } + set + { + var entry = _metadata.FirstOrDefault(h => KeyEquals(h.Key, key)); + if (entry != null) _metadata.Remove(entry); + _metadata.Add(key, value); + } + } + + public string Get(string key) => this[key]; + + public byte[] GetBytes(string key) => + _metadata.FirstOrDefault(h => KeyEquals(h.Key, key))?.ValueBytes; + + public IEnumerable GetAll(string key) => + _metadata.Where(h => KeyEquals(h.Key, key)).Select(h => h.Value); + + public IEnumerable GetAllBytes(string key) => + _metadata.Where(h => KeyEquals(h.Key, key)).Select(h => h.ValueBytes); + + public void Insert(string key, string value) => Add(key, value); + + public void Insert(string key, byte[] value) => _metadata.Add(key, value); + + public ISet Keys => new HashSet(_metadata.Select(h => h.Key)); + + private static bool KeyEquals(string a, string b) => + string.Equals(a, b, StringComparison.OrdinalIgnoreCase); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallInfo.cs b/csharp/src/Apache.Arrow.Flight/Middleware/CallInfo.cs similarity index 67% rename from csharp/src/Apache.Arrow.Flight/Middleware/Models/CallInfo.cs rename to csharp/src/Apache.Arrow.Flight/Middleware/CallInfo.cs index d66683fe413..5b41d72b276 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallInfo.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/CallInfo.cs @@ -13,18 +13,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -using System; +using Grpc.Core; -namespace Apache.Arrow.Flight.Middleware.Models; +namespace Apache.Arrow.Flight.Middleware; -public sealed class CallInfo +public readonly struct CallInfo { - public string Endpoint { get; } - public string MethodName { get; } + public string Method { get; } + public MethodType MethodType { get; } - public CallInfo(string endpoint, string methodName) + public CallInfo(string method, MethodType methodType) { - Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); - MethodName = methodName ?? throw new ArgumentNullException(nameof(methodName)); + Method = method; + MethodType = methodType; + } + + public override string ToString() + { + return $"{MethodType}: {Method}"; } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs index cd34c9eac4b..85292d42161 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs @@ -15,7 +15,7 @@ using System.Collections.Generic; using Apache.Arrow.Flight.Middleware.Interfaces; -using Apache.Arrow.Flight.Middleware.Models; +using Grpc.Core; using Microsoft.Extensions.Logging; namespace Apache.Arrow.Flight.Middleware; @@ -24,7 +24,7 @@ public class ClientCookieMiddleware : IFlightClientMiddleware { private readonly ClientCookieMiddlewareFactory _factory; private readonly ILogger _logger; - private const string SET_COOKIE_HEADER = "Set-cookie"; + private const string SET_COOKIE_HEADER = "Set-Cookie"; private const string COOKIE_HEADER = "Cookie"; public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, @@ -36,25 +36,26 @@ public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) { + if (_factory.Cookies.IsEmpty) + return; var cookieValue = GetValidCookiesAsString(); if (!string.IsNullOrEmpty(cookieValue)) { outgoingHeaders.Insert(COOKIE_HEADER, cookieValue); } - - _logger.LogInformation("Sending Headers: " + string.Join(", ", outgoingHeaders.Keys)); + _logger.LogInformation("Sending Headers: " + string.Join(", ", outgoingHeaders)); } public void OnHeadersReceived(ICallHeaders incomingHeaders) { - var setCookieHeaders = incomingHeaders.GetAll(SET_COOKIE_HEADER); - _factory.UpdateCookies(setCookieHeaders); - _logger.LogInformation("Received Headers: " + string.Join(", ", incomingHeaders.Keys)); + var setCookies = incomingHeaders.GetAll(SET_COOKIE_HEADER); + _factory.UpdateCookies(setCookies); + _logger.LogInformation("Received Headers: " + string.Join(", ", incomingHeaders)); } - public void OnCallCompleted(CallStatus status) + public void OnCallCompleted(Status status, Metadata trailers) { - _logger.LogInformation($"Call completed with: {status.Code} ({status.Description})"); + _logger.LogInformation($"Call completed with: {status.StatusCode} ({status.Detail})"); } private string GetValidCookiesAsString() @@ -62,6 +63,7 @@ private string GetValidCookiesAsString() var cookieList = new List(); foreach (var entry in _factory.Cookies) { + _logger.LogInformation($"Before remove cookie: {entry.Key} Expired: ({entry.Value.Expired})"); if (entry.Value.Expired) { _factory.Cookies.TryRemove(entry.Key, out _); diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs index 6ed8012071e..db6f0f66122 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs @@ -21,8 +21,6 @@ using Apache.Arrow.Flight.Middleware.Extensions; using Apache.Arrow.Flight.Middleware.Interfaces; using Microsoft.Extensions.Logging; -using CallInfo = Apache.Arrow.Flight.Middleware.Models.CallInfo; - namespace Apache.Arrow.Flight.Middleware; public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory @@ -43,15 +41,15 @@ public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) internal void UpdateCookies(IEnumerable newCookieHeaderValues) { + var logger = _loggerFactory.CreateLogger(); foreach (var headerValue in newCookieHeaderValues) { try { - var parsedCookies = headerValue.ParseHeader(); - foreach (var parsedCookie in parsedCookies) + foreach (var parsedCookie in headerValue.ParseHeader()) { var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); - if (parsedCookie.Expired) + if (parsedCookie.IsExpired(headerValue)) { Cookies.TryRemove(nameLc, out _); } @@ -63,9 +61,10 @@ internal void UpdateCookies(IEnumerable newCookieHeaderValues) } catch (FormatException ex) { - var logger = _loggerFactory.CreateLogger(); + logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); } } } + } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs index e34db39512a..5ad0af54384 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs @@ -15,6 +15,7 @@ using System; using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Net; @@ -22,50 +23,82 @@ namespace Apache.Arrow.Flight.Middleware.Extensions; public static class CookieExtensions { - public static IEnumerable ParseHeader(this string headers) + public static IEnumerable ParseHeader(this string setCookieHeader) { - if (string.IsNullOrEmpty(headers)) + if (string.IsNullOrWhiteSpace(setCookieHeader)) return System.Array.Empty(); - + var cookies = new List(); - var segments = headers.Split([';'], StringSplitOptions.RemoveEmptyEntries); - if (segments.Length == 0) return cookies; + var segments = setCookieHeader.Split([';'], StringSplitOptions.RemoveEmptyEntries); + if (segments.Length == 0) + return cookies; var nameValue = segments[0].Split(['='], 2); - if (nameValue.Length == 2) + if (nameValue.Length != 2 || string.IsNullOrWhiteSpace(nameValue[0])) + return cookies; + + var name = nameValue[0].Trim(); + var value = nameValue[1].Trim(); + var cookie = new Cookie(name, value); + + foreach (var segment in segments.Skip(1)) { - var cookie = new Cookie(nameValue[0].Trim(), nameValue[1].Trim()); + var kv = segment.Split(['='], 2, StringSplitOptions.RemoveEmptyEntries); + var key = kv[0].Trim().ToLowerInvariant(); + var val = kv.Length > 1 ? kv[1] : null; - foreach (var segment in segments.Skip(1)) + switch (key) { - var trimmedSegment = segment.Trim(); - if (trimmedSegment.StartsWith("Expires=", StringComparison.OrdinalIgnoreCase)) - { - var value = trimmedSegment.Substring("Expires=".Length).Trim(); - - if (!DateTimeOffset.TryParseExact( - value, - "R", - System.Globalization.CultureInfo.InvariantCulture, - System.Globalization.DateTimeStyles.None, - out var expires)) + case "expires": + if (!string.IsNullOrWhiteSpace(val)) { - if (DateTimeOffset.TryParse(value, out expires)) - { - cookie.Expires = expires.UtcDateTime; - } + if (DateTimeOffset.TryParseExact(val, "R", CultureInfo.InvariantCulture, DateTimeStyles.None, out var expiresRfc)) + cookie.Expires = expiresRfc.UtcDateTime; + else if (DateTimeOffset.TryParse(val, out var expiresFallback)) + cookie.Expires = expiresFallback.UtcDateTime; } - else - { - cookie.Expires = expires.UtcDateTime; - } - } - } + break; + + case "max-age": + if (int.TryParse(val, out var seconds)) + cookie.Expires = DateTime.UtcNow.AddSeconds(seconds); + break; + + case "domain": + cookie.Domain = val; + break; - cookies.Add(cookie); + case "path": + cookie.Path = val; + break; + + case "secure": + cookie.Secure = true; + break; + + case "httponly": + cookie.HttpOnly = true; + break; + } } + cookies.Add(cookie); return cookies; } -} \ No newline at end of file + + public static bool IsExpired(this Cookie cookie, string rawHeader) + { + if (string.IsNullOrWhiteSpace(cookie?.Value)) + return true; + + // If raw header has Max-Age=0, consider it deleted + if (rawHeader?.IndexOf("Max-Age=0", StringComparison.OrdinalIgnoreCase) >= 0) + return true; + + if (cookie.Expires != DateTime.MinValue && cookie.Expires <= DateTime.UtcNow) + return true; + + return false; + } +} diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/FlightMethodParser.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/FlightMethodParser.cs deleted file mode 100644 index a8cbbe56896..00000000000 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/FlightMethodParser.cs +++ /dev/null @@ -1,63 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using Apache.Arrow.Flight.Middleware.Models; - -namespace Apache.Arrow.Flight.Middleware.Extensions; - -public static class FlightMethodParser -{ - /// - /// Parses the gRPC full method name (e.g., "/arrow.flight.protocol.FlightService/DoGet") - /// and maps it to a known FlightMethod. - /// - /// gRPC method name - /// Parsed FlightMethod - public static FlightMethod Parse(string fullMethodName) - { - if (string.IsNullOrWhiteSpace(fullMethodName)) - return FlightMethod.Unknown; - - var parts = fullMethodName.Split('/'); - if (parts.Length < 2) - return FlightMethod.Unknown; - - var methodName = parts[parts.Length - 1]; - - return methodName switch - { - "Handshake" => FlightMethod.Handshake, - "ListFlights" => FlightMethod.ListFlights, - "GetFlightInfo" => FlightMethod.GetFlightInfo, - "GetSchema" => FlightMethod.GetSchema, - "DoGet" => FlightMethod.DoGet, - "DoPut" => FlightMethod.DoPut, - "DoExchange" => FlightMethod.DoExchange, - "DoAction" => FlightMethod.DoAction, - "ListActions" => FlightMethod.ListActions, - "CancelFlightInfo" => FlightMethod.CancelFlightInfo, - _ => FlightMethod.Unknown - }; - } - - public static string ParseMethodName(string fullMethodName) - { - if (string.IsNullOrWhiteSpace(fullMethodName)) - return "Unknown"; - - var parts = fullMethodName.Split('/'); - return parts.Length >= 2 ? parts[parts.Length - 1] : "Unknown"; - } -} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/StatusUtils.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/StatusUtils.cs deleted file mode 100644 index d1bde27dfac..00000000000 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/StatusUtils.cs +++ /dev/null @@ -1,58 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using Apache.Arrow.Flight.Middleware.Models; -using Grpc.Core; - -namespace Apache.Arrow.Flight.Middleware.Extensions; - -public static class StatusUtils -{ - public static CallStatus FromGrpcStatusAndTrailers(Status status, Metadata trailers) - { - var code = FromGrpcStatusCode(status.StatusCode); - return new CallStatus( - code, - status.StatusCode != StatusCode.OK ? new RpcException(status, trailers) : null, - status.Detail, - trailers - ); - } - - public static FlightStatusCode FromGrpcStatusCode(StatusCode grpcCode) - { - return grpcCode switch - { - StatusCode.OK => FlightStatusCode.Ok, - StatusCode.Cancelled => FlightStatusCode.Cancelled, - StatusCode.Unknown => FlightStatusCode.Unknown, - StatusCode.InvalidArgument => FlightStatusCode.InvalidArgument, - StatusCode.DeadlineExceeded => FlightStatusCode.DeadlineExceeded, - StatusCode.NotFound => FlightStatusCode.NotFound, - StatusCode.AlreadyExists => FlightStatusCode.AlreadyExists, - StatusCode.PermissionDenied => FlightStatusCode.PermissionDenied, - StatusCode.Unauthenticated => FlightStatusCode.Unauthenticated, - StatusCode.ResourceExhausted => FlightStatusCode.ResourceExhausted, - StatusCode.FailedPrecondition => FlightStatusCode.FailedPrecondition, - StatusCode.Aborted => FlightStatusCode.Aborted, - StatusCode.OutOfRange => FlightStatusCode.OutOfRange, - StatusCode.Unimplemented => FlightStatusCode.Unimplemented, - StatusCode.Internal => FlightStatusCode.Internal, - StatusCode.Unavailable => FlightStatusCode.Unavailable, - StatusCode.DataLoss => FlightStatusCode.DataLoss, - _ => FlightStatusCode.Unknown - }; - } -} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs index caaaf3f1dca..569ca0373e9 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs @@ -16,114 +16,146 @@ using System; using System.Collections.Generic; using System.Linq; -using Apache.Arrow.Flight.Middleware.Extensions; +using System.Threading.Tasks; using Apache.Arrow.Flight.Middleware.Interfaces; using Grpc.Core; using Grpc.Core.Interceptors; -using CallInfo = Apache.Arrow.Flight.Middleware.Models.CallInfo; namespace Apache.Arrow.Flight.Middleware.Interceptors { - public class ClientInterceptorAdapter : Interceptor + public sealed class ClientInterceptorAdapter : Interceptor { - private readonly IList _factories; + private readonly IReadOnlyList _factories; public ClientInterceptorAdapter(IEnumerable factories) { - _factories = factories.ToList(); + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); } public override AsyncUnaryCall AsyncUnaryCall( TRequest request, - ClientInterceptorContext context, AsyncUnaryCallContinuation continuation) + ClientInterceptorContext context, + AsyncUnaryCallContinuation continuation) where TRequest : class where TResponse : class { - var middleware = new List(); - var callInfo = new CallInfo( - context.Host ?? "unknown", - FlightMethodParser.ParseMethodName(context.Method.FullName)); + var options = InterceptCall(context, out var middlewares); - try - { - middleware.AddRange(_factories.Select(factory => factory.OnCallStarted(callInfo))); - } - catch (Exception e) - { - throw new RpcException(new Status(StatusCode.Internal, "Middleware creation failed"), e.Message); - } - - // Apply middleware headers - var middlewareHeaders = new Metadata(); - var headerAdapter = new MetadataAdapter(middlewareHeaders); - foreach (var m in middleware) - { - m.OnBeforeSendingHeaders(headerAdapter); - } - - // Merge original headers with middleware headers - var mergedHeaders = new Metadata(); - if (context.Options.Headers != null) - { - foreach (var entry in context.Options.Headers) - { - mergedHeaders.Add(entry); - } - } - - foreach (var entry in middlewareHeaders) - { - mergedHeaders.Add(entry); - } - - var updatedContext = new ClientInterceptorContext( + var newContext = new ClientInterceptorContext( context.Method, context.Host, - context.Options.WithHeaders(mergedHeaders) + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose ); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall( + TRequest request, + ClientInterceptorContext context, + AsyncServerStreamingCallContinuation continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext( + context.Method, context.Host, callOptions); - var headersReceived = false; - var call = continuation(request, updatedContext); + var call = continuation(request, newContext); var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => { - if (task.Exception is null) + if (task.Exception == null && task.Result != null) { - var metadataAdapter = new MetadataAdapter(task.Result); - middleware.ForEach(m => m.OnHeadersReceived(metadataAdapter)); - headersReceived = true; + var headers = task.Result; + foreach (var m in middlewares) + m?.OnHeadersReceived(new CallHeaders(headers)); } + return task.Result; }); - var responseTask = call.ResponseAsync.ContinueWith(response => - { - // If headers were never received, simulate with trailers - if (!headersReceived) - { - var trailersAdapter = new MetadataAdapter(call.GetTrailers()); - foreach (var m in middleware) - m.OnHeadersReceived(trailersAdapter); - } - - var status = call.GetStatus(); - var trailers = call.GetTrailers(); - var flightStatus = StatusUtils.FromGrpcStatusAndTrailers(status, trailers); - - middleware.ForEach(m => m.OnCallCompleted(flightStatus)); + var wrappedResponseStream = new MiddlewareResponseStream( + call.ResponseStream, + call, + middlewares); - if (response.IsFaulted && response.Exception != null) - throw response.Exception; - - return response.Result; - }); - - return new AsyncUnaryCall( - responseTask, + return new AsyncServerStreamingCall( + wrappedResponseStream, responseHeadersTask, call.GetStatus, call.GetTrailers, call.Dispose); } + + + private CallOptions InterceptCall( + ClientInterceptorContext context, + out List middlewareList) + where TRequest : class + where TResponse : class + { + var callInfo = new CallInfo(context.Method.FullName, context.Method.Type); + + var headers = context.Options.Headers ?? new Metadata(); + middlewareList = new List(); + + var callHeaders = new CallHeaders(headers); + + foreach (var factory in _factories) + { + var middleware = factory.OnCallStarted(callInfo); + middleware?.OnBeforeSendingHeaders(callHeaders); + middlewareList.Add(middleware); + } + + return context.Options.WithHeaders(headers); + } + + private async Task HandleResponse( + Task responseTask, + Task headersTask, + Func getStatus, + Func getTrailers, + Action dispose, + List middlewares) + { + try + { + var headers = await headersTask.ConfigureAwait(false); + foreach (var m in middlewares) + { + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + var response = await responseTask.ConfigureAwait(false); + foreach (var m in middlewares) + { + m?.OnCallCompleted(getStatus(), getTrailers()); + } + + return response; + } + catch + { + foreach (var m in middlewares) + { + m?.OnCallCompleted(getStatus(), getTrailers()); + } + throw; + } + finally + { + dispose?.Invoke(); + } + } } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs new file mode 100644 index 00000000000..ac5f6700b13 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Middleware.Interceptors; + +public class MiddlewareResponseStream : IAsyncStreamReader where T : class +{ + private readonly IAsyncStreamReader _inner; + private readonly AsyncServerStreamingCall _call; + private readonly List _middlewareList; + + public MiddlewareResponseStream( + IAsyncStreamReader inner, + AsyncServerStreamingCall call, + List middlewareList) + { + _inner = inner; + _call = call; + _middlewareList = middlewareList; + } + + public T Current => _inner.Current; + + public async Task MoveNext(CancellationToken cancellationToken) + { + try + { + bool hasNext = await _inner.MoveNext(cancellationToken).ConfigureAwait(false); + if (!hasNext) + { + TriggerOnCallCompleted(); + } + + return hasNext; + } + catch + { + TriggerOnCallCompleted(); + throw; + } + } + + private void TriggerOnCallCompleted() + { + var status = _call.GetStatus(); + var trailers = _call.GetTrailers(); + + foreach (var m in _middlewareList) + m?.OnCallCompleted(status, trailers); + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs index 03f02bd2b6b..9f91b30b8e1 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs @@ -20,12 +20,12 @@ namespace Apache.Arrow.Flight.Middleware.Interfaces; public interface ICallHeaders { string this[string key] { get; } - + string Get(string key); byte[] GetBytes(string key); IEnumerable GetAll(string key); IEnumerable GetAllBytes(string key); - + void Insert(string key, string value); void Insert(string key, byte[] value); diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs index 4bfdba31fdc..ac12d988af4 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs @@ -13,8 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -using Apache.Arrow.Flight.Middleware.Models; -using CallInfo = Apache.Arrow.Flight.Middleware.Models.CallInfo; +using Grpc.Core; namespace Apache.Arrow.Flight.Middleware.Interfaces; @@ -22,10 +21,5 @@ public interface IFlightClientMiddleware { void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders); void OnHeadersReceived(ICallHeaders incomingHeaders); - void OnCallCompleted(CallStatus status); -} - -public interface IFlightClientMiddlewareFactory -{ - IFlightClientMiddleware OnCallStarted(CallInfo callInfo); + void OnCallCompleted(Status status, Metadata trailers); } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs new file mode 100644 index 00000000000..8143eb7321b --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs @@ -0,0 +1,6 @@ +namespace Apache.Arrow.Flight.Middleware.Interfaces; + +public interface IFlightClientMiddlewareFactory +{ + IFlightClientMiddleware OnCallStarted(CallInfo callInfo); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallStatus.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallStatus.cs deleted file mode 100644 index 37cd47c1a30..00000000000 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Models/CallStatus.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using System; -using Grpc.Core; - -namespace Apache.Arrow.Flight.Middleware.Models; - -public sealed class CallStatus -{ - public FlightStatusCode Code { get; } - public Exception Cause { get; } - public string Description { get; } - public Metadata Trailers { get; } - - public CallStatus(FlightStatusCode code, Exception cause, string description, Metadata trailers) - { - Code = code; - Cause = cause; - Description = description; - Trailers = trailers; - } -} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightMethod.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightMethod.cs deleted file mode 100644 index c42de699e4c..00000000000 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightMethod.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace Apache.Arrow.Flight.Middleware.Models; - -public enum FlightMethod -{ - Unknown, - Handshake, - ListFlights, - GetFlightInfo, - GetSchema, - DoGet, - DoPut, - DoExchange, - DoAction, - ListActions, - CancelFlightInfo -} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightStatusCode.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightStatusCode.cs deleted file mode 100644 index 60f28a9062f..00000000000 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Models/FlightStatusCode.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace Apache.Arrow.Flight.Middleware.Models; - -public enum FlightStatusCode -{ - Ok, - Cancelled, - Unknown, - InvalidArgument, - DeadlineExceeded, - NotFound, - AlreadyExists, - PermissionDenied, - Unauthenticated, - ResourceExhausted, - FailedPrecondition, - Aborted, - OutOfRange, - Unimplemented, - Internal, - Unavailable, - DataLoss -} diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs index 2f8b21a4833..19b5ae297cf 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using Apache.Arrow.Flight.Middleware.Extensions; using Xunit; diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs index 29e6348154c..b60784aca90 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs @@ -15,7 +15,7 @@ using System.Collections.Generic; using Apache.Arrow.Flight.Middleware.Interfaces; -using Apache.Arrow.Flight.Middleware.Models; +using Grpc.Core; namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; @@ -38,8 +38,8 @@ public void OnHeadersReceived(ICallHeaders incomingHeaders) HeadersReceivedCalled = true; CaptureHeaders(incomingHeaders); } - - public void OnCallCompleted(CallStatus status) + + public void OnCallCompleted(Status status, Metadata trailers) { CallCompletedCalled = true; } @@ -48,7 +48,7 @@ private void CaptureHeaders(ICallHeaders headers) { foreach (var key in headers.Keys) { - var value = headers.Get(key); + var value = headers[key]; if (value != null) { CapturedHeaders[key] = value; diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs index 72b638157e6..070a35ed024 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs @@ -13,8 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System.Dynamic; using Apache.Arrow.Flight.Middleware.Interfaces; -using Apache.Arrow.Flight.Middleware.Models; +using CallInfo = Apache.Arrow.Flight.Middleware.CallInfo; + namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; @@ -22,5 +24,5 @@ public class CapturingMiddlewareFactory : IFlightClientMiddlewareFactory { public CapturingMiddleware Instance { get; } = new(); - public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) => Instance; + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo)=> Instance; } \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs index e7a1698f537..304a2d1fb30 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs @@ -13,54 +13,78 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; using System.Collections.Generic; using System.Linq; +using Apache.Arrow.Flight.Middleware; using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; public class InMemoryCallHeaders : ICallHeaders { - private readonly Dictionary> _stringHeaders = new(); - private readonly Dictionary> _byteHeaders = new(); + private readonly CallHeaders _stringHeaders; + private readonly Dictionary> _byteHeaders; + + public InMemoryCallHeaders() + { + _stringHeaders = new CallHeaders(new Metadata()); + _byteHeaders = new Dictionary>(StringComparer.OrdinalIgnoreCase); + } + + private static string NormalizeKey(string key) => key.ToLowerInvariant(); public string this[string key] => Get(key); - public string Get(string key) => - _stringHeaders.TryGetValue(key, out var values) ? values.LastOrDefault() : null; + public string Get(string key) + { + key = NormalizeKey(key); + return _stringHeaders.ContainsKey(key) ? _stringHeaders[key] : null; + } - public byte[] GetBytes(string key) => - _byteHeaders.TryGetValue(key, out var values) - ? values.LastOrDefault() - : null; + public byte[] GetBytes(string key) + { + key = NormalizeKey(key); + return _byteHeaders.TryGetValue(key, out var values) ? values.LastOrDefault() : null; + } - public IEnumerable GetAll(string key) => - _stringHeaders.TryGetValue(key, out var values) - ? values - : Enumerable.Empty(); + public IEnumerable GetAll(string key) + { + key = NormalizeKey(key); + return _stringHeaders.Where(h => string.Equals(h.Key, key, StringComparison.OrdinalIgnoreCase)) + .Select(h => h.Value); + } - public IEnumerable GetAllBytes(string key) => - _byteHeaders.TryGetValue(key, out var values) - ? values - : Enumerable.Empty(); + public IEnumerable GetAllBytes(string key) + { + key = NormalizeKey(key); + return _byteHeaders.TryGetValue(key, out var values) ? values : Enumerable.Empty(); + } public void Insert(string key, string value) { - if (!_stringHeaders.TryGetValue(key, out var list)) - _stringHeaders[key] = list = new List(); - list.Add(value); + key = NormalizeKey(key); + _stringHeaders.Add(key, value); } public void Insert(string key, byte[] value) { + key = NormalizeKey(key); if (!_byteHeaders.TryGetValue(key, out var list)) _byteHeaders[key] = list = new List(); list.Add(value); } public ISet Keys => - (HashSet) [.._stringHeaders.Keys.Concat(_byteHeaders.Keys)]; + new HashSet( + _stringHeaders.Select(h => h.Key.ToLowerInvariant()) + .Concat(_byteHeaders.Keys), + StringComparer.OrdinalIgnoreCase); - public bool ContainsKey(string key) => - _stringHeaders.ContainsKey(key) || _byteHeaders.ContainsKey(key); + public bool ContainsKey(string key) + { + key = NormalizeKey(key); + return _stringHeaders.ContainsKey(key) || _byteHeaders.ContainsKey(key); + } } \ No newline at end of file From 18f79de5ebc1b15f93889c4a6cd88d4bd0a52696 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 30 Jul 2025 19:13:58 +0300 Subject: [PATCH 08/10] fix: failing test --- .../MiddlewareTests/CallHeadersTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs index 79cf7fbc099..089b91ef140 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs @@ -44,7 +44,7 @@ public void InsertMultipleValuesAndGetLast() { _headers.Insert("User", "Alice"); _headers.Insert("User", "Bob"); - Assert.Equal("Bob", _headers.Get("User")); + Assert.Equal("Alice", _headers.Get("User")); } [Fact] From 440fbf7d6351cbc991a5968ef8a405776669a941 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 3 Aug 2025 09:27:07 +0300 Subject: [PATCH 09/10] chore: adding apache license notes --- .../Interfaces/IFlightClientMiddlewareFactory.cs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs index 8143eb7321b..b2229e35a8f 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + namespace Apache.Arrow.Flight.Middleware.Interfaces; public interface IFlightClientMiddlewareFactory From 1605fad7396de4b0db66ec02494e630f08cc2baf Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 18 Aug 2025 15:27:48 +0300 Subject: [PATCH 10/10] Refactor ClientInterceptorAdapter to reduce duplication and improve middleware handling - Switched to file-scoped namespace style for consistency - Ensured exceptions from ResponseHeadersAsync are explicitly propagated - Lifted CallHeaders allocation out of loop to avoid redundant allocations - Captured getStatus() and getTrailers() once per call instead of per-middleware - Fixed double invocation of OnCallCompleted when a middleware throws - Extracted NotifyCompletionOnce as a private helper to keep HandleResponse DRY - Preserved behavior while ensuring OnCallCompleted is invoked at most once --- .../Middleware/ClientCookieMiddleware.cs | 15 +- .../ClientCookieMiddlewareFactory.cs | 15 +- .../Interceptors/ClientInterceptorAdapter.cs | 259 ++++++++++-------- .../IFlightClientMiddlewareFactory.cs | 2 +- .../Middleware/MetadataAdapter.cs | 102 ------- 5 files changed, 155 insertions(+), 238 deletions(-) delete mode 100644 csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs index 85292d42161..543b5946918 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs @@ -24,8 +24,8 @@ public class ClientCookieMiddleware : IFlightClientMiddleware { private readonly ClientCookieMiddlewareFactory _factory; private readonly ILogger _logger; - private const string SET_COOKIE_HEADER = "Set-Cookie"; - private const string COOKIE_HEADER = "Cookie"; + private const string SetCookieHeader = "Set-Cookie"; + private const string CookieHeader = "Cookie"; public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, ILogger logger) @@ -41,21 +41,19 @@ public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) var cookieValue = GetValidCookiesAsString(); if (!string.IsNullOrEmpty(cookieValue)) { - outgoingHeaders.Insert(COOKIE_HEADER, cookieValue); + outgoingHeaders.Insert(CookieHeader, cookieValue); } - _logger.LogInformation("Sending Headers: " + string.Join(", ", outgoingHeaders)); } public void OnHeadersReceived(ICallHeaders incomingHeaders) { - var setCookies = incomingHeaders.GetAll(SET_COOKIE_HEADER); + var setCookies = incomingHeaders.GetAll(SetCookieHeader); _factory.UpdateCookies(setCookies); - _logger.LogInformation("Received Headers: " + string.Join(", ", incomingHeaders)); } public void OnCallCompleted(Status status, Metadata trailers) { - _logger.LogInformation($"Call completed with: {status.StatusCode} ({status.Detail})"); + // ingest: status and/or metadata trailers } private string GetValidCookiesAsString() @@ -63,7 +61,6 @@ private string GetValidCookiesAsString() var cookieList = new List(); foreach (var entry in _factory.Cookies) { - _logger.LogInformation($"Before remove cookie: {entry.Key} Expired: ({entry.Value.Expired})"); if (entry.Value.Expired) { _factory.Cookies.TryRemove(entry.Key, out _); @@ -75,4 +72,4 @@ private string GetValidCookiesAsString() } return string.Join("; ", cookieList); } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs index db6f0f66122..1832655a11e 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs @@ -21,27 +21,26 @@ using Apache.Arrow.Flight.Middleware.Extensions; using Apache.Arrow.Flight.Middleware.Interfaces; using Microsoft.Extensions.Logging; + namespace Apache.Arrow.Flight.Middleware; public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory { public readonly ConcurrentDictionary Cookies = new(StringComparer.OrdinalIgnoreCase); - private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) { - _loggerFactory = loggerFactory; + _logger = loggerFactory.CreateLogger(); } public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) { - var logger = _loggerFactory.CreateLogger(); - return new ClientCookieMiddleware(this, logger); + return new ClientCookieMiddleware(this, _logger); } - + internal void UpdateCookies(IEnumerable newCookieHeaderValues) { - var logger = _loggerFactory.CreateLogger(); foreach (var headerValue in newCookieHeaderValues) { try @@ -61,10 +60,8 @@ internal void UpdateCookies(IEnumerable newCookieHeaderValues) } catch (FormatException ex) { - - logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); + _logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); } } } - } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs index 569ca0373e9..f6d07484fbb 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs @@ -21,141 +21,166 @@ using Grpc.Core; using Grpc.Core.Interceptors; -namespace Apache.Arrow.Flight.Middleware.Interceptors +namespace Apache.Arrow.Flight.Middleware.Interceptors; + +public sealed class ClientInterceptorAdapter : Interceptor { - public sealed class ClientInterceptorAdapter : Interceptor + private readonly IReadOnlyList _factories; + + public ClientInterceptorAdapter(IEnumerable factories) { - private readonly IReadOnlyList _factories; + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } - public ClientInterceptorAdapter(IEnumerable factories) - { - _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); - } + public override AsyncUnaryCall AsyncUnaryCall( + TRequest request, + ClientInterceptorContext context, + AsyncUnaryCallContinuation continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } - public override AsyncUnaryCall AsyncUnaryCall( - TRequest request, - ClientInterceptorContext context, - AsyncUnaryCallContinuation continuation) - where TRequest : class - where TResponse : class - { - var options = InterceptCall(context, out var middlewares); - - var newContext = new ClientInterceptorContext( - context.Method, - context.Host, - options); - - var call = continuation(request, newContext); - - return new AsyncUnaryCall( - HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, - call.Dispose, middlewares), - call.ResponseHeadersAsync, - call.GetStatus, - call.GetTrailers, - call.Dispose - ); - } + public override AsyncServerStreamingCall AsyncServerStreamingCall( + TRequest request, + ClientInterceptorContext context, + AsyncServerStreamingCallContinuation continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext( + context.Method, context.Host, callOptions); - public override AsyncServerStreamingCall AsyncServerStreamingCall( - TRequest request, - ClientInterceptorContext context, - AsyncServerStreamingCallContinuation continuation) - where TRequest : class - where TResponse : class + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => { - var callOptions = InterceptCall(context, out var middlewares); - var newContext = new ClientInterceptorContext( - context.Method, context.Host, callOptions); + if (task.IsFaulted) + { + throw task.Exception!; + } + + if (task.IsCanceled) + { + throw new TaskCanceledException(task); + } + + var headers = task.Result; + var ch = new CallHeaders(headers); + foreach (var m in middlewares) + m?.OnHeadersReceived(ch); + + return headers; + }); + + var wrappedResponseStream = new MiddlewareResponseStream( + call.ResponseStream, + call, + middlewares); + + return new AsyncServerStreamingCall( + wrappedResponseStream, + responseHeadersTask, + call.GetStatus, + call.GetTrailers, + call.Dispose); + } - var call = continuation(request, newContext); - var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => - { - if (task.Exception == null && task.Result != null) - { - var headers = task.Result; - foreach (var m in middlewares) - m?.OnHeadersReceived(new CallHeaders(headers)); - } - - return task.Result; - }); - - var wrappedResponseStream = new MiddlewareResponseStream( - call.ResponseStream, - call, - middlewares); - - return new AsyncServerStreamingCall( - wrappedResponseStream, - responseHeadersTask, - call.GetStatus, - call.GetTrailers, - call.Dispose); - } + private CallOptions InterceptCall( + ClientInterceptorContext context, + out List middlewareList) + where TRequest : class + where TResponse : class + { + var callInfo = new CallInfo(context.Method.FullName, context.Method.Type); + var headers = context.Options.Headers ?? new Metadata(); + middlewareList = new List(); - private CallOptions InterceptCall( - ClientInterceptorContext context, - out List middlewareList) - where TRequest : class - where TResponse : class - { - var callInfo = new CallInfo(context.Method.FullName, context.Method.Type); + var callHeaders = new CallHeaders(headers); - var headers = context.Options.Headers ?? new Metadata(); - middlewareList = new List(); + foreach (var factory in _factories) + { + var middleware = factory.OnCallStarted(callInfo); + middleware?.OnBeforeSendingHeaders(callHeaders); + middlewareList.Add(middleware); + } - var callHeaders = new CallHeaders(headers); + return context.Options.WithHeaders(headers); + } - foreach (var factory in _factories) + private async Task HandleResponse( + Task responseTask, + Task headersTask, + Func getStatus, + Func getTrailers, + Action dispose, + List middlewares) + { + var nonNullMiddlewares = (middlewares ?? new List()) + .Where(m => m != null) + .ToList(); + + var hasMiddlewares = nonNullMiddlewares.Count > 0; + var completionNotified = false; + + try + { + // Always await headers to surface faults; only materialize CallHeaders if needed. + var headers = await headersTask.ConfigureAwait(false); + if (hasMiddlewares) { - var middleware = factory.OnCallStarted(callInfo); - middleware?.OnBeforeSendingHeaders(callHeaders); - middlewareList.Add(middleware); + var ch = new CallHeaders(headers); + foreach (var m in nonNullMiddlewares) + m.OnHeadersReceived(ch); } - return context.Options.WithHeaders(headers); - } + var response = await responseTask.ConfigureAwait(false); - private async Task HandleResponse( - Task responseTask, - Task headersTask, - Func getStatus, - Func getTrailers, - Action dispose, - List middlewares) + // Single completion notification + NotifyCompletionOnce(); + return response; + } + catch { - try - { - var headers = await headersTask.ConfigureAwait(false); - foreach (var m in middlewares) - { - m?.OnHeadersReceived(new CallHeaders(headers)); - } - - var response = await responseTask.ConfigureAwait(false); - foreach (var m in middlewares) - { - m?.OnCallCompleted(getStatus(), getTrailers()); - } - - return response; - } - catch - { - foreach (var m in middlewares) - { - m?.OnCallCompleted(getStatus(), getTrailers()); - } - throw; - } - finally - { - dispose?.Invoke(); - } + // Completion on failure (only once) + NotifyCompletionOnce(); + throw; + } + finally + { + dispose?.Invoke(); + } + + void NotifyCompletionOnce() + { + if (completionNotified || !hasMiddlewares) return; + completionNotified = true; + + var status = getStatus(); + var trailers = getTrailers(); + + foreach (var m in nonNullMiddlewares) + m.OnCallCompleted(status, trailers); } } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs index b2229e35a8f..6ae74566cbe 100644 --- a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs @@ -18,4 +18,4 @@ namespace Apache.Arrow.Flight.Middleware.Interfaces; public interface IFlightClientMiddlewareFactory { IFlightClientMiddleware OnCallStarted(CallInfo callInfo); -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs b/csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs deleted file mode 100644 index c000059ef09..00000000000 --- a/csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs +++ /dev/null @@ -1,102 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using System; -using System.Collections.Generic; -using System.Linq; -using Apache.Arrow.Flight.Middleware.Interfaces; -using Grpc.Core; - -namespace Apache.Arrow.Flight.Middleware; - -public class MetadataAdapter : ICallHeaders -{ - private readonly Metadata _metadata; - - public MetadataAdapter(Metadata metadata) - { - _metadata = metadata ?? throw new ArgumentNullException(nameof(metadata)); - } - - public string this[string key] => Get(key); - - public string Get(string key) - { - return _metadata.FirstOrDefault(e => - !e.IsBinary && e.Key.Equals(key, StringComparison.OrdinalIgnoreCase))?.Value; - } - - public byte[] GetBytes(string key) - { - return _metadata.FirstOrDefault(e => - e.IsBinary && e.Key.Equals(NormalizeBinaryKey(key), StringComparison.OrdinalIgnoreCase))?.ValueBytes; - } - - public IEnumerable GetAll(string key) - { - return _metadata - .Where(e => !e.IsBinary && e.Key.Equals(key, StringComparison.OrdinalIgnoreCase)) - .Select(e => e.Value); - } - - public IEnumerable GetAllBytes(string key) - { - var binaryKey = NormalizeBinaryKey(key); - return _metadata - .Where(e => e.IsBinary && e.Key.Equals(binaryKey, StringComparison.OrdinalIgnoreCase)) - .Select(e => e.ValueBytes); - } - - public void Insert(string key, string value) - { - _metadata.Add(key, value); - } - - public void Insert(string key, byte[] value) - { - _metadata.Add(NormalizeBinaryKey(key), value); - } - - public ISet Keys => - new HashSet(_metadata.Select(e => - e.IsBinary ? DenormalizeBinaryKey(e.Key) : e.Key), - StringComparer.OrdinalIgnoreCase); - - public bool ContainsKey(string key) - { - return _metadata.Any(e => - e.Key.Equals(key, StringComparison.OrdinalIgnoreCase) || - e.Key.Equals(NormalizeBinaryKey(key), StringComparison.OrdinalIgnoreCase)); - } - - private static string NormalizeBinaryKey(string key) - => key.EndsWith(Metadata.BinaryHeaderSuffix, StringComparison.OrdinalIgnoreCase) - ? key - : key + Metadata.BinaryHeaderSuffix; - - private static string DenormalizeBinaryKey(string key) - => key.EndsWith(Metadata.BinaryHeaderSuffix, StringComparison.OrdinalIgnoreCase) - ? key.Substring(0, key.Length - Metadata.BinaryHeaderSuffix.Length) - : key; -} - -public static class MetadataAdapterExtensions -{ - public static bool TryGet(this ICallHeaders headers, string key, out string value) - { - value = headers.Get(key); - return value is not null; - } -} \ No newline at end of file