diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs index cbccc9dab34..227ea4812a6 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightSqlServer.cs @@ -285,6 +285,7 @@ public override Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWri public override Task DoAction(FlightAction action, IAsyncStreamWriter responseStream, ServerCallContext context) { Logger?.LogTrace("Executing Flight SQL DoAction: {ActionType}", action.Type); + switch (action.Type) { case SqlAction.CreateRequest: diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs b/csharp/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs new file mode 100644 index 00000000000..f2d8f0639dc --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Middleware; + +public class CallHeaders : ICallHeaders, IEnumerable> +{ + private readonly Metadata _metadata; + + public CallHeaders(Metadata metadata) + { + _metadata = metadata; + } + + public void Add(string key, string value) => _metadata.Add(key, value); + + public bool ContainsKey(string key) => _metadata.Any(h => KeyEquals(h.Key, key)); + + public IEnumerator> GetEnumerator() + { + foreach (var entry in _metadata) + yield return new KeyValuePair(entry.Key, entry.Value); + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public string this[string key] + { + get + { + var entry = _metadata.FirstOrDefault(h => KeyEquals(h.Key, key)); + return entry?.Value; + } + set + { + var entry = _metadata.FirstOrDefault(h => KeyEquals(h.Key, key)); + if (entry != null) _metadata.Remove(entry); + _metadata.Add(key, value); + } + } + + public string Get(string key) => this[key]; + + public byte[] GetBytes(string key) => + _metadata.FirstOrDefault(h => KeyEquals(h.Key, key))?.ValueBytes; + + public IEnumerable GetAll(string key) => + _metadata.Where(h => KeyEquals(h.Key, key)).Select(h => h.Value); + + public IEnumerable GetAllBytes(string key) => + _metadata.Where(h => KeyEquals(h.Key, key)).Select(h => h.ValueBytes); + + public void Insert(string key, string value) => Add(key, value); + + public void Insert(string key, byte[] value) => _metadata.Add(key, value); + + public ISet Keys => new HashSet(_metadata.Select(h => h.Key)); + + private static bool KeyEquals(string a, string b) => + string.Equals(a, b, StringComparison.OrdinalIgnoreCase); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/CallInfo.cs b/csharp/src/Apache.Arrow.Flight/Middleware/CallInfo.cs new file mode 100644 index 00000000000..5b41d72b276 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/CallInfo.cs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Grpc.Core; + +namespace Apache.Arrow.Flight.Middleware; + +public readonly struct CallInfo +{ + public string Method { get; } + public MethodType MethodType { get; } + + public CallInfo(string method, MethodType methodType) + { + Method = method; + MethodType = methodType; + } + + public override string ToString() + { + return $"{MethodType}: {Method}"; + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs new file mode 100644 index 00000000000..543b5946918 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Middleware; + +public class ClientCookieMiddleware : IFlightClientMiddleware +{ + private readonly ClientCookieMiddlewareFactory _factory; + private readonly ILogger _logger; + private const string SetCookieHeader = "Set-Cookie"; + private const string CookieHeader = "Cookie"; + + public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, + ILogger logger) + { + _factory = factory; + _logger = logger; + } + + public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) + { + if (_factory.Cookies.IsEmpty) + return; + var cookieValue = GetValidCookiesAsString(); + if (!string.IsNullOrEmpty(cookieValue)) + { + outgoingHeaders.Insert(CookieHeader, cookieValue); + } + } + + public void OnHeadersReceived(ICallHeaders incomingHeaders) + { + var setCookies = incomingHeaders.GetAll(SetCookieHeader); + _factory.UpdateCookies(setCookies); + } + + public void OnCallCompleted(Status status, Metadata trailers) + { + // ingest: status and/or metadata trailers + } + + private string GetValidCookiesAsString() + { + var cookieList = new List(); + foreach (var entry in _factory.Cookies) + { + if (entry.Value.Expired) + { + _factory.Cookies.TryRemove(entry.Key, out _); + } + else + { + cookieList.Add(entry.Value.ToString()); + } + } + return string.Join("; ", cookieList); + } +} diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs new file mode 100644 index 00000000000..1832655a11e --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using Apache.Arrow.Flight.Middleware.Extensions; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Middleware; + +public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory +{ + public readonly ConcurrentDictionary Cookies = new(StringComparer.OrdinalIgnoreCase); + private readonly ILogger _logger; + + public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) + { + _logger = loggerFactory.CreateLogger(); + } + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) + { + return new ClientCookieMiddleware(this, _logger); + } + + internal void UpdateCookies(IEnumerable newCookieHeaderValues) + { + foreach (var headerValue in newCookieHeaderValues) + { + try + { + foreach (var parsedCookie in headerValue.ParseHeader()) + { + var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); + if (parsedCookie.IsExpired(headerValue)) + { + Cookies.TryRemove(nameLc, out _); + } + else + { + Cookies[nameLc] = parsedCookie; + } + } + } + catch (FormatException ex) + { + _logger.LogWarning(ex, "Skipping malformed Set-Cookie header: '{HeaderValue}'", headerValue); + } + } + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs new file mode 100644 index 00000000000..5ad0af54384 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Net; + +namespace Apache.Arrow.Flight.Middleware.Extensions; + +public static class CookieExtensions +{ + public static IEnumerable ParseHeader(this string setCookieHeader) + { + if (string.IsNullOrWhiteSpace(setCookieHeader)) + return System.Array.Empty(); + + var cookies = new List(); + + var segments = setCookieHeader.Split([';'], StringSplitOptions.RemoveEmptyEntries); + if (segments.Length == 0) + return cookies; + + var nameValue = segments[0].Split(['='], 2); + if (nameValue.Length != 2 || string.IsNullOrWhiteSpace(nameValue[0])) + return cookies; + + var name = nameValue[0].Trim(); + var value = nameValue[1].Trim(); + var cookie = new Cookie(name, value); + + foreach (var segment in segments.Skip(1)) + { + var kv = segment.Split(['='], 2, StringSplitOptions.RemoveEmptyEntries); + var key = kv[0].Trim().ToLowerInvariant(); + var val = kv.Length > 1 ? kv[1] : null; + + switch (key) + { + case "expires": + if (!string.IsNullOrWhiteSpace(val)) + { + if (DateTimeOffset.TryParseExact(val, "R", CultureInfo.InvariantCulture, DateTimeStyles.None, out var expiresRfc)) + cookie.Expires = expiresRfc.UtcDateTime; + else if (DateTimeOffset.TryParse(val, out var expiresFallback)) + cookie.Expires = expiresFallback.UtcDateTime; + } + break; + + case "max-age": + if (int.TryParse(val, out var seconds)) + cookie.Expires = DateTime.UtcNow.AddSeconds(seconds); + break; + + case "domain": + cookie.Domain = val; + break; + + case "path": + cookie.Path = val; + break; + + case "secure": + cookie.Secure = true; + break; + + case "httponly": + cookie.HttpOnly = true; + break; + } + } + + cookies.Add(cookie); + return cookies; + } + + public static bool IsExpired(this Cookie cookie, string rawHeader) + { + if (string.IsNullOrWhiteSpace(cookie?.Value)) + return true; + + // If raw header has Max-Age=0, consider it deleted + if (rawHeader?.IndexOf("Max-Age=0", StringComparison.OrdinalIgnoreCase) >= 0) + return true; + + if (cookie.Expires != DateTime.MinValue && cookie.Expires <= DateTime.UtcNow) + return true; + + return false; + } +} diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs new file mode 100644 index 00000000000..f6d07484fbb --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors; + +public sealed class ClientInterceptorAdapter : Interceptor +{ + private readonly IReadOnlyList _factories; + + public ClientInterceptorAdapter(IEnumerable factories) + { + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } + + public override AsyncUnaryCall AsyncUnaryCall( + TRequest request, + ClientInterceptorContext context, + AsyncUnaryCallContinuation continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall( + TRequest request, + ClientInterceptorContext context, + AsyncServerStreamingCallContinuation continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext( + context.Method, context.Host, callOptions); + + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.IsFaulted) + { + throw task.Exception!; + } + + if (task.IsCanceled) + { + throw new TaskCanceledException(task); + } + + var headers = task.Result; + var ch = new CallHeaders(headers); + foreach (var m in middlewares) + m?.OnHeadersReceived(ch); + + return headers; + }); + + var wrappedResponseStream = new MiddlewareResponseStream( + call.ResponseStream, + call, + middlewares); + + return new AsyncServerStreamingCall( + wrappedResponseStream, + responseHeadersTask, + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + + private CallOptions InterceptCall( + ClientInterceptorContext context, + out List middlewareList) + where TRequest : class + where TResponse : class + { + var callInfo = new CallInfo(context.Method.FullName, context.Method.Type); + + var headers = context.Options.Headers ?? new Metadata(); + middlewareList = new List(); + + var callHeaders = new CallHeaders(headers); + + foreach (var factory in _factories) + { + var middleware = factory.OnCallStarted(callInfo); + middleware?.OnBeforeSendingHeaders(callHeaders); + middlewareList.Add(middleware); + } + + return context.Options.WithHeaders(headers); + } + + private async Task HandleResponse( + Task responseTask, + Task headersTask, + Func getStatus, + Func getTrailers, + Action dispose, + List middlewares) + { + var nonNullMiddlewares = (middlewares ?? new List()) + .Where(m => m != null) + .ToList(); + + var hasMiddlewares = nonNullMiddlewares.Count > 0; + var completionNotified = false; + + try + { + // Always await headers to surface faults; only materialize CallHeaders if needed. + var headers = await headersTask.ConfigureAwait(false); + if (hasMiddlewares) + { + var ch = new CallHeaders(headers); + foreach (var m in nonNullMiddlewares) + m.OnHeadersReceived(ch); + } + + var response = await responseTask.ConfigureAwait(false); + + // Single completion notification + NotifyCompletionOnce(); + return response; + } + catch + { + // Completion on failure (only once) + NotifyCompletionOnce(); + throw; + } + finally + { + dispose?.Invoke(); + } + + void NotifyCompletionOnce() + { + if (completionNotified || !hasMiddlewares) return; + completionNotified = true; + + var status = getStatus(); + var trailers = getTrailers(); + + foreach (var m in nonNullMiddlewares) + m.OnCallCompleted(status, trailers); + } + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs new file mode 100644 index 00000000000..ac5f6700b13 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Middleware.Interceptors; + +public class MiddlewareResponseStream : IAsyncStreamReader where T : class +{ + private readonly IAsyncStreamReader _inner; + private readonly AsyncServerStreamingCall _call; + private readonly List _middlewareList; + + public MiddlewareResponseStream( + IAsyncStreamReader inner, + AsyncServerStreamingCall call, + List middlewareList) + { + _inner = inner; + _call = call; + _middlewareList = middlewareList; + } + + public T Current => _inner.Current; + + public async Task MoveNext(CancellationToken cancellationToken) + { + try + { + bool hasNext = await _inner.MoveNext(cancellationToken).ConfigureAwait(false); + if (!hasNext) + { + TriggerOnCallCompleted(); + } + + return hasNext; + } + catch + { + TriggerOnCallCompleted(); + throw; + } + } + + private void TriggerOnCallCompleted() + { + var status = _call.GetStatus(); + var trailers = _call.GetTrailers(); + + foreach (var m in _middlewareList) + m?.OnCallCompleted(status, trailers); + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs new file mode 100644 index 00000000000..9f91b30b8e1 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; + +namespace Apache.Arrow.Flight.Middleware.Interfaces; + +public interface ICallHeaders +{ + string this[string key] { get; } + + string Get(string key); + byte[] GetBytes(string key); + IEnumerable GetAll(string key); + IEnumerable GetAllBytes(string key); + + void Insert(string key, string value); + void Insert(string key, byte[] value); + + ISet Keys { get; } + bool ContainsKey(string key); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs new file mode 100644 index 00000000000..ac12d988af4 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Grpc.Core; + +namespace Apache.Arrow.Flight.Middleware.Interfaces; + +public interface IFlightClientMiddleware +{ + void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders); + void OnHeadersReceived(ICallHeaders incomingHeaders); + void OnCallCompleted(Status status, Metadata trailers); +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs new file mode 100644 index 00000000000..6ae74566cbe --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Apache.Arrow.Flight.Middleware.Interfaces; + +public interface IFlightClientMiddlewareFactory +{ + IFlightClientMiddleware OnCallStarted(CallInfo callInfo); +} diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs new file mode 100644 index 00000000000..089b91ef140 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Linq; +using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; +using Xunit; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; + +public class CallHeadersTests +{ + private readonly InMemoryCallHeaders _headers = new(); + + [Fact] + public void InsertAndGetStringValue() + { + _headers.Insert("Auth", "Bearer 123"); + Assert.Equal("Bearer 123", _headers.Get("Auth")); + Assert.Equal("Bearer 123", _headers["Auth"]); + } + + [Fact] + public void InsertAndGetByteArrayValue() + { + var bytes = new byte[] { 1, 2, 3, 4, 5 }; + _headers.Insert("Data", bytes); + Assert.Equal(bytes, _headers.GetBytes("Data")); + } + + [Fact] + public void InsertMultipleValuesAndGetLast() + { + _headers.Insert("User", "Alice"); + _headers.Insert("User", "Bob"); + Assert.Equal("Alice", _headers.Get("User")); + } + + [Fact] + public void GetAllShouldReturnAllStringValues() + { + _headers.Insert("Header", "v1"); + _headers.Insert("Header", "v2"); + var all = _headers.GetAll("Header").ToList(); + Assert.Contains("v1", all); + Assert.Contains("v2", all); + Assert.Equal(2, all.Count); + } + + [Fact] + public void GetAllBytesShouldReturnAllByteArrayValues() + { + var a = new byte[] { 1 }; + var b = new byte[] { 2 }; + _headers.Insert("Binary", a); + _headers.Insert("Binary", b); + var all = _headers.GetAllBytes("Binary").ToList(); + Assert.Contains(a, all); + Assert.Contains(b, all); + Assert.Equal(2, all.Count); + } + + [Fact] + public void KeysShouldReturnAllKeys() + { + _headers.Insert("A", "x"); + _headers.Insert("B", "y"); + Assert.Contains("A", _headers.Keys); + Assert.Contains("B", _headers.Keys); + } + + [Fact] + public void ContainsKeyShouldWork() + { + _headers.Insert("Check", "yes"); + Assert.True(_headers.ContainsKey("Check")); + Assert.False(_headers.ContainsKey("Missing")); + } + + [Fact] + public void GetNonExistentKeyShouldReturnNull() + { + Assert.Null(_headers.Get("MissingKey")); + Assert.Null(_headers.GetBytes("MissingKey")); + Assert.Empty(_headers.GetAll("MissingKey")); + Assert.Empty(_headers.GetAllBytes("MissingKey")); + } + + [Fact] + public void ContainsKeyShouldBeFalseForMissingKey() + { + Assert.False(_headers.ContainsKey("DefinitelyMissing")); + } + + [Fact] + public void KeysShouldBeEmptyWhenNoHeaders() + { + Assert.Empty(_headers.Keys); + } + + [Fact] + public void IndexerShouldReturnNullForMissingKey() + { + string value = _headers["nonexistent"]; + Assert.Null(value); + } + + [Fact] + public void InsertEmptyStringsShouldStillStore() + { + _headers.Insert("Empty", ""); + Assert.Equal("", _headers.Get("Empty")); + Assert.Single(_headers.GetAll("Empty")); + } + + [Fact] + public void InsertEmptyByteArrayShouldStillStore() + { + var empty = System.Array.Empty(); + _headers.Insert("BinaryEmpty", empty); + Assert.Equal(empty, _headers.GetBytes("BinaryEmpty")); + Assert.Single(_headers.GetAllBytes("BinaryEmpty")); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs new file mode 100644 index 00000000000..b1d89a36073 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Linq; +using System.Net; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware; +using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; +using Xunit; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; + +public class ClientCookieMiddlewareTests +{ + private readonly ClientCookieMiddlewareMock _middlewareMock = new(); + + [Fact] + public void NoCookiesReturnsEmptyString() + { + var factory = _middlewareMock.CreateFactory(); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + Assert.Empty(headers.GetAll("Cookie")); + } + + [Fact] + public void OnlyExpiredCookiesRemovesAll() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["expired"] = + _middlewareMock.CreateCookie("expired", "value", DateTimeOffset.UtcNow.AddMinutes(-5)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + Assert.Empty(headers.GetAll("Cookie")); + Assert.Empty(factory.Cookies); + } + + [Fact] + public void OnlyValidCookiesReturnsCookieHeader() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["valid"] = _middlewareMock.CreateCookie("valid", "abc", DateTimeOffset.UtcNow.AddMinutes(10)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + var header = headers.GetAll("Cookie").FirstOrDefault(); + Assert.NotNull(header); + Assert.Contains("valid=abc", header); + } + + [Fact] + public void MixedCookiesRemovesExpiredOnly() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["expired"] = + _middlewareMock.CreateCookie("expired", "x", DateTimeOffset.UtcNow.AddMinutes(-10)); + factory.Cookies["valid"] = _middlewareMock.CreateCookie("valid", "y", DateTimeOffset.UtcNow.AddMinutes(10)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + var header = headers.GetAll("Cookie").FirstOrDefault(); + Assert.NotNull(header); + Assert.Contains("valid=y", header); + Assert.DoesNotContain("expired=x", header); + Assert.Single(factory.Cookies); + } + + [Fact] + public void DuplicateCookieKeysLastValidRemains() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["token"] = _middlewareMock.CreateCookie("token", "old", DateTimeOffset.UtcNow.AddMinutes(-5)); + factory.Cookies["token"] = _middlewareMock.CreateCookie("token", "new", DateTimeOffset.UtcNow.AddMinutes(5)); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + var header = headers.GetAll("Cookie").FirstOrDefault(); + Assert.NotNull(header); + Assert.Contains("token=new", header); + } + + [Fact] + public void FalsePositiveValidDateButMarkedExpired() + { + var factory = _middlewareMock.CreateFactory(); + factory.Cookies["wrong"] = + _middlewareMock.CreateCookie("wrong", "v", DateTimeOffset.UtcNow.AddMinutes(10), expiredOverride: true); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + Assert.Empty(headers.GetAll("Cookie")); + } + + [Fact] + public async Task ConcurrentInsertRemoveDoesNotCorrupt() + { + var factory = _middlewareMock.CreateFactory(); + var middleware = + new ClientCookieMiddleware(factory, new ClientCookieMiddlewareMock.TestLogger()); + + for (int i = 0; i < 100; i++) + factory.Cookies[$"cookie{i}"] = + _middlewareMock.CreateCookie($"cookie{i}", $"{i}", DateTimeOffset.UtcNow.AddMinutes(5)); + + var tasks = Enumerable.Range(0, 20).Select(_ => Task.Run(() => + { + var headers = new InMemoryCallHeaders(); + middleware.OnBeforeSendingHeaders(headers); + foreach (var key in factory.Cookies.Keys) + factory.Cookies.TryRemove(key, out Cookie _); + })); + + await Task.WhenAll(tasks); + Assert.True(factory.Cookies.Count >= 0); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs new file mode 100644 index 00000000000..fc1706408e4 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Middleware.Interceptors; +using Apache.Arrow.Flight.Sql.Tests.Stubs; +using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; +using Grpc.Core; +using Grpc.Core.Interceptors; +using Xunit; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; + +public class ClientInterceptorAdapterTests +{ + private readonly TestWebFactory _testWebFactory; + private readonly FlightClient _client; + private readonly CapturingMiddlewareFactory _middlewareFactory; + + public ClientInterceptorAdapterTests() + { + _testWebFactory = new TestWebFactory(new InMemoryFlightStore()); + + _middlewareFactory = new CapturingMiddlewareFactory(); + var interceptor = new ClientInterceptorAdapter([_middlewareFactory]); + + _client = new FlightClient(_testWebFactory.GetChannel().Intercept(interceptor)); + } + + [Fact] + public async Task MiddlewareFlowIsCalledCorrectly() + { + // Arrange + var descriptor = FlightDescriptor.CreatePathDescriptor("test"); + + // Act + var info = await _client.GetInfo(descriptor); + var middleware = _middlewareFactory.Instance; + + // Assert + Assert.NotNull(info); + Assert.True(middleware.BeforeHeadersCalled, "BeforeHeaders not called"); + Assert.True(middleware.HeadersReceivedCalled, "HeadersReceived not called"); + Assert.True(middleware.CallCompletedCalled, "CallCompleted not called"); + } + + [Fact] + public async Task CookieAndHeaderValuesArePersistedThroughMiddleware() + { + // Arrange + var descriptor = FlightDescriptor.CreatePathDescriptor("test"); + + // Act + try + { + await _client.GetInfo(descriptor); + } + catch (RpcException) + { + // Expected: Flight not found, but middleware should have run + } + + // Assert Middleware captured the headers and cookies correctly + var middleware = _middlewareFactory.Instance; + + Assert.True(middleware.BeforeHeadersCalled, "OnBeforeSendingHeaders not called"); + Assert.True(middleware.HeadersReceivedCalled, "OnHeadersReceived not called"); + Assert.True(middleware.CallCompletedCalled, "OnCallCompleted not called"); + + // Validate Cookies captured correctly + Assert.True(middleware.CapturedHeaders.ContainsKey("cookie")); + var cookies = ParseCookies(middleware.CapturedHeaders["cookie"]); + + Assert.Equal("abc123", cookies["sessionId"]); + Assert.Equal("xyz789", cookies["token"]); + } + + private static IDictionary ParseCookies(string cookieHeader) + { + return cookieHeader.Split(';') + .Select(pair => pair.Split('=')) + .ToDictionary(parts => parts[0].Trim(), parts => parts[1].Trim()); + } +} + + \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs new file mode 100644 index 00000000000..19b5ae297cf --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using Apache.Arrow.Flight.Middleware.Extensions; +using Xunit; +using static System.Linq.Enumerable; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests; + +public class CookieExtensionsTests +{ + [Fact] + public void ParseHeaderShouldParseSimpleCookie() + { + // Arrange + var header = "sessionId=abc123"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("sessionId", cookies[0].Name); + Assert.Equal("abc123", cookies[0].Value); + Assert.False(cookies[0].Expired); + } + + [Fact] + public void ParseHeaderShouldParseCookieWithExpires() + { + // Arrange + var futureDate = DateTimeOffset.UtcNow.AddDays(7); + var header = $"userId=789; Expires={futureDate:R}"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("userId", cookies[0].Name); + Assert.Equal("789", cookies[0].Value); + Assert.True(Math.Abs((cookies[0].Expires - futureDate.UtcDateTime).TotalSeconds) < 5); + } + + [Fact] + public void ParseHeaderShouldReturnEmptyWhenMalformed() + { + // Arrange + var header = "this_is_wrong"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Empty(cookies); + } + + [Fact] + public void ParseHeaderShouldReturnEmptyWhenEmptyString() + { + // Arrange + var header = string.Empty; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Empty(cookies); + } + + [Fact] + public void ParseHeaderShouldReturnEmptyWhenNullString() + { + // Arrange + string header = null; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Empty(cookies); + } + + [Fact] + public void ParseHeaderShouldParseCookieIgnoringAttributes() + { + // Arrange + var header = "token=xyz; Path=/; HttpOnly"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("token", cookies[0].Name); + Assert.Equal("xyz", cookies[0].Value); + } + + [Fact] + public void ParseHeaderShouldIgnoreInvalidExpires() + { + // Arrange + var header = "name=value; Expires=invalid-date"; + + // Act + var cookies = header.ParseHeader().ToList(); + + // Assert + Assert.Single(cookies); + Assert.Equal("name", cookies[0].Name); + Assert.Equal("value", cookies[0].Value); + Assert.False(cookies[0].Expired); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs new file mode 100644 index 00000000000..b60784aca90 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; + +public class CapturingMiddleware : IFlightClientMiddleware +{ + public Dictionary CapturedHeaders { get; } = new(); + + public bool BeforeHeadersCalled { get; private set; } + public bool HeadersReceivedCalled { get; private set; } + public bool CallCompletedCalled { get; private set; } + public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) + { + BeforeHeadersCalled = true; + outgoingHeaders.Insert("x-test-header", "test-value"); + outgoingHeaders.Insert("cookie", "sessionId=abc123; token=xyz789"); + CaptureHeaders(outgoingHeaders); + } + public void OnHeadersReceived(ICallHeaders incomingHeaders) + { + HeadersReceivedCalled = true; + CaptureHeaders(incomingHeaders); + } + + public void OnCallCompleted(Status status, Metadata trailers) + { + CallCompletedCalled = true; + } + + private void CaptureHeaders(ICallHeaders headers) + { + foreach (var key in headers.Keys) + { + var value = headers[key]; + if (value != null) + { + CapturedHeaders[key] = value; + } + } + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs new file mode 100644 index 00000000000..070a35ed024 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Dynamic; +using Apache.Arrow.Flight.Middleware.Interfaces; +using CallInfo = Apache.Arrow.Flight.Middleware.CallInfo; + + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; + +public class CapturingMiddlewareFactory : IFlightClientMiddlewareFactory +{ + public CapturingMiddleware Instance { get; } = new(); + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo)=> Instance; +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs new file mode 100644 index 00000000000..c3f0fe77c67 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Net; +using Apache.Arrow.Flight.Middleware; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; + +internal class ClientCookieMiddlewareMock +{ + public Cookie CreateCookie(string name, string value, DateTimeOffset? expires = null, bool? expiredOverride = null) + { + return new Cookie + { + Name = name, + Value = value, + Expires = expires!.Value.UtcDateTime, + Expired = expiredOverride ?? (expires.HasValue && expires.Value < DateTimeOffset.UtcNow) + }; + } + + public ClientCookieMiddlewareFactory CreateFactory() + { + return new ClientCookieMiddlewareFactory(new TestLoggerFactory()); + } + + public class TestLogger : ILogger + { + public IDisposable BeginScope(TState state) => null; + public bool IsEnabled(LogLevel logLevel) => false; + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, + Func formatter) + { + } + } + + internal class TestLoggerFactory : ILoggerFactory + { + public void AddProvider(ILoggerProvider provider) + { + } + + public ILogger CreateLogger(string categoryName) => new TestLogger(); + + public void Dispose() + { + } + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs new file mode 100644 index 00000000000..304a2d1fb30 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Flight.Middleware; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs; + +public class InMemoryCallHeaders : ICallHeaders +{ + private readonly CallHeaders _stringHeaders; + private readonly Dictionary> _byteHeaders; + + public InMemoryCallHeaders() + { + _stringHeaders = new CallHeaders(new Metadata()); + _byteHeaders = new Dictionary>(StringComparer.OrdinalIgnoreCase); + } + + private static string NormalizeKey(string key) => key.ToLowerInvariant(); + + public string this[string key] => Get(key); + + public string Get(string key) + { + key = NormalizeKey(key); + return _stringHeaders.ContainsKey(key) ? _stringHeaders[key] : null; + } + + public byte[] GetBytes(string key) + { + key = NormalizeKey(key); + return _byteHeaders.TryGetValue(key, out var values) ? values.LastOrDefault() : null; + } + + public IEnumerable GetAll(string key) + { + key = NormalizeKey(key); + return _stringHeaders.Where(h => string.Equals(h.Key, key, StringComparison.OrdinalIgnoreCase)) + .Select(h => h.Value); + } + + public IEnumerable GetAllBytes(string key) + { + key = NormalizeKey(key); + return _byteHeaders.TryGetValue(key, out var values) ? values : Enumerable.Empty(); + } + + public void Insert(string key, string value) + { + key = NormalizeKey(key); + _stringHeaders.Add(key, value); + } + + public void Insert(string key, byte[] value) + { + key = NormalizeKey(key); + if (!_byteHeaders.TryGetValue(key, out var list)) + _byteHeaders[key] = list = new List(); + list.Add(value); + } + + public ISet Keys => + new HashSet( + _stringHeaders.Select(h => h.Key.ToLowerInvariant()) + .Concat(_byteHeaders.Keys), + StringComparer.OrdinalIgnoreCase); + + public bool ContainsKey(string key) + { + key = NormalizeKey(key); + return _stringHeaders.ContainsKey(key) || _byteHeaders.ContainsKey(key); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs new file mode 100644 index 00000000000..619b4aa3b0f --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Flight.TestWeb; +using Apache.Arrow.Types; + +namespace Apache.Arrow.Flight.Sql.Tests.Stubs; + +public class InMemoryFlightStore : FlightStore +{ + public InMemoryFlightStore() + { + // Pre-register a dummy flight so GetFlightInfo can resolve it + var descriptor = FlightDescriptor.CreatePathDescriptor("test"); + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + + var recordBatch = new RecordBatch(schema, new Array[] + { + new Int32Array.Builder().Append(1).Build(), + new StringArray.Builder().Append("John Doe").Build() + }, 1); + + var location = new FlightLocation("grpc+tcp://localhost:50051"); + + var flightHolder = new FlightHolder(descriptor, schema, location.Uri); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + Flights.Add(descriptor, flightHolder); + } + + public override string ToString() + { + return $"InMemoryFlightStore(Flights={Flights.Count})"; + } +} \ No newline at end of file