diff --git a/PingPong.cs b/PingPong.cs index c57ca0d..1364dd4 100644 --- a/PingPong.cs +++ b/PingPong.cs @@ -9,86 +9,40 @@ namespace net.vieapps.Components.WebSockets { - /// - /// Pong EventArgs - /// - internal class PongEventArgs : EventArgs - { - /// - /// The data extracted from a Pong WebSocket frame - /// - public ArraySegment Payload { get; private set; } - - /// - /// Initialises a new instance of the PongEventArgs class - /// - /// The pong payload must be 125 bytes or less (can be zero bytes) - public PongEventArgs(ArraySegment payload) => this.Payload = payload; - } - - // -------------------------------------------------- - - /// - /// Ping Pong Manager used to facilitate ping pong WebSocket messages - /// - internal interface IPingPongManager - { - /// - /// Raised when a Pong frame is received - /// - event EventHandler Pong; - - /// - /// Sends a ping frame - /// - /// The payload (must be 125 bytes of less) - /// The cancellation token - Task SendPingAsync(ArraySegment payload, CancellationToken cancellation = default(CancellationToken)); - } - - // -------------------------------------------------- - - /// - /// Ping Pong Manager used to facilitate ping pong WebSocket messages - /// - internal class PingPongManager : IPingPongManager + internal class PingPongManager { readonly WebSocketImplementation _websocket; - readonly Task _pingTask; readonly CancellationToken _cancellationToken; - readonly Stopwatch _stopwatch; - long _pingSentTicks; - - /// - /// Raised when a Pong frame is received - /// - public event EventHandler Pong; + readonly Action _onPong; + readonly Func _getPongPayload; + readonly Func _getPingPayload; + long _pingTimestamp = 0; - /// - /// Initialises a new instance of the PingPongManager to facilitate ping pong WebSocket messages. - /// - /// The WebSocket instance used to listen to ping messages and send pong messages - /// The token used to cancel a pending ping send AND the automatic sending of ping messages if KeepAliveInterval is positive - public PingPongManager(WebSocketImplementation websocket, CancellationToken cancellationToken) + public PingPongManager(WebSocketImplementation websocket, WebSocketOptions options, CancellationToken cancellationToken) { this._websocket = websocket; - this._websocket.Pong += this.DoPong; this._cancellationToken = cancellationToken; - this._stopwatch = Stopwatch.StartNew(); - this._pingTask = Task.Run(this.DoPingAsync); + this._getPongPayload = options.GetPongPayload; + this._onPong = options.OnPong; + if (this._websocket.KeepAliveInterval != TimeSpan.Zero) + { + this._getPingPayload = options.GetPingPayload; + Task.Run(this.SendPingAsync).ConfigureAwait(false); + } + } + + public void OnPong(byte[] pong) + { + this._pingTimestamp = 0; + this._onPong?.Invoke(this._websocket, pong); } - /// - /// Sends a ping frame - /// - /// The payload (must be 125 bytes of less) - /// The cancellation token - public Task SendPingAsync(ArraySegment payload, CancellationToken cancellationToken = default(CancellationToken)) - => this._websocket.SendPingAsync(payload, cancellationToken); + public Task SendPongAsync(byte[] ping) + => this._websocket.SendPongAsync((this._getPongPayload?.Invoke(this._websocket, ping) ?? ping).ToArraySegment(), this._cancellationToken); - async Task DoPingAsync() + public async Task SendPingAsync() { - Events.Log.PingPongManagerStarted(this._websocket.ID, (int)this._websocket.KeepAliveInterval.TotalSeconds); + Events.Log.PingPongManagerStarted(this._websocket.ID, this._websocket.KeepAliveInterval.TotalSeconds.CastAs()); try { while (!this._cancellationToken.IsCancellationRequested) @@ -97,30 +51,19 @@ async Task DoPingAsync() if (this._websocket.State != WebSocketState.Open) break; - if (this._pingSentTicks != 0) + if (this._pingTimestamp != 0) { Events.Log.KeepAliveIntervalExpired(this._websocket.ID, (int)this._websocket.KeepAliveInterval.TotalSeconds); - await this._websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, $"No Pong message received in response to a Ping after KeepAliveInterval ({this._websocket.KeepAliveInterval})", this._cancellationToken).ConfigureAwait(false); + await this._websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, $"No PONG message received in response to a PING message after keep-alive interval ({this._websocket.KeepAliveInterval})", this._cancellationToken).ConfigureAwait(false); break; } - this._pingSentTicks = this._stopwatch.Elapsed.Ticks; - await this.SendPingAsync(this._pingSentTicks.ToArraySegment(), this._cancellationToken).ConfigureAwait(false); + this._pingTimestamp = DateTime.Now.ToUnixTimestamp(); + await this._websocket.SendPingAsync((this._getPingPayload?.Invoke(this._websocket) ?? this._pingTimestamp.ToBytes()).ToArraySegment(), this._cancellationToken).ConfigureAwait(false); } } - catch (OperationCanceledException) - { - // normal, do nothing - } + catch { } Events.Log.PingPongManagerEnded(this._websocket.ID); } - - protected virtual void OnPong(PongEventArgs args) => this.Pong?.Invoke(this, args); - - void DoPong(object sender, PongEventArgs arg) - { - this._pingSentTicks = 0; - this.OnPong(arg); - } } } \ No newline at end of file diff --git a/README.md b/README.md index 21933dd..3bb5741 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,9 @@ public EndPoint LocalEndPoint { get; } // Extra information public Dictionary Extra { get; } + +// Headers information +public Dictionary Headers { get; } ``` ## Fly on the sky with Event-liked driven @@ -117,8 +120,8 @@ var websocket = new WebSocket And this class has some methods for working on both side of client and server role: ```csharp -void Connect(Uri uri, WebSocketOptions options, Action onSuccess, Action onFailed); -void StartListen(int port, X509Certificate2 certificate, Action onSuccess, Action onFailed); +void Connect(Uri uri, WebSocketOptions options, Action onSuccess, Action onFailure); +void StartListen(int port, X509Certificate2 certificate, Action onSuccess, Action onFailure, Func getPingPayload, Func getPongPayload, Action onPong); void StopListen(); ``` @@ -149,7 +152,7 @@ websocket.StartListen(); Want to have a free SSL certificate? Take a look at [Let's Encrypt](https://letsencrypt.org/). -Special: A simple tool named [lets-encrypt-win-simple](https://github.com/PKISharp/win-acme) will help your IIS works with Let's Encrypt very well. +Special: A simple tool named [win-acme](https://github.com/PKISharp/win-acme) will help your IIS works with Let's Encrypt very well. ### SubProtocol Negotiation @@ -185,7 +188,7 @@ When integrate this component with your app that hosted by ASP.NET / ASP.NET Cor then the method **WrapAsync** is here to help. This method will return a task that run a process for receiving messages from this WebSocket connection. ```csharp -Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, EndPoint remoteEndPoint, EndPoint localEndPoint, string userAgent, string urlReferrer, string headers, string cookies, Action onSuccess); +Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, EndPoint remoteEndPoint, EndPoint localEndPoint, Dictionary headers, Action onSuccess); ``` And might be you need an extension method to wrap an existing WebSocket connection, then take a look at some lines of code below: @@ -316,7 +319,7 @@ bool CloseWebSocket(ManagedWebSocket websocket, WebSocketCloseStatus closeStatus Our prefers: - [Microsoft.Extensions.Logging.Console](https://www.nuget.org/packages/Microsoft.Extensions.Logging.Console): live logs -- [Serilog.Extensions.Logging.File](https://www.nuget.org/packages/Serilog.Extensions.Logging.File): rolling log files (by date) - high performance, and very simple to use +- [Serilog.Extensions.Logging.File](https://www.nuget.org/packages/Serilog.Extensions.Logging.File): rolling log files (by hour or date) - high performance, and very simple to use ### Namespaces diff --git a/VIEApps.Components.WebSockets.csproj b/VIEApps.Components.WebSockets.csproj index 2e44917..5c457ee 100644 --- a/VIEApps.Components.WebSockets.csproj +++ b/VIEApps.Components.WebSockets.csproj @@ -9,14 +9,14 @@ VIEApps NGX WebSockets false ../VIEApps.Components.snk - 10.2.1903.1 - 10.2.1903.1 - v10.2.netstandard-2+rev:2019.03.01-latest.components-ignore.cert.errors + 10.2.1906.2 + 10.2.1906.2 + v10.2.netstandard-2+rev:2019.06.12-async.close VIEApps NGX VIEApps.net false VIEApps.Components.WebSockets - 10.2.1903.1 + 10.2.1906.2 VIEApps NGX WebSockets High performance WebSocket on .NET Standard 2.0 (both server and client - standalone or wrapper of System.Net.WebSockets.WebSocket) VIEApps.net @@ -24,9 +24,9 @@ LICENSE.md ..\ websocket;websockets;websocket-client;websocket-server;websocket-wrapper;vieapps;vieapps.components - Upgrade to latest components, add options to ignore remote certificate errors (support for client certificates) + Improvement: add async close methods + https://vieapps.net/ https://github.com/vieapps/Components.Utility/raw/master/logo.png - https://github.com/vieapps/Components.WebSockets https://github.com/vieapps/Components.WebSockets @@ -39,7 +39,7 @@ - + \ No newline at end of file diff --git a/WebSocket.cs b/WebSocket.cs index d051288..0cc5935 100644 --- a/WebSocket.cs +++ b/WebSocket.cs @@ -40,32 +40,27 @@ public class WebSocket : IDisposable bool _disposing = false, _disposed = false; /// - /// Gets the listening port of the listener + /// Gets or Sets the SSL certificate for securing connections (server) /// - public int Port { get; private set; } = 46429; + public X509Certificate2 Certificate { get; set; } /// - /// Gets or sets the SSL certificate for securing connections + /// Gets or Sets the SSL protocol for securing connections with SSL Certificate (server) /// - public X509Certificate2 Certificate { get; set; } = null; + public SslProtocols SslProtocol { get; set; } = SslProtocols.Tls12; /// - /// Gets or sets the SSL protocol for securing connections with SSL Certificate - /// - public SslProtocols SslProtocol { get; set; } = SslProtocols.Tls; - - /// - /// Gets or sets the collection of supported sub-protocol + /// Gets or Sets the collection of supported sub-protocol (server) /// public IEnumerable SupportedSubProtocols { get; set; } = new string[0]; /// - /// Gets or sets keep-alive interval (seconds) for sending ping messages from server + /// Gets or Sets the keep-alive interval for sending ping messages (server) /// public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.FromSeconds(60); /// - /// Gets or sets a value that specifies whether the listener is disable the Nagle algorithm or not (default is true - means disable for better performance) + /// Gets or Sets a value that specifies whether the listener is disable the Nagle algorithm or not (default is true - means disable for better performance) /// /// /// Set to true to send a message immediately with the least amount of latency (typical usage for chat) @@ -76,9 +71,9 @@ public class WebSocket : IDisposable public bool NoDelay { get; set; } = true; /// - /// Gets or sets await interval (miliseconds) while receiving messages + /// Gets or Sets await interval between two rounds of receiving messages /// - public int AwaitInterval { get; set; } = 0; + public TimeSpan ReceivingAwaitInterval { get; set; } = TimeSpan.Zero; #endregion #region Event Handlers @@ -139,12 +134,20 @@ public Action OnMessageReceive } #endregion + /// + /// Creates new an instance of the centralized WebSocket + /// + /// The cancellation token + public WebSocket(CancellationToken cancellationToken) + : this(null, cancellationToken) { } + /// /// Creates new an instance of the centralized WebSocket /// /// The logger factory /// The cancellation token - public WebSocket(ILoggerFactory loggerFactory, CancellationToken cancellationToken) : this(loggerFactory, null, cancellationToken) { } + public WebSocket(ILoggerFactory loggerFactory, CancellationToken cancellationToken) + : this(loggerFactory, null, cancellationToken) { } /// /// Creates new an instance of the centralized WebSocket @@ -186,7 +189,10 @@ public static string AgentName /// The SSL Certificate to secure connections /// Action to fire when start successful /// Action to fire when failed to start - public void StartListen(int port = 46429, X509Certificate2 certificate = null, Action onSuccess = null, Action onFailure = null) + /// The function to get the custom 'PING' playload to send a 'PING' message + /// The function to get the custom 'PONG' playload to response to a 'PING' message + /// The action to fire when a 'PONG' message has been sent + public void StartListen(int port = 46429, X509Certificate2 certificate = null, Action onSuccess = null, Action onFailure = null, Func getPingPayload = null, Func getPongPayload = null, Action onPong = null) { // check if (this._tcpListener != null) @@ -197,72 +203,93 @@ public void StartListen(int port = 46429, X509Certificate2 certificate = null, A } catch (Exception ex) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } return; } - // listen + // set X.509 certificate + this.Certificate = certificate ?? this.Certificate; + + // open the listener and listen for incoming requests try { // open the listener - this.Port = port > IPEndPoint.MinPort && port < IPEndPoint.MaxPort ? port : 46429; - this.Certificate = certificate ?? this.Certificate; + this._tcpListener = new TcpListener(IPAddress.IPv6Any, port > IPEndPoint.MinPort && port < IPEndPoint.MaxPort ? port : 46429); + this._tcpListener.Server.SetOptions(this.NoDelay, true); + this._tcpListener.Start(512); - this._tcpListener = new TcpListener(IPAddress.Any, this.Port); - this._tcpListener.Server.NoDelay = this.NoDelay; - this._tcpListener.Server.SetKeepAliveInterval(); - this._tcpListener.Start(1024); - - var platform = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) - ? "Windows" - : RuntimeInformation.IsOSPlatform(OSPlatform.Linux) + if (this._logger.IsEnabled(LogLevel.Debug)) + { + var platform = RuntimeInformation.IsOSPlatform(OSPlatform.Linux) ? "Linux" - : "macOS"; - platform += $" ({RuntimeInformation.FrameworkDescription.Trim()}) - SSL: {this.Certificate != null}"; - if (this.Certificate != null) - platform += $" ({this.Certificate.GetNameInfo(X509NameType.DnsName, false)} :: Issued by {this.Certificate.GetNameInfo(X509NameType.DnsName, true)})"; + : RuntimeInformation.IsOSPlatform(OSPlatform.OSX) + ? "macOS" + : "Windows"; + platform += $" {RuntimeInformation.OSArchitecture.ToString().ToLower()} ({RuntimeInformation.FrameworkDescription.Trim()}) - SSL: {this.Certificate != null}"; + if (this.Certificate != null) + platform += $" ({this.Certificate.GetNameInfo(X509NameType.DnsName, false)} :: Issued by {this.Certificate.GetNameInfo(X509NameType.DnsName, true)})"; + this._logger.LogInformation($"The listener is started => {this._tcpListener.Server.LocalEndPoint}\r\nPlatform: {platform}\r\nPowered by {WebSocketHelper.AgentName} v{this.GetType().Assembly.GetVersion()}"); + } - this._logger.LogInformation($"The listener is started (listening port: {this.Port})\r\nPlatform: {platform}\r\nPowered by {WebSocketHelper.AgentName} {this.GetType().Assembly.GetVersion()}"); + // callback when success try { onSuccess?.Invoke(); } catch (Exception ex) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } // listen for incoming connection requests - this.Listen(); + this.Listen(getPingPayload, getPongPayload, onPong); } catch (SocketException ex) { - var message = $"Error occurred while listening on port \"{this.Port}\". Make sure another application is not running and consuming this port."; - this._logger.Log(LogLevel.Debug, LogLevel.Error, message, ex); + var message = $"Error occurred while listening on port \"{(port > IPEndPoint.MinPort && port < IPEndPoint.MaxPort ? port : 46429)}\". Make sure another application is not running and consuming this port."; + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, message, ex); try { onFailure?.Invoke(new ListenerSocketException(message, ex)); } catch (Exception e) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); } } catch (Exception ex) { - this._logger.Log(LogLevel.Debug, LogLevel.Error, $"Got an unexpected error while listening: {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Got an unexpected error while listening: {ex.Message}", ex); try { onFailure?.Invoke(ex); } catch (Exception e) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); } } } + /// + /// Starts to listen for client requests as a WebSocket server + /// + /// The port for listening + /// Action to fire when start successful + /// Action to fire when failed to start + /// The function to get the custom 'PING' playload to send a 'PING' message + /// The function to get the custom 'PONG' playload to response to a 'PING' message + /// The action to fire when a 'PONG' message has been sent + public void StartListen(int port, Action onSuccess, Action onFailure, Func getPingPayload, Func getPongPayload, Action onPong) + => this.StartListen(port, null, onSuccess, onFailure, getPingPayload, getPongPayload, onPong); + /// /// Starts to listen for client requests as a WebSocket server /// @@ -270,14 +297,24 @@ public void StartListen(int port = 46429, X509Certificate2 certificate = null, A /// Action to fire when start successful /// Action to fire when failed to start public void StartListen(int port, Action onSuccess, Action onFailure) - => this.StartListen(port, null, onSuccess, onFailure); + => this.StartListen(port, onSuccess, onFailure, null, null, null); + + /// + /// Starts to listen for client requests as a WebSocket server + /// + /// The port for listening + /// The function to get the custom 'PING' playload to send a 'PING' message + /// The function to get the custom 'PONG' playload to response to a 'PING' message + /// The action to fire when a 'PONG' message has been sent + public void StartListen(int port, Func getPingPayload, Func getPongPayload, Action onPong) + => this.StartListen(port, null, null, getPingPayload, getPongPayload, onPong); /// /// Starts to listen for client requests as a WebSocket server /// /// The port for listening public void StartListen(int port) - => this.StartListen(port, null, null); + => this.StartListen(port, null, null, null); /// /// Stops listen @@ -305,51 +342,50 @@ public void StopListen(bool cancelPendings = true) } } - Task Listen() + Task Listen(Func getPingPayload, Func getPongPayload, Action onPong) { this._listeningCTS = CancellationTokenSource.CreateLinkedTokenSource(this._processingCTS.Token); - return this.ListenAsync(); + return this.ListenAsync(getPingPayload, getPongPayload, onPong); } - async Task ListenAsync() + async Task ListenAsync(Func getPingPayload, Func getPongPayload, Action onPong) { try { while (!this._listeningCTS.IsCancellationRequested) - { - var tcpClient = await this._tcpListener.AcceptTcpClientAsync().WithCancellationToken(this._listeningCTS.Token).ConfigureAwait(false); - tcpClient.Client.SetKeepAliveInterval(); - this.AcceptClient(tcpClient); - } + this.AcceptClient(await this._tcpListener.AcceptTcpClientAsync().WithCancellationToken(this._listeningCTS.Token).ConfigureAwait(false), getPingPayload, getPongPayload, onPong); } catch (Exception ex) { this.StopListen(false); if (ex is OperationCanceledException || ex is TaskCanceledException || ex is ObjectDisposedException || ex is SocketException || ex is IOException) - this._logger.LogInformation($"The listener is stopped {(this._logger.IsEnabled(LogLevel.Debug) ? $"({ex.GetType()})" : "")}"); + this._logger.LogDebug($"The listener is stopped {(this._logger.IsEnabled(LogLevel.Debug) ? $"({ex.GetType()})" : "")}"); else this._logger.LogError($"The listener is stopped ({ex.Message})", ex); } } - void AcceptClient(TcpClient tcpClient) - => Task.Run(() => this.AcceptClientAsync(tcpClient)).ConfigureAwait(false); + void AcceptClient(TcpClient tcpClient, Func getPingPayload, Func getPongPayload, Action onPong) + => Task.Run(() => this.AcceptClientAsync(tcpClient, getPingPayload, getPongPayload, onPong)).ConfigureAwait(false); - async Task AcceptClientAsync(TcpClient tcpClient) + async Task AcceptClientAsync(TcpClient tcpClient, Func getPingPayload, Func getPongPayload, Action onPong) { ManagedWebSocket websocket = null; try { - var id = Guid.NewGuid(); - var endpoint = tcpClient.Client.RemoteEndPoint; + // set optins + tcpClient.Client.SetOptions(this.NoDelay); // get stream + var id = Guid.NewGuid(); + var endpoint = tcpClient.Client.RemoteEndPoint; Stream stream = null; if (this.Certificate != null) try { Events.Log.AttemptingToSecureConnection(id); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Attempting to secure the connection ({id} @ {endpoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Attempting to secure the connection ({id} @ {endpoint})"); stream = new SslStream(tcpClient.GetStream(), false); await (stream as SslStream).AuthenticateAsServerAsync( @@ -360,7 +396,8 @@ async Task AcceptClientAsync(TcpClient tcpClient) ).WithCancellationToken(this._listeningCTS.Token).ConfigureAwait(false); Events.Log.ConnectionSecured(id); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"The connection successfully secured ({id} @ {endpoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"The connection successfully secured ({id} @ {endpoint})"); } catch (OperationCanceledException) { @@ -377,15 +414,18 @@ async Task AcceptClientAsync(TcpClient tcpClient) else { Events.Log.ConnectionNotSecured(id); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Use insecured connection ({id} @ {endpoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Use insecured connection ({id} @ {endpoint})"); stream = tcpClient.GetStream(); } // parse request - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"The connection is opened, then parse the request ({id} @ {endpoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"The connection is opened, then parse the request ({id} @ {endpoint})"); var header = await stream.ReadHeaderAsync(this._listeningCTS.Token).ConfigureAwait(false); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Handshake request ({id} @ {endpoint}) => \r\n{header.Trim()}"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Handshake request ({id} @ {endpoint}) => \r\n{header.Trim()}"); var isWebSocketRequest = false; var path = string.Empty; @@ -400,32 +440,34 @@ async Task AcceptClientAsync(TcpClient tcpClient) // verify request if (!isWebSocketRequest) { - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"The request contains no WebSocket upgrade request, then ignore ({id} @ {endpoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"The request contains no WebSocket upgrade request, then ignore ({id} @ {endpoint})"); stream.Close(); tcpClient.Close(); return; } // accept the request + Events.Log.AcceptWebSocketStarted(id); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"The request has requested an upgrade to WebSocket protocol, negotiating WebSocket handshake ({id} @ {endpoint})"); + var options = new WebSocketOptions { - KeepAliveInterval = this.KeepAliveInterval + KeepAliveInterval = this.KeepAliveInterval.Ticks < 0 ? TimeSpan.FromSeconds(60) : this.KeepAliveInterval, + GetPingPayload = getPingPayload, + GetPongPayload = getPongPayload, + OnPong = onPong }; - Events.Log.AcceptWebSocketStarted(id); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"The request has requested an upgrade to WebSocket protocol, negotiating WebSocket handshake ({id} @ {endpoint})"); try { // check the version (support version 13 and above) match = new Regex("Sec-WebSocket-Version: (.*)").Match(header); - if (match.Success) - { - var secWebSocketVersion = match.Groups[1].Value.Trim().CastAs(); - if (secWebSocketVersion < 13) - throw new VersionNotSupportedException($"WebSocket Version {secWebSocketVersion} is not supported, must be 13 or above"); - } - else + if (!match.Success || !Int32.TryParse(match.Groups[1].Value, out int version)) throw new VersionNotSupportedException("Unable to find \"Sec-WebSocket-Version\" in the upgrade request"); + else if (version < 13) + throw new VersionNotSupportedException($"WebSocket Version {version} is not supported, must be 13 or above"); // get the request key match = new Regex("Sec-WebSocket-Key: (.*)").Match(header); @@ -436,7 +478,7 @@ async Task AcceptClientAsync(TcpClient tcpClient) // negotiate subprotocol match = new Regex("Sec-WebSocket-Protocol: (.*)").Match(header); options.SubProtocol = match.Success - ? match.Groups[1].Value.Trim().Split(new[] { ',', ' ' }, StringSplitOptions.RemoveEmptyEntries).NegotiateSubProtocol(this.SupportedSubProtocols) + ? match.Groups[1].Value?.Trim().Split(new[] { ',', ' ' }, StringSplitOptions.RemoveEmptyEntries).NegotiateSubProtocol(this.SupportedSubProtocols) : null; // handshake @@ -454,7 +496,8 @@ async Task AcceptClientAsync(TcpClient tcpClient) Events.Log.SendingHandshake(id, handshake); await stream.WriteHeaderAsync(handshake, this._listeningCTS.Token).ConfigureAwait(false); Events.Log.HandshakeSent(id, handshake); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Handshake response ({id} @ {endpoint}) => \r\n{handshake.Trim()}"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Handshake response ({id} @ {endpoint}) => \r\n{handshake.Trim()}"); } catch (VersionNotSupportedException ex) { @@ -470,7 +513,8 @@ async Task AcceptClientAsync(TcpClient tcpClient) } Events.Log.ServerHandshakeSuccess(id); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"WebSocket handshake response has been sent, the stream is ready ({id} @ {endpoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"WebSocket handshake response has been sent, the stream is ready ({id} @ {endpoint})"); // update the connected WebSocket connection match = new Regex("Sec-WebSocket-Extensions: (.*)").Match(header); @@ -483,26 +527,11 @@ async Task AcceptClientAsync(TcpClient tcpClient) ? match.Groups[1].Value.Trim() : string.Empty; - websocket = new WebSocketImplementation(id, false, this._recycledStreamFactory, stream, options, new Uri($"ws{(this.Certificate != null ? "s" : "")}://{host}{path}"), endpoint, tcpClient.Client.LocalEndPoint); - - match = new Regex("User-Agent: (.*)").Match(header); - websocket.Extra["User-Agent"] = match.Success - ? match.Groups[1].Value.Trim() - : string.Empty; - - match = new Regex("Referer: (.*)").Match(header); - if (match.Success) - websocket.Extra["Referer"] = match.Groups[1].Value.Trim(); - else - { - match = new Regex("Origin: (.*)").Match(header); - websocket.Extra["Referer"] = match.Success - ? match.Groups[1].Value.Trim() - : string.Empty; - } - // add into the collection + websocket = new WebSocketImplementation(id, false, this._recycledStreamFactory, stream, options, new Uri($"ws{(this.Certificate != null ? "s" : "")}://{host}{path}"), endpoint, tcpClient.Client.LocalEndPoint, header.ToDictionary()); await this.AddWebSocketAsync(websocket).ConfigureAwait(false); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"The server WebSocket connection was successfully established ({websocket.ID} @ {websocket.RemoteEndPoint})\r\n- URI: {websocket.RequestUri}\r\n- Headers:\r\n\t{websocket.Headers.ToString("\r\n\t", kvp => $"{kvp.Key}: {kvp.Value}")}"); // callback try @@ -511,7 +540,8 @@ async Task AcceptClientAsync(TcpClient tcpClient) } catch (Exception e) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); } // receive messages @@ -525,14 +555,16 @@ async Task AcceptClientAsync(TcpClient tcpClient) } else { - this._logger.Log(LogLevel.Debug, LogLevel.Error, $"Error occurred while accepting an incoming connection request: {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while accepting an incoming connection request: {ex.Message}", ex); try { this.ErrorHandler?.Invoke(websocket, ex); } catch (Exception e) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); } } } @@ -542,16 +574,16 @@ async Task AcceptClientAsync(TcpClient tcpClient) #region Connect to remote endpoints as client async Task ConnectAsync(Uri uri, WebSocketOptions options, Action onSuccess = null, Action onFailure = null) { - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Attempting to connect ({uri})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Attempting to connect ({uri})"); + try { // connect the TCP client var id = Guid.NewGuid(); - var tcpClient = new TcpClient - { - NoDelay = options.NoDelay - }; - tcpClient.Client.SetKeepAliveInterval(); + + var tcpClient = new TcpClient(); + tcpClient.Client.SetOptions(options.NoDelay); if (IPAddress.TryParse(uri.Host, out IPAddress ipAddress)) { @@ -565,7 +597,11 @@ async Task ConnectAsync(Uri uri, WebSocketOptions options, Action options.IgnoreCertificateErrors || RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || sslPolicyErrors == SslPolicyErrors.None, + userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => sslPolicyErrors == SslPolicyErrors.None || options.IgnoreCertificateErrors || RuntimeInformation.IsOSPlatform(OSPlatform.Linux), userCertificateSelectionCallback: (sender, host, certificates, certificate, issuers) => this.Certificate ); - await (stream as SslStream).AuthenticateAsClientAsync(targetHost: uri.Host).WithCancellationToken(this._processingCTS.Token).ConfigureAwait(false); + await (stream as SslStream).AuthenticateAsClientAsync(targetHost: sniHost ?? uri.Host).WithCancellationToken(this._processingCTS.Token).ConfigureAwait(false); Events.Log.ConnectionSecured(id); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"The connection successfully secured ({id} @ {endpoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"The connection successfully secured ({id} @ {endpoint})"); } catch (OperationCanceledException) { @@ -601,16 +639,19 @@ async Task ConnectAsync(Uri uri, WebSocketOptions options, Action handshake += $"{kvp.Key}: {kvp.Value}\r\n"); + options.AdditionalHeaders?.Where(x => !x.Key.Equals("Host", StringComparison.OrdinalIgnoreCase)).ForEach(kvp => handshake += $"{kvp.Key}: {kvp.Value}\r\n"); Events.Log.SendingHandshake(id, handshake); await stream.WriteHeaderAsync(handshake, this._processingCTS.Token).ConfigureAwait(false); Events.Log.HandshakeSent(id, handshake); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Handshake request ({id} @ {endpoint}) => \r\n{handshake.Trim()}"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Handshake request ({id} @ {endpoint}) => \r\n{handshake.Trim()}"); // read response Events.Log.ReadingResponse(id); @@ -635,7 +677,8 @@ async Task ConnectAsync(Uri uri, WebSocketOptions options, Action \r\n{response.Trim()}"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Handshake response ({id} @ {endpoint}) => \r\n{response.Trim()}"); } catch (Exception ex) { @@ -650,7 +693,7 @@ async Task ConnectAsync(Uri uri, WebSocketOptions options, Action $"{kvp.Key}: {kvp.Value}")}"); // callback try @@ -697,7 +743,8 @@ async Task ConnectAsync(Uri uri, WebSocketOptions options, Action {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } // receive messages @@ -709,14 +756,16 @@ async Task ConnectAsync(Uri uri, WebSocketOptions options, Action {e.Message}", e); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); } } } @@ -761,7 +810,7 @@ public void Connect(string location, Action onSuccess, Action< => this.Connect(location, null, onSuccess, onFailure); #endregion - #region Wrap a WebSocket connection of ASP.NET / ASP.NET Core + #region Wrap a WebSocket connection /// /// Wraps a WebSocket connection of ASP.NET / ASP.NET Core and acts like a WebSocket server /// @@ -769,33 +818,18 @@ public void Connect(string location, Action onSuccess, Action< /// The original request URI of the WebSocket connection /// The remote endpoint of the WebSocket connection /// The local endpoint of the WebSocket connection - /// The string that presents the user agent of the client that made this request to the WebSocket connection - /// The string that presents the url referer of the client that made this request to the WebSocket connection - /// The string that presents the headers of the client that made this request to the WebSocket connection - /// The string that presents the cookies of the client that made this request to the WebSocket connection + /// The collection that presents the headers of the client that made this request to the WebSocket connection /// The action to fire when the WebSocket connection is wrap success /// A task that run the receiving process when wrap successful or an exception when failed - public Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, EndPoint remoteEndPoint = null, EndPoint localEndPoint = null, string userAgent = null, string urlReferer = null, string headers = null, string cookies = null, Action onSuccess = null) + public Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, EndPoint remoteEndPoint = null, EndPoint localEndPoint = null, Dictionary headers = null, Action onSuccess = null) { try { // create - var websocket = new WebSocketWrapper(webSocket, requestUri, remoteEndPoint, localEndPoint); - - if (!string.IsNullOrWhiteSpace(userAgent)) - websocket.Extra["User-Agent"] = userAgent; - - if (!string.IsNullOrWhiteSpace(urlReferer)) - websocket.Extra["Referer"] = urlReferer; - - if (!string.IsNullOrWhiteSpace(headers)) - websocket.Extra["Headers"] = headers; - - if (!string.IsNullOrWhiteSpace(cookies)) - websocket.Extra["Cookies"] = cookies; - + var websocket = new WebSocketWrapper(webSocket, requestUri, remoteEndPoint, localEndPoint, headers); this.AddWebSocket(websocket); - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Wrap a WebSocket connection [{webSocket.GetType()}] successful ({websocket.ID} @ {websocket.RemoteEndPoint}"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Wrap a WebSocket connection [{webSocket.GetType()}] successful ({websocket.ID} @ {websocket.RemoteEndPoint})\r\n- URI: {websocket.RequestUri}\r\n- Headers:\r\n\t{websocket.Headers.ToString("\r\n\t", kvp => $"{kvp.Key}: {kvp.Value}")}"); // callback try @@ -805,7 +839,8 @@ public Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, } catch (Exception ex) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } // receive messages @@ -813,25 +848,12 @@ public Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, } catch (Exception ex) { - this._logger.Log(LogLevel.Debug, LogLevel.Error, $"Unable to wrap a WebSocket connection [{webSocket.GetType()}]: {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Unable to wrap a WebSocket connection [{webSocket.GetType()}]: {ex.Message}", ex); return Task.FromException(new WrapWebSocketFailedException($"Unable to wrap a WebSocket connection [{webSocket.GetType()}]", ex)); } } - /// - /// Wraps a WebSocket connection of ASP.NET / ASP.NET Core and acts like a WebSocket server - /// - /// The WebSocket connection of ASP.NET / ASP.NET Core - /// The original request URI of the WebSocket connection - /// The remote endpoint of the WebSocket connection - /// The local endpoint of the WebSocket connection - /// The string that presents the user agent of the client that made this request to the WebSocket connection - /// The string that presents the url referer of the client that made this request to the WebSocket connection - /// The action to fire when the WebSocket connection is wrap success - /// A task that run the receiving process when wrap successful or an exception when failed - public Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, EndPoint remoteEndPoint, EndPoint localEndPoint, string userAgent, string urlReferer, Action onSuccess) - => this.WrapAsync(webSocket, requestUri, remoteEndPoint, localEndPoint, userAgent, urlReferer, null, null, onSuccess); - /// /// Wraps a WebSocket connection of ASP.NET / ASP.NET Core and acts like a WebSocket server /// @@ -840,7 +862,7 @@ public Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, /// The remote endpoint of the WebSocket connection /// A task that run the receiving process when wrap successful or an exception when failed public Task WrapAsync(System.Net.WebSockets.WebSocket webSocket, Uri requestUri, EndPoint remoteEndPoint) - => this.WrapAsync(webSocket, requestUri, remoteEndPoint, null, null, null, null, null, null); + => this.WrapAsync(webSocket, requestUri, remoteEndPoint, null, new Dictionary(), null); #endregion #region Receive messages @@ -868,28 +890,34 @@ async Task ReceiveAsync(ManagedWebSocket websocket) closeStatusDescription = websocket.IsClient ? "Disconnected" : "Service is unavailable"; } - this.CloseWebSocket(websocket, closeStatus, closeStatusDescription); + await this.CloseWebSocketAsync(websocket, closeStatus, closeStatusDescription).ConfigureAwait(false); try { this.ConnectionBrokenHandler?.Invoke(websocket); } catch (Exception e) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); } if (ex is OperationCanceledException || ex is TaskCanceledException || ex is ObjectDisposedException || ex is WebSocketException || ex is SocketException || ex is IOException) - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"Stop receiving process when got an error: {ex.Message} ({ex.GetType().GetTypeName(true)})"); + { + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Stop receiving process when got an error: {ex.Message} ({ex.GetType().GetTypeName(true)})"); + } else { - this._logger.Log(LogLevel.Debug, LogLevel.Error, closeStatusDescription, ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, closeStatusDescription, ex); try { this.ErrorHandler?.Invoke(websocket, ex); } catch (Exception e) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); } } return; @@ -898,15 +926,17 @@ async Task ReceiveAsync(ManagedWebSocket websocket) // message to close if (result.MessageType == WebSocketMessageType.Close) { - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"The remote endpoint is initiated to close - Status: {result.CloseStatus} - Description: {result.CloseStatusDescription ?? "N/A"} ({websocket.ID} @ {websocket.RemoteEndPoint})"); - this.CloseWebSocket(websocket); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"The remote endpoint is initiated to close - Status: {result.CloseStatus} - Description: {result.CloseStatusDescription ?? "N/A"} ({websocket.ID} @ {websocket.RemoteEndPoint})"); + await this.CloseWebSocketAsync(websocket).ConfigureAwait(false); try { this.ConnectionBrokenHandler?.Invoke(websocket); } catch (Exception ex) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } return; } @@ -915,10 +945,11 @@ async Task ReceiveAsync(ManagedWebSocket websocket) if (result.Count > WebSocketHelper.ReceiveBufferSize) { var message = $"WebSocket frame cannot exceed buffer size of {WebSocketHelper.ReceiveBufferSize:#,##0} bytes"; - this._logger.Log(LogLevel.Debug, LogLevel.Debug, $"Close the connection because {message} ({websocket.ID} @ {websocket.RemoteEndPoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"Close the connection because {message} ({websocket.ID} @ {websocket.RemoteEndPoint})"); await websocket.CloseAsync(WebSocketCloseStatus.MessageTooBig, $"{message}, send multiple frames instead.", CancellationToken.None).ConfigureAwait(false); - this.CloseWebSocket(websocket); + await this.CloseWebSocketAsync(websocket).ConfigureAwait(false); try { this.ConnectionBrokenHandler?.Invoke(websocket); @@ -926,7 +957,8 @@ async Task ReceiveAsync(ManagedWebSocket websocket) } catch (Exception ex) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } return; } @@ -934,33 +966,36 @@ async Task ReceiveAsync(ManagedWebSocket websocket) // got a message if (result.Count > 0) { - this._logger.Log(LogLevel.Trace, LogLevel.Debug, $"A message was received - Type: {result.MessageType} - End of message: {result.EndOfMessage} - Length: {result.Count:#,##0} ({websocket.ID} @ {websocket.RemoteEndPoint})"); + if (this._logger.IsEnabled(LogLevel.Trace)) + this._logger.Log(LogLevel.Debug, $"A message was received - Type: {result.MessageType} - End of message: {result.EndOfMessage} - Length: {result.Count:#,##0} ({websocket.ID} @ {websocket.RemoteEndPoint})"); try { this.MessageReceivedHandler?.Invoke(websocket, result, buffer.Take(result.Count).ToArray()); } catch (Exception ex) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } } // wait for next round - if (this.AwaitInterval > 0) + if (this.ReceivingAwaitInterval.Ticks > 0) try { - await Task.Delay(this.AwaitInterval, this._processingCTS.Token).ConfigureAwait(false); + await Task.Delay(this.ReceivingAwaitInterval, this._processingCTS.Token).ConfigureAwait(false); } catch { - this.CloseWebSocket(websocket, websocket.IsClient ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.EndpointUnavailable, websocket.IsClient ? "Disconnected" : "Service is unavailable"); + await this.CloseWebSocketAsync(websocket, websocket.IsClient ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.EndpointUnavailable, websocket.IsClient ? "Disconnected" : "Service is unavailable").ConfigureAwait(false); try { this.ConnectionBrokenHandler?.Invoke(websocket); } catch (Exception ex) { - this._logger.Log(LogLevel.Information, LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {ex.Message}", ex); } return; } @@ -982,10 +1017,29 @@ async Task ReceiveAsync(ManagedWebSocket websocket) /// true if this message is a standalone message (this is the norm), false if it is a multi-part message (and true for the last message) /// The cancellation token /// - public Task SendAsync(Guid id, ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) - => this._websockets.TryGetValue(id, out ManagedWebSocket websocket) - ? websocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken) - : Task.FromException(new InformationNotFoundException($"WebSocket connection with identity \"{id}\" is not found")); + public async Task SendAsync(Guid id, ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) + { + ManagedWebSocket websocket = null; + try + { + if (this._websockets.TryGetValue(id, out websocket)) + await websocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken).ConfigureAwait(false); + else + throw new InformationNotFoundException($"WebSocket connection with identity \"{id}\" is not found"); + } + catch (Exception ex) + { + try + { + this.ErrorHandler?.Invoke(websocket, ex); + } + catch (Exception e) + { + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + } + } + } /// /// Sends the message to a WebSocket connection @@ -995,10 +1049,29 @@ async Task ReceiveAsync(ManagedWebSocket websocket) /// true if this message is a standalone message (this is the norm), false if it is a multi-part message (and true for the last message) /// The cancellation token /// - public Task SendAsync(Guid id, string message, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) - => this._websockets.TryGetValue(id, out ManagedWebSocket websocket) - ? websocket.SendAsync(message, endOfMessage, cancellationToken) - : Task.FromException(new InformationNotFoundException($"WebSocket connection with identity \"{id}\" is not found")); + public async Task SendAsync(Guid id, string message, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) + { + ManagedWebSocket websocket = null; + try + { + if (this._websockets.TryGetValue(id, out websocket)) + await websocket.SendAsync(message, endOfMessage, cancellationToken).ConfigureAwait(false); + else + throw new InformationNotFoundException($"WebSocket connection with identity \"{id}\" is not found"); + } + catch (Exception ex) + { + try + { + this.ErrorHandler?.Invoke(websocket, ex); + } + catch (Exception e) + { + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + } + } + } /// /// Sends the message to a WebSocket connection @@ -1008,10 +1081,29 @@ async Task ReceiveAsync(ManagedWebSocket websocket) /// true if this message is a standalone message (this is the norm), false if it is a multi-part message (and true for the last message) /// The cancellation token /// - public Task SendAsync(Guid id, byte[] message, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) - => this._websockets.TryGetValue(id, out ManagedWebSocket websocket) - ? websocket.SendAsync(message, endOfMessage, cancellationToken) - : Task.FromException(new InformationNotFoundException($"WebSocket connection with identity \"{id}\" is not found")); + public async Task SendAsync(Guid id, byte[] message, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) + { + ManagedWebSocket websocket = null; + try + { + if (this._websockets.TryGetValue(id, out websocket)) + await websocket.SendAsync(message, endOfMessage, cancellationToken).ConfigureAwait(false); + else + throw new InformationNotFoundException($"WebSocket connection with identity \"{id}\" is not found"); + } + catch (Exception ex) + { + try + { + this.ErrorHandler?.Invoke(websocket, ex); + } + catch (Exception e) + { + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + } + } + } /// /// Sends the message to the WebSocket connections that matched with the predicate @@ -1022,8 +1114,26 @@ async Task ReceiveAsync(ManagedWebSocket websocket) /// true if this message is a standalone message (this is the norm), false if it is a multi-part message (and true for the last message) /// The cancellation token /// - public Task SendAsync(Func predicate, ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) => this.GetWebSockets(predicate).ToList().ForEachAsync((connection, token) - => connection.SendAsync(buffer.Clone(), messageType, endOfMessage, token), cancellationToken); + public Task SendAsync(Func predicate, ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken = default(CancellationToken)) + => this.GetWebSockets(predicate).ForEachAsync(async (websocket, token) => + { + try + { + await websocket.SendAsync(buffer.Clone(), messageType, endOfMessage, token).ConfigureAwait(false); + } + catch (Exception ex) + { + try + { + this.ErrorHandler?.Invoke(websocket, ex); + } + catch (Exception e) + { + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.Log(LogLevel.Error, $"Error occurred while calling the handler => {e.Message}", e); + } + } + }, cancellationToken); /// /// Sends the message to the WebSocket connections that matched with the predicate @@ -1059,7 +1169,7 @@ async Task AddWebSocketAsync(ManagedWebSocket websocket) if (!this.AddWebSocket(websocket)) { if (websocket != null) - await Task.Delay(UtilityService.GetRandomNumber(123, 456)).ConfigureAwait(false); + await Task.Delay(UtilityService.GetRandomNumber(123, 456), this._processingCTS.Token).ConfigureAwait(false); return this.AddWebSocket(websocket); } return true; @@ -1092,15 +1202,28 @@ public IEnumerable GetWebSockets(Func /// The close status to use /// A description of why we are closing /// - bool CloseWebsocket(ManagedWebSocket websocket, WebSocketCloseStatus closeStatus, string closeStatusDescription) + async Task CloseWebsocketAsync(ManagedWebSocket websocket, WebSocketCloseStatus closeStatus, string closeStatusDescription) { if (websocket.State == WebSocketState.Open) - Task.Run(() => websocket.DisposeAsync(closeStatus, closeStatusDescription)).ConfigureAwait(false); + await websocket.DisposeAsync(closeStatus, closeStatusDescription).ConfigureAwait(false); else websocket.Close(); return true; } + /// + /// Closes the WebSocket connection and remove from the centralized collections + /// + /// The WebSocket connection to close + /// The close status to use + /// A description of why we are closing + /// + bool CloseWebsocket(ManagedWebSocket websocket, WebSocketCloseStatus closeStatus, string closeStatusDescription) + { + Task.Run(() => this.CloseWebsocketAsync(websocket, closeStatus, closeStatusDescription)).ConfigureAwait(false); + return true; + } + /// /// Closes the WebSocket connection and remove from the centralized collections /// @@ -1109,10 +1232,22 @@ bool CloseWebsocket(ManagedWebSocket websocket, WebSocketCloseStatus closeStatus /// A description of why we are closing /// true if closed and destroyed public bool CloseWebSocket(Guid id, WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable") - => this._websockets.TryRemove(id, out ManagedWebSocket websocket) + => this._websockets.TryRemove(id, out var websocket) ? this.CloseWebsocket(websocket, closeStatus, closeStatusDescription) : false; + /// + /// Closes the WebSocket connection and remove from the centralized collections + /// + /// The identity of a WebSocket connection to close + /// The close status to use + /// A description of why we are closing + /// true if closed and destroyed + public Task CloseWebSocketAsync(Guid id, WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable") + => this._websockets.TryRemove(id, out var websocket) + ? this.CloseWebsocketAsync(websocket, closeStatus, closeStatusDescription) + : Task.FromResult(false); + /// /// Closes the WebSocket connection and remove from the centralized collections /// @@ -1123,7 +1258,19 @@ public bool CloseWebSocket(Guid id, WebSocketCloseStatus closeStatus = WebSocket public bool CloseWebSocket(ManagedWebSocket websocket, WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable") => websocket == null ? false - : this.CloseWebsocket(this._websockets.TryRemove(websocket.ID, out ManagedWebSocket webSocket) ? webSocket : websocket, closeStatus, closeStatusDescription); + : this.CloseWebsocket(this._websockets.TryRemove(websocket.ID, out var webSocket) ? webSocket : websocket, closeStatus, closeStatusDescription); + + /// + /// Closes the WebSocket connection and remove from the centralized collections + /// + /// The WebSocket connection to close + /// The close status to use + /// A description of why we are closing + /// true if closed and destroyed + public Task CloseWebSocketAsync(ManagedWebSocket websocket, WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable") + => websocket == null + ? Task.FromResult(false) + : this.CloseWebsocketAsync(this._websockets.TryRemove(websocket.ID, out var webSocket) ? webSocket : websocket, closeStatus, closeStatusDescription); #endregion #region Dispose @@ -1218,6 +1365,10 @@ public abstract class ManagedWebSocket : System.Net.WebSockets.WebSocket /// Gets the state to include the full exception (with stack trace) in the close response when an exception is encountered and the WebSocket connection is closed /// protected abstract bool IncludeExceptionInCloseResponse { get; } + + protected bool IsDisposing { get; set; } = false; + + protected bool IsDisposed { get; set; } = false; #endregion #region Methods @@ -1291,21 +1442,23 @@ await Task.WhenAll( /// /// Cleans up unmanaged resources (will send a close frame if the connection is still open) /// - public override void Dispose() => this.DisposeAsync().Wait(4321); - - protected bool _disposing = false, _disposed = false; + public override void Dispose() + => this.DisposeAsync().GetAwaiter().GetResult(); - internal virtual async Task DisposeAsync(WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable", CancellationToken cancellationToken = default(CancellationToken), Action onCompleted = null) + internal virtual async Task DisposeAsync(WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable", CancellationToken cancellationToken = default(CancellationToken), Action onDisposed = null) { - if (!this._disposing && !this._disposed) + if (!this.IsDisposing && !this.IsDisposed) { - this._disposing = true; + this.IsDisposing = true; Events.Log.WebSocketDispose(this.ID, this.State); - if (this.State == WebSocketState.Open) - await this.CloseOutputTimeoutAsync(closeStatus, closeStatusDescription, null, () => Events.Log.WebSocketDisposeCloseTimeout(this.ID, this.State), ex => Events.Log.WebSocketDisposeError(this.ID, this.State, ex.ToString())).ConfigureAwait(false); - onCompleted?.Invoke(); - this._disposed = true; - this._disposing = false; + await Task.WhenAll(this.State == WebSocketState.Open ? this.CloseOutputTimeoutAsync(closeStatus, closeStatusDescription, null, () => Events.Log.WebSocketDisposeCloseTimeout(this.ID, this.State), ex => Events.Log.WebSocketDisposeError(this.ID, this.State, ex.ToString())) : Task.CompletedTask).ConfigureAwait(false); + try + { + onDisposed?.Invoke(); + } + catch { } + this.IsDisposed = true; + this.IsDisposing = false; } } @@ -1318,5 +1471,54 @@ internal virtual void Close() { } } #endregion + #region Extra information + /// + /// Sets the value of a specified key of the extra information + /// + /// + /// + /// + public void Set(string key, T value) + => this.Extra[key] = value; + + /// + /// Gets the value of a specified key from the extra information + /// + /// + /// + /// + /// + public T Get(string key, T @default = default(T)) + => this.Extra.TryGetValue(key, out object value) && value != null && value is T + ? (T)value + : @default; + + /// + /// Removes the value of a specified key from the extra information + /// + /// + /// + public bool Remove(string key) + => this.Extra.Remove(key); + + /// + /// Removes the value of a specified key from the extra information + /// + /// + /// + /// + /// + public bool Remove(string key, out T value) + { + value = this.Get(key); + return this.Remove(key); + } + + /// + /// Gets the header information of the WebSocket connection + /// + public Dictionary Headers => this.Get("Headers", new Dictionary()); + #endregion + } } \ No newline at end of file diff --git a/WebSocketHelper.cs b/WebSocketHelper.cs index f663cce..b4d81f9 100644 --- a/WebSocketHelper.cs +++ b/WebSocketHelper.cs @@ -44,8 +44,7 @@ public static Func GetRecyclableMemoryStreamFactory() { var buffer = new byte[WebSocketHelper.ReceiveBufferSize]; var offset = 0; - var read = 0; - + int read; do { if (offset >= WebSocketHelper.ReceiveBufferSize) @@ -74,35 +73,39 @@ public static Func GetRecyclableMemoryStreamFactory() public static Task WriteHeaderAsync(this Stream stream, string header, CancellationToken cancellationToken = default(CancellationToken)) => stream.WriteAsync((header.Trim() + "\r\n\r\n").ToArraySegment(), cancellationToken); - /// - /// Computes a WebSocket accept key from a given key - /// - /// The WebSocket request key - /// A WebSocket accept key - public static string ComputeAcceptKey(this string key) + internal static string ComputeAcceptKey(this string key) => (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").GetHash("SHA1").ToBase64(); - /// - /// Negotiates sub-protocol - /// - /// - /// - /// - public static string NegotiateSubProtocol(this IEnumerable requestedSubProtocols, IEnumerable supportedSubProtocols) + internal static string NegotiateSubProtocol(this IEnumerable requestedSubProtocols, IEnumerable supportedSubProtocols) => requestedSubProtocols == null || supportedSubProtocols == null || !requestedSubProtocols.Any() || !supportedSubProtocols.Any() ? null : requestedSubProtocols.Intersect(supportedSubProtocols).FirstOrDefault() ?? throw new SubProtocolNegotiationFailedException("Unable to negotiate a sub-protocol"); - /// - /// Set keep-alive interval to something more reasonable (because the TCP keep-alive default values of Windows are huge ~7200s) - /// - /// - /// - /// - public static void SetKeepAliveInterval(this Socket socket, uint keepaliveInterval = 60000, uint retryInterval = 10000) + internal static void SetOptions(this Socket socket, bool noDelay = true, bool dualMode = false, uint keepaliveInterval = 60000, uint retryInterval = 10000) { + // general options + socket.NoDelay = noDelay; + if (dualMode) + { + socket.DualMode = true; + socket.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.IPv6Only, false); + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, 1); + } + + // specifict options (only avalable when running on Windows) if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) socket.IOControl(IOControlCode.KeepAliveValues, ((uint)1).ToBytes().Concat(keepaliveInterval.ToBytes(), retryInterval.ToBytes()), null); } + + internal static Dictionary ToDictionary(this string @string, Action> onPreCompleted = null) + { + var dictionary = string.IsNullOrWhiteSpace(@string) + ? new Dictionary(StringComparer.OrdinalIgnoreCase) + : @string.Replace("\r", "").ToList("\n") + .Where(header => header.IndexOf(":") > 0) + .ToDictionary(header => header.Left(header.IndexOf(":")).Trim(), header => header.Right(header.Length - header.IndexOf(":") - 1).Trim(), StringComparer.OrdinalIgnoreCase); + onPreCompleted?.Invoke(dictionary); + return dictionary; + } } } \ No newline at end of file diff --git a/WebSocketImplementation.cs b/WebSocketImplementation.cs index 77071f5..6327f4e 100644 --- a/WebSocketImplementation.cs +++ b/WebSocketImplementation.cs @@ -1,11 +1,13 @@ #region Related components using System; using System.Net; +using System.Linq; using System.IO; using System.IO.Compression; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; +using System.Collections.Generic; using System.Collections.Concurrent; using Microsoft.Extensions.Logging; using net.vieapps.Components.WebSockets.Exceptions; @@ -14,23 +16,24 @@ namespace net.vieapps.Components.WebSockets { - internal class WebSocketImplementation : ManagedWebSocket + public class WebSocketImplementation : ManagedWebSocket { #region Properties readonly Func _recycledStreamFactory; readonly Stream _stream; - readonly IPingPongManager _pingpongManager; - WebSocketState _state; - WebSocketMessageType _continuationFrameMessageType = WebSocketMessageType.Binary; - WebSocketCloseStatus? _closeStatus; - string _closeStatusDescription; - bool _isContinuationFrame, _writting = false; + readonly PingPongManager _pingpongManager; + readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1); readonly string _subProtocol; readonly CancellationTokenSource _processingCTS; readonly ConcurrentQueue> _buffers = new ConcurrentQueue>(); - - public event EventHandler Pong; + readonly ILogger _logger; + WebSocketState _state; + WebSocketMessageType _continuationMessageType = WebSocketMessageType.Binary; + WebSocketCloseStatus? _closeStatus; + string _closeStatusDescription; + bool _isContinuationFrame = false; + bool _pending = false; /// /// Gets the state that indicates the reason why the remote endpoint initiated the close handshake @@ -58,52 +61,55 @@ internal class WebSocketImplementation : ManagedWebSocket protected override bool IncludeExceptionInCloseResponse { get; } #endregion - public WebSocketImplementation(Guid id, bool isClient, Func recycledStreamFactory, Stream stream, WebSocketOptions options, Uri requestUri, EndPoint remoteEndPoint, EndPoint localEndPoint) + internal WebSocketImplementation(Guid id, bool isClient, Func recycledStreamFactory, Stream stream, WebSocketOptions options, Uri requestUri, EndPoint remoteEndPoint, EndPoint localEndPoint, Dictionary headers) { - if (options.KeepAliveInterval.Ticks < 0) - throw new ArgumentException("KeepAliveInterval must be Zero or positive", nameof(options)); - this.ID = id; this.IsClient = isClient; this.IncludeExceptionInCloseResponse = options.IncludeExceptionInCloseResponse; - this.KeepAliveInterval = options.KeepAliveInterval; + this.KeepAliveInterval = options.KeepAliveInterval.Ticks < 0 ? TimeSpan.FromSeconds(60) : options.KeepAliveInterval; this.RequestUri = requestUri; this.RemoteEndPoint = remoteEndPoint; this.LocalEndPoint = localEndPoint; + this.Set("Headers", headers); this._recycledStreamFactory = recycledStreamFactory ?? WebSocketHelper.GetRecyclableMemoryStreamFactory(); this._stream = stream; this._state = WebSocketState.Open; this._subProtocol = options.SubProtocol; this._processingCTS = new CancellationTokenSource(); - - if (this.KeepAliveInterval == TimeSpan.Zero) - Events.Log.KeepAliveIntervalZero(this.ID); - else - this._pingpongManager = new PingPongManager(this, this._processingCTS.Token); + this._pingpongManager = new PingPongManager(this, options, this._processingCTS.Token); + this._logger = Logger.CreateLogger(); } /// /// Puts data on the wire /// - /// + /// /// /// - async Task PutOnTheWireAsync(MemoryStream data, CancellationToken cancellationToken) + async Task PutOnTheWireAsync(MemoryStream stream, CancellationToken cancellationToken) { - // add into queue - this._buffers.Enqueue(data.ToArraySegment()); + // check disposed + if (this.IsDisposed) + { + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.LogWarning($"Object disposed => {this.ID}"); + throw new ObjectDisposedException($"WebSocketImplementation => {this.ID}"); + } - // check pending write operations - if (this._writting) + // add into queue and check pending operations + this._buffers.Enqueue(stream.ToArraySegment()); + if (this._pending) { Events.Log.PendingOperations(this.ID); - Logger.Log(LogLevel.Debug, LogLevel.Warning, $"Pending operations => {this._buffers.Count:#,##0} ({this.ID} @ {this.RemoteEndPoint})"); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.LogWarning($"#{Thread.CurrentThread.ManagedThreadId} Pendings => {this._buffers.Count:#,##0} ({this.ID} @ {this.RemoteEndPoint})"); return; } // put data to wire - this._writting = true; + this._pending = true; + await this._lock.WaitAsync(cancellationToken).ConfigureAwait(false); try { while (this._buffers.Count > 0) @@ -116,7 +122,8 @@ async Task PutOnTheWireAsync(MemoryStream data, CancellationToken cancellationTo } finally { - this._writting = false; + this._pending = false; + this._lock.Release(); } } @@ -145,27 +152,27 @@ public override async Task ReceiveAsync(ArraySegment ReceiveAsync(ArraySegment(buffer.Array, buffer.Offset, frame.Count), cts.Token).ConfigureAwait(false); + await this._pingpongManager.SendPongAsync(buffer.Take(frame.Count).ToArray()).ConfigureAwait(false); break; case WebSocketOpCode.Pong: - this.Pong?.Invoke(this, new PongEventArgs(new ArraySegment(buffer.Array, frame.Count, buffer.Offset))); + this._pingpongManager.OnPong(buffer.Take(frame.Count).ToArray()); break; case WebSocketOpCode.Text: // continuation frames will follow, record the message type as text if (!frame.IsFinBitSet) - this._continuationFrameMessageType = WebSocketMessageType.Text; + this._continuationMessageType = WebSocketMessageType.Text; return new WebSocketReceiveResult(frame.Count, WebSocketMessageType.Text, frame.IsFinBitSet); case WebSocketOpCode.Binary: // continuation frames will follow, record the message type as binary if (!frame.IsFinBitSet) - this._continuationFrameMessageType = WebSocketMessageType.Binary; + this._continuationMessageType = WebSocketMessageType.Binary; return new WebSocketReceiveResult(frame.Count, WebSocketMessageType.Binary, frame.IsFinBitSet); case WebSocketOpCode.Continuation: - return new WebSocketReceiveResult(frame.Count, this._continuationFrameMessageType, frame.IsFinBitSet); + return new WebSocketReceiveResult(frame.Count, this._continuationMessageType, frame.IsFinBitSet); default: var ex = new NotSupportedException($"Unknown WebSocket op-code: {frame.OpCode}"); @@ -211,6 +218,8 @@ public override async Task ReceiveAsync(ArraySegment {ex.Message}"); throw ex; } } @@ -252,24 +261,18 @@ async Task RespondToCloseFrameAsync(WebSocketFrame frame return new WebSocketReceiveResult(frame.Count, WebSocketMessageType.Close, frame.IsFinBitSet, frame.CloseStatus, frame.CloseStatusDescription); } - /// - /// Called when a Pong frame is received - /// - /// - protected virtual void OnPong(PongEventArgs args) => this.Pong?.Invoke(this, args); - /// /// Calls this when got ping messages (pong payload must be 125 bytes or less, pong should contain the same payload as the ping) /// /// /// /// - async Task SendPongAsync(ArraySegment payload, CancellationToken cancellationToken) + public async Task SendPongAsync(ArraySegment payload, CancellationToken cancellationToken) { // exceeded max length if (payload.Count > 125) { - var ex = new BufferOverflowException($"Max pong message size is 125 bytes, exceeded: {payload.Count}"); + var ex = new BufferOverflowException($"Max PONG message size is 125 bytes, exceeded: {payload.Count}"); await this.CloseOutputTimeoutAsync(WebSocketCloseStatus.ProtocolError, ex.Message, ex).ConfigureAwait(false); throw ex; } @@ -286,7 +289,7 @@ async Task SendPongAsync(ArraySegment payload, CancellationToken cancellat } catch (Exception ex) { - await this.CloseOutputTimeoutAsync(WebSocketCloseStatus.EndpointUnavailable, "Unable to send Pong response", ex).ConfigureAwait(false); + await this.CloseOutputTimeoutAsync(WebSocketCloseStatus.EndpointUnavailable, "Unable to send PONG response", ex).ConfigureAwait(false); throw; } } @@ -300,7 +303,7 @@ async Task SendPongAsync(ArraySegment payload, CancellationToken cancellat public async Task SendPingAsync(ArraySegment payload, CancellationToken cancellationToken) { if (payload.Count > 125) - throw new BufferOverflowException($"Max ping message size is 125 bytes, exceeded: {payload.Count}"); + throw new BufferOverflowException($"Max PING message size is 125 bytes, exceeded: {payload.Count}"); if (this._state == WebSocketState.Open) using (var stream = this._recycledStreamFactory()) @@ -341,13 +344,14 @@ public override async Task SendAsync(ArraySegment buffer, WebSocketMessage } // send - using (var stream = this._recycledStreamFactory()) - { - stream.Write(opCode, buffer, endOfMessage, this.IsClient); - Events.Log.SendingFrame(this.ID, opCode, endOfMessage, buffer.Count, false); - await this.PutOnTheWireAsync(stream, cancellationToken).ConfigureAwait(false); - this._isContinuationFrame = !endOfMessage; - } + if (this._state == WebSocketState.Open) + using (var stream = this._recycledStreamFactory()) + { + stream.Write(opCode, buffer, endOfMessage, this.IsClient); + Events.Log.SendingFrame(this.ID, opCode, endOfMessage, buffer.Count, false); + await this.PutOnTheWireAsync(stream, cancellationToken).ConfigureAwait(false); + this._isContinuationFrame = !endOfMessage; + } } /// @@ -428,16 +432,25 @@ public override void Abort() this._processingCTS.Cancel(); } - internal override Task DisposeAsync(WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable", CancellationToken cancellationToken = default(CancellationToken), Action onCompleted = null) + internal override Task DisposeAsync(WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable", CancellationToken cancellationToken = default(CancellationToken), Action onDisposed = null) => base.DisposeAsync(closeStatus, closeStatusDescription, cancellationToken, () => { this.Close(); - onCompleted?.Invoke(); + try + { + onDisposed?.Invoke(); + } + catch { } + try + { + this._lock.Dispose(); + } + catch { } }); internal override void Close() { - if (!this._disposing && !this._disposed) + if (!this.IsDisposing && !this.IsDisposed) { this._processingCTS.Cancel(); this._processingCTS.Dispose(); diff --git a/WebSocketOptions.cs b/WebSocketOptions.cs index 9633d70..8c72eff 100644 --- a/WebSocketOptions.cs +++ b/WebSocketOptions.cs @@ -12,26 +12,24 @@ public class WebSocketOptions /// Gets or sets how often to send ping requests to the remote endpoint /// /// - /// This is done to prevent proxy servers from closing your connection - /// The default is TimeSpan.Zero meaning that it is disabled. + /// This is done to prevent proxy servers from closing your connection, the default is TimeSpan.Zero meaning that it is disabled. /// WebSocket servers usually send ping messages so it is not normally necessary for the client to send them (hence the TimeSpan.Zero default) - /// You can manually control ping pong messages using the PingPongManager class. - /// If you do that it is advisible to set this KeepAliveInterval to zero + /// You can manually control ping pong messages using the PingPongManager class. If you do that it is advisible to set this KeepAliveInterval to zero. /// public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.Zero; /// - /// Gets or sets the sub-protocol (Sec-WebSocket-Protocol) + /// Gets or Sets the sub-protocol (Sec-WebSocket-Protocol) /// public string SubProtocol { get; set; } /// - /// Gets or sets the extensions (Sec-WebSocket-Extensions) + /// Gets or Sets the extensions (Sec-WebSocket-Extensions) /// public string Extensions { get; set; } /// - /// Gets or sets state to send a message immediately or not + /// Gets or Sets state to send a message immediately or not /// /// /// Set to true to send a message immediately with the least amount of latency (typical usage for chat) @@ -42,12 +40,12 @@ public class WebSocketOptions public bool NoDelay { get; set; } = true; /// - /// Gets or sets the additional headers + /// Gets or Sets the additional headers /// public Dictionary AdditionalHeaders { get; set; } = new Dictionary(StringComparer.OrdinalIgnoreCase); /// - /// Gets or sets the state to include the full exception (with stack trace) in the close response when an exception is encountered and the WebSocket connection is closed + /// Gets or Sets the state to include the full exception (with stack trace) in the close response when an exception is encountered and the WebSocket connection is closed /// /// /// The default is false @@ -55,11 +53,26 @@ public class WebSocketOptions public bool IncludeExceptionInCloseResponse { get; set; } = false; /// - /// Gets or sets whether remote certificate errors should be ignored + /// Gets or Sets whether remote certificate errors should be ignored /// /// /// The default is false /// public bool IgnoreCertificateErrors { get; set; } = false; + + /// + /// Gets or Sets the function to prepare the custom 'PING' playload to send a 'PING' message + /// + public Func GetPingPayload { get; set; } + + /// + /// Gets or Sets the function to prepare the custom 'PONG' playload to response to a 'PING' message + /// + public Func GetPongPayload { get; set; } + + /// + /// Gets or Sets the action to fire when a 'PONG' message has been sent + /// + public Action OnPong { get; set; } } } \ No newline at end of file diff --git a/WebSocketWrapper.cs b/WebSocketWrapper.cs index 5369258..17e09bb 100644 --- a/WebSocketWrapper.cs +++ b/WebSocketWrapper.cs @@ -5,6 +5,7 @@ using System.Threading; using System.Threading.Tasks; using System.Collections.Concurrent; +using System.Collections.Generic; using Microsoft.Extensions.Logging; using net.vieapps.Components.Utility; #endregion @@ -17,7 +18,9 @@ internal class WebSocketWrapper : ManagedWebSocket #region Properties readonly System.Net.WebSockets.WebSocket _websocket = null; readonly ConcurrentQueue, WebSocketMessageType, bool>> _buffers = new ConcurrentQueue, WebSocketMessageType, bool>>(); - bool _sending = false; + readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1); + readonly ILogger _logger; + bool _pending = false; /// /// Gets the state that indicates the reason why the remote endpoint initiated the close handshake @@ -45,13 +48,15 @@ internal class WebSocketWrapper : ManagedWebSocket protected override bool IncludeExceptionInCloseResponse { get; } = false; #endregion - public WebSocketWrapper(System.Net.WebSockets.WebSocket websocket, Uri requestUri, EndPoint remoteEndPoint = null, EndPoint localEndPoint = null) + public WebSocketWrapper(System.Net.WebSockets.WebSocket websocket, Uri requestUri, EndPoint remoteEndPoint, EndPoint localEndPoint, Dictionary headers) { this._websocket = websocket; + this._logger = Logger.CreateLogger(); this.ID = Guid.NewGuid(); this.RequestUri = requestUri; this.RemoteEndPoint = remoteEndPoint; this.LocalEndPoint = localEndPoint; + this.Set("Headers", headers); } /// @@ -73,17 +78,27 @@ public override Task ReceiveAsync(ArraySegment buf /// public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - // add into queue and check pending write operations + // check disposed + if (this.IsDisposed) + { + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.LogWarning($"Object disposed => {this.ID}"); + throw new ObjectDisposedException($"WebSocketWrapper => {this.ID}"); + } + + // add into queue and check pending operations this._buffers.Enqueue(new Tuple, WebSocketMessageType, bool>(buffer, messageType, endOfMessage)); - if (this._sending) + if (this._pending) { Events.Log.PendingOperations(this.ID); - Logger.Log(LogLevel.Debug, LogLevel.Warning, $"Pending operations => {this._buffers.Count:#,##0} ({this.ID} @ {this.RemoteEndPoint})"); + if (this._logger.IsEnabled(LogLevel.Debug)) + this._logger.LogWarning($"#{Thread.CurrentThread.ManagedThreadId} Pendings => {this._buffers.Count:#,##0} ({this.ID} @ {this.RemoteEndPoint})"); return; } // put data to wire - this._sending = true; + this._pending = true; + await this._lock.WaitAsync(cancellationToken).ConfigureAwait(false); try { while (this.State == WebSocketState.Open && this._buffers.Count > 0) @@ -96,7 +111,8 @@ public override async Task SendAsync(ArraySegment buffer, WebSocketMessage } finally { - this._sending = false; + this._pending = false; + this._lock.Release(); } } @@ -126,20 +142,26 @@ public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string c public override void Abort() => this._websocket.Abort(); - internal override Task DisposeAsync(WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable", CancellationToken cancellationToken = default(CancellationToken), Action onCompleted = null) + internal override Task DisposeAsync(WebSocketCloseStatus closeStatus = WebSocketCloseStatus.EndpointUnavailable, string closeStatusDescription = "Service is unavailable", CancellationToken cancellationToken = default(CancellationToken), Action onDisposed = null) => base.DisposeAsync(closeStatus, closeStatusDescription, cancellationToken, () => { this.Close(); - onCompleted?.Invoke(); + try + { + onDisposed?.Invoke(); + } + catch { } + try + { + this._lock.Dispose(); + } + catch { } }); internal override void Close() { - if (!this._disposing && !this._disposed) - { - if ("System.Net.WebSockets.ManagedWebSocket".Equals($"{this._websocket.GetType()}")) - this._websocket.Dispose(); - } + if (!this.IsDisposing && !this.IsDisposed && "System.Net.WebSockets.ManagedWebSocket".Equals($"{this._websocket.GetType()}")) + this._websocket.Dispose(); } ~WebSocketWrapper()