diff --git a/src/BaristaLabs.ChromeDevTools.RemoteInterface/CodeGen/ProtocolGenerator.cs b/src/BaristaLabs.ChromeDevTools.RemoteInterface/CodeGen/ProtocolGenerator.cs index 0edf571..60fc6f8 100644 --- a/src/BaristaLabs.ChromeDevTools.RemoteInterface/CodeGen/ProtocolGenerator.cs +++ b/src/BaristaLabs.ChromeDevTools.RemoteInterface/CodeGen/ProtocolGenerator.cs @@ -131,9 +131,11 @@ private Dictionary GetTypesInDomain(ICollection GetTypesInDomain(ICollection m_logger; - private readonly ConcurrentDictionary>> m_eventHandlers = new ConcurrentDictionary>>(); + private readonly ConcurrentDictionary>> m_eventHandlers = new ConcurrentDictionary>>(); private readonly ConcurrentDictionary m_eventTypeMap = new ConcurrentDictionary(); private ActionBlock m_messageQueue; @@ -94,13 +94,13 @@ /// /// /// - public async Task> SendCommand(TCommand command, CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) + public async Task> SendCommand(TCommand command, string sessionId = "", CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) where TCommand : ICommand { if (command == null) throw new ArgumentNullException(nameof(command)); - var result = await SendCommand(command.CommandName, JToken.FromObject(command), cancellationToken, millisecondsTimeout, throwExceptionIfResponseNotReceived); + var result = await SendCommand(command.CommandName, JToken.FromObject(command), sessionId, cancellationToken, millisecondsTimeout, throwExceptionIfResponseNotReceived); if (result == null) return null; @@ -121,14 +121,14 @@ /// /// /// - public async Task SendCommand(TCommand command, CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) + public async Task SendCommand(TCommand command, string sessionId = "", CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) where TCommand : ICommand where TCommandResponse : ICommandResponse { if (command == null) throw new ArgumentNullException(nameof(command)); - var result = await SendCommand(command.CommandName, JToken.FromObject(command), cancellationToken, millisecondsTimeout, throwExceptionIfResponseNotReceived); + var result = await SendCommand(command.CommandName, JToken.FromObject(command), sessionId, cancellationToken, millisecondsTimeout, throwExceptionIfResponseNotReceived); if (result == null) return default(TCommandResponse); @@ -146,10 +146,13 @@ /// /// [DebuggerStepThrough] - public async Task SendCommand(string commandName, JToken @params, CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) + public async Task SendCommand(string commandName, JToken @params, string sessionId = "", CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) { + sessionId = sessionId ?? ""; + var message = new { + sessionId, id = Interlocked.Increment(ref m_currentCommandId), method = commandName, @params = @params @@ -160,7 +163,7 @@ await OpenSessionConnection(cancellationToken); - LogTrace("Sending {id} {method}: {params}", message.id, message.method, @params.ToString()); + LogTrace("Sending: \nSessionId: {sessionId} \nID: {id} \nCommand Name: {method} \nCommand Params: {params}", sessionId, message.id, message.method, @params.ToString()); var contents = JsonConvert.SerializeObject(message); @@ -181,7 +184,7 @@ if (!String.IsNullOrWhiteSpace(errorData)) exceptionMessage = $"{exceptionMessage} - {errorData}"; - LogTrace("Recieved Error Response {id}: {message} {data}", message.id, message, errorData); + LogTrace("Recieved Error Response: \nID: {id} \nRequest Message: {message} \nError Data: {data} \nException Message: {exceptionMessage}", message.id, message, errorData, exceptionMessage); throw new CommandResponseException(exceptionMessage) { Code = m_lastResponse.Result.Value("code") @@ -195,9 +198,11 @@ /// /// Event to subscribe to /// - public void Subscribe(Action eventCallback) + public void Subscribe(Action eventCallback, string sessionId = "") where TEvent : IEvent { + sessionId = sessionId ?? ""; + if (eventCallback == null) throw new ArgumentNullException(nameof(eventCallback)); @@ -212,7 +217,14 @@ }); var callbackWrapper = new Action(obj => eventCallback((TEvent)obj)); - m_eventHandlers.AddOrUpdate(eventName, + + var keyObj = new + { + SessionId = sessionId, + EventName = eventName + }; + + m_eventHandlers.AddOrUpdate(keyObj, (m) => new ConcurrentBag>(new[] { callbackWrapper }), (m, currentBag) => { @@ -225,15 +237,24 @@ { if (m_sessionSocket.State != WebSocketState.Open) { + m_openEvent.Reset(); m_sessionSocket.Open(); await Task.Run(() => m_openEvent.Wait(cancellationToken)); } } - private void RaiseEvent(string methodName, JToken eventData) + private void RaiseEvent(string methodName, JToken eventData, string sessionId = "") { - if (m_eventHandlers.TryGetValue(methodName, out ConcurrentBag> bag)) + sessionId = sessionId ?? ""; + + var keyObj = new + { + SessionId = sessionId, + EventName = methodName + }; + + if (m_eventHandlers.TryGetValue(keyObj, out ConcurrentBag> bag)) { if (!EventTypeMap.TryGetTypeForMethodName(methodName, out Type eventType)) throw new InvalidOperationException($"Unknown {methodName} does not correspond to a known event type."); @@ -279,8 +300,11 @@ { var method = methodProperty.Value(); var eventData = messageObject["params"]; - LogTrace("Recieved Event {method}: {params}", method, eventData.ToString()); - RaiseEvent(method, eventData); + var sessionId = messageObject.TryGetValue("sessionId", out JToken sessionIdProperty) + ? sessionIdProperty.Value() + : ""; + LogTrace("Recieved Event: \nEvent Name: {method} \nParams: {params} \nSessionId: {sessionId}", method, eventData.ToString(), sessionId); + RaiseEvent(method, eventData, sessionId); return; } diff --git a/src/ChromeDevToolsGeneratorCLI/Templates/domain.hbs b/src/ChromeDevToolsGeneratorCLI/Templates/domain.hbs index 0c0b9ff..3d2d53e 100644 --- a/src/ChromeDevToolsGeneratorCLI/Templates/domain.hbs +++ b/src/ChromeDevToolsGeneratorCLI/Templates/domain.hbs @@ -28,9 +28,9 @@ /// /// {{xml-code-comment Description 2}} /// - public async Task<{{dehumanize Name}}CommandResponse> {{dehumanize Name}}({{dehumanize Name}}Command command{{#if NoParameters}} = null{{/if}}, CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) + public async Task<{{dehumanize Name}}CommandResponse> {{dehumanize Name}}({{dehumanize Name}}Command command{{#if NoParameters}} = null{{/if}}, string sessionId = null, CancellationToken cancellationToken = default(CancellationToken), int? millisecondsTimeout = null, bool throwExceptionIfResponseNotReceived = true) { - return await m_session.SendCommand<{{dehumanize Name}}Command, {{dehumanize Name}}CommandResponse>(command{{#if NoParameters}} ?? new {{dehumanize Name}}Command(){{/if}}, cancellationToken, millisecondsTimeout, throwExceptionIfResponseNotReceived); + return await m_session.SendCommand<{{dehumanize Name}}Command, {{dehumanize Name}}CommandResponse>(command{{#if NoParameters}} ?? new {{dehumanize Name}}Command(){{/if}}, sessionId, cancellationToken, millisecondsTimeout, throwExceptionIfResponseNotReceived); } {{/each}} @@ -38,9 +38,9 @@ /// /// {{xml-code-comment Description 2}} /// - public void SubscribeTo{{dehumanize Name}}Event(Action<{{dehumanize Name}}Event> eventCallback) + public void SubscribeTo{{dehumanize Name}}Event(Action<{{dehumanize Name}}Event> eventCallback, string sessionId = "") { - m_session.Subscribe(eventCallback); + m_session.Subscribe(eventCallback, sessionId); } {{/each}} }