diff --git a/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java b/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java index 89107d91e..3a680fb2a 100644 --- a/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java +++ b/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java @@ -31,6 +31,7 @@ import io.ably.lib.types.ClientOptions; import io.ably.lib.types.ConnectionDetails; import io.ably.lib.types.ErrorInfo; +import io.ably.lib.types.Param; import io.ably.lib.types.ProtocolMessage; import io.ably.lib.types.ProtocolSerializer; import io.ably.lib.util.Log; @@ -857,6 +858,23 @@ public void requestState(StateIndication state) { requestState(null, state); } + /** + * Get query params representing the current authentication method and credentials. + */ + Param[] getAuthParams() throws AblyException { + return ably.auth.getAuthParams(); + } + + /** + * Determines if the given WebSocketTransport instance is the currently active transport. + * + * @param transport the WebSocketTransport instance to check against the active transport + * @return true if the provided transport is the currently active transport, false otherwise + */ + boolean isActiveTransport(WebSocketTransport transport) { + return transport == this.transport; + } + private synchronized void requestState(ITransport transport, StateIndication stateIndication) { Log.v(TAG, "requestState(): requesting " + stateIndication.state + "; id = " + connection.id); addAction(new AsynchronousStateChangeAction(transport, stateIndication)); @@ -2002,7 +2020,7 @@ private boolean isFatalError(ErrorInfo err) { private ErrorInfo stateError; private ConnectParams pendingConnect; private boolean suppressRetry; /* for tests only; modified via reflection */ - private ITransport transport; + private volatile ITransport transport; private long suspendTime; public long msgSerial; private long lastActivity; diff --git a/lib/src/main/java/io/ably/lib/transport/WebSocketTransport.java b/lib/src/main/java/io/ably/lib/transport/WebSocketTransport.java index 69ce91a34..6d7c087f0 100644 --- a/lib/src/main/java/io/ably/lib/transport/WebSocketTransport.java +++ b/lib/src/main/java/io/ably/lib/transport/WebSocketTransport.java @@ -52,8 +52,11 @@ public class WebSocketTransport implements ITransport { private ConnectListener connectListener; private WebSocketClient webSocketClient; private final WebSocketEngine webSocketEngine; + private WebSocketHandler webSocketHandler; private boolean activityCheckTurnedOff = false; + private boolean connectHasBeenCalled = false; + /****************** * protected constructor ******************/ @@ -94,22 +97,26 @@ private static WebSocketEngine createWebSocketEngine(TransportParams params) { * ITransport methods ******************/ + /** + * Connect is called once when we create transport; + * after transport is closed, we never call `connect` again + */ @Override public void connect(ConnectListener connectListener) { + ensureConnectCalledOnce(); this.connectListener = connectListener; try { boolean isTls = params.options.tls; String wsScheme = isTls ? "wss://" : "ws://"; wsUri = wsScheme + params.host + ':' + params.port + "/"; - Param[] authParams = connectionManager.ably.auth.getAuthParams(); + Param[] authParams = connectionManager.getAuthParams(); Param[] connectParams = params.getConnectParams(authParams); if (connectParams.length > 0) wsUri = HttpUtils.encodeParams(wsUri, connectParams); Log.d(TAG, "connect(); wsUri = " + wsUri); - synchronized (this) { - webSocketClient = this.webSocketEngine.create(wsUri, new WebSocketHandler(this::receive)); - } + webSocketHandler = new WebSocketHandler(this::receive); + webSocketClient = this.webSocketEngine.create(wsUri, webSocketHandler); webSocketClient.connect(); } catch (AblyException e) { Log.e(TAG, "Unexpected exception attempting connection; wsUri = " + wsUri, e); @@ -120,14 +127,36 @@ public void connect(ConnectListener connectListener) { } } + /** + * `connect()` can't be called more than once + */ + private synchronized void ensureConnectCalledOnce() { + if (connectHasBeenCalled) throw new IllegalStateException("WebSocketTransport is already initialized"); + connectHasBeenCalled = true; + } + @Override public void close() { Log.d(TAG, "close()"); - synchronized (this) { - if (webSocketClient != null) { - webSocketClient.close(); - webSocketClient = null; - } + // Take local snapshots of the shared references. Callback threads (e.g., onClose) + // may concurrently set these fields to null. + // + // Intentionally avoid synchronizing here: + // - The WebSocket library may invoke our WebSocketHandler while holding its own + // internal locks. + // - If close() also acquired a lock on WebSocketTransport, we could invert the + // lock order and create a circular wait (deadlock): close() waits for the WS + // library to release its lock, while the WS library waits for a lock on + // WebSocketTransport. + final WebSocketClient client = webSocketClient; + final WebSocketHandler handler = webSocketHandler; + if (client != null && handler != null) { + // Record activity so the activity timer remains armed. If a graceful close + // stalls, the timer can detect inactivity and force-cancel the socket. + handler.flagActivity(); + client.close(); + } else { + Log.w(TAG, "close() called on uninitialized or already closed transport"); } } @@ -191,6 +220,11 @@ public String toString() { public String getURL() { return wsUri; } + + private boolean isActiveTransport() { + return connectionManager.isActiveTransport(this); + } + //interface to transfer Protocol message from websocket interface WebSocketReceiver { void onMessage(ProtocolMessage protocolMessage) throws AblyException; @@ -217,9 +251,14 @@ class WebSocketHandler implements WebSocketListener { * WsClient private members ***************************/ - private Timer timer = new Timer(); - private TimerTask activityTimerTask = null; - private long lastActivityTime; + private final Timer timer = new Timer(); + private volatile TimerTask activityTimerTask = null; + private volatile long lastActivityTime; + + /** + * Monitor for activity timer events + */ + private final Object activityTimerMonitor = new Object(); WebSocketHandler(WebSocketReceiver receiver) { this.receiver = receiver; @@ -318,66 +357,68 @@ public void onOldJavaVersionDetected(Throwable throwable) { Log.w(TAG, "Error when trying to set SSL parameters, most likely due to an old Java API version", throwable); } - private synchronized void dispose() { - /* dispose timer */ - try { - timer.cancel(); - timer = null; - } catch (IllegalStateException e) { - } + private void dispose() { + timer.cancel(); } - private synchronized void flagActivity() { - lastActivityTime = System.currentTimeMillis(); - connectionManager.setLastActivity(lastActivityTime); - if (activityTimerTask == null && connectionManager.maxIdleInterval != 0 && !activityCheckTurnedOff) { - /* No timer currently running because previously there was no - * maxIdleInterval configured, but now there is a - * maxIdleInterval configured. Call checkActivity so a timer - * gets started. This happens when flagActivity gets called - * just after processing the connect message that configures - * maxIdleInterval. */ - checkActivity(); + private void flagActivity() { + if (isActiveTransport()) { + lastActivityTime = System.currentTimeMillis(); + connectionManager.setLastActivity(lastActivityTime); } + + if (connectionManager.maxIdleInterval == 0) { + Log.v(TAG, "checkActivity: turned off because maxIdleInterval is 0"); + return; + } + + if (activityCheckTurnedOff) { + Log.v(TAG, "checkActivity: turned off for test purpose"); + return; + } + + checkActivity(); } - private synchronized void checkActivity() { + private void checkActivity() { long timeout = getActivityTimeout(); + if (timeout == 0) { Log.v(TAG, "checkActivity: infinite timeout"); return; } - // Check if timer already running - if (activityTimerTask != null) { - return; - } + // prevent going to the synchronized block if the timer is active + if (activityTimerTask != null) return; - // Start the activity timer task - startActivityTimer(timeout + 100); + synchronized (activityTimerMonitor) { + // Check if timer already running + if (activityTimerTask == null) { + // Start the activity timer task + startActivityTimer(timeout + 100); + } + } } - private synchronized void startActivityTimer(long timeout) { - if (activityTimerTask == null) { - schedule((activityTimerTask = new TimerTask() { - public void run() { - try { - onActivityTimerExpiry(); - } catch (Throwable t) { - Log.e(TAG, "Unexpected exception in activity timer handler", t); - } + private void startActivityTimer(long timeout) { + activityTimerTask = new TimerTask() { + public void run() { + try { + onActivityTimerExpiry(); + } catch (Exception exception) { + Log.e(TAG, "Unexpected exception in activity timer handler", exception); + webSocketClient.cancel(ABNORMAL_CLOSE, "Activity timer closed unexpectedly"); } - }), timeout); - } + } + }; + schedule(activityTimerTask, timeout); } - private synchronized void schedule(TimerTask task, long delay) { - if (timer != null) { - try { - timer.schedule(task, delay); - } catch (IllegalStateException ise) { - Log.e(TAG, "Unexpected exception scheduling activity timer", ise); - } + private void schedule(TimerTask task, long delay) { + try { + timer.schedule(task, delay); + } catch (IllegalStateException ise) { + Log.w(TAG, "Timer has already has been canceled", ise); } } @@ -392,7 +433,7 @@ private void onActivityTimerExpiry() { return; } - synchronized (this) { + synchronized (activityTimerMonitor) { activityTimerTask = null; // Otherwise, we've had some activity, restart the timer for the next timeout Log.v(TAG, "onActivityTimerExpiry: ok"); @@ -401,7 +442,7 @@ private void onActivityTimerExpiry() { } private long getActivityTimeout() { - return connectionManager.maxIdleInterval + connectionManager.ably.options.realtimeRequestTimeout; + return connectionManager.maxIdleInterval + params.options.realtimeRequestTimeout; } } diff --git a/lib/src/test/java/io/ably/lib/test/common/Helpers.java b/lib/src/test/java/io/ably/lib/test/common/Helpers.java index 5b0f328c8..724e818af 100644 --- a/lib/src/test/java/io/ably/lib/test/common/Helpers.java +++ b/lib/src/test/java/io/ably/lib/test/common/Helpers.java @@ -969,6 +969,16 @@ public static boolean equalNullableStrings(String one, String two) { return (one == null) ? (two == null) : one.equals(two); } + public static void setPrivateField(Object object, String fieldName, Object value) { + try { + Field connectionStateField = object.getClass().getDeclaredField(fieldName); + connectionStateField.setAccessible(true); + connectionStateField.set(object, value); + } catch (Exception e) { + fail("Failed accessing " + fieldName + " with error " + e); + } + } + public static class RawHttpRequest { public String id; public URL url; diff --git a/lib/src/test/java/io/ably/lib/transport/WebSocketTransportTest.java b/lib/src/test/java/io/ably/lib/transport/WebSocketTransportTest.java new file mode 100644 index 000000000..c9baad554 --- /dev/null +++ b/lib/src/test/java/io/ably/lib/transport/WebSocketTransportTest.java @@ -0,0 +1,153 @@ +package io.ably.lib.transport; + +import io.ably.lib.network.WebSocketClient; +import io.ably.lib.network.WebSocketEngine; +import io.ably.lib.network.WebSocketListener; +import io.ably.lib.test.common.Helpers; +import io.ably.lib.test.util.EmptyPlatformAgentProvider; +import io.ably.lib.transport.ITransport.TransportParams; +import io.ably.lib.types.ClientOptions; +import io.ably.lib.types.Param; +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for WebSocketTransport, specifically testing activity timer behavior + * when WebSocket close operations get stuck or fail to trigger onClose handlers. + */ +public class WebSocketTransportTest { + + private ConnectionManager mockConnectionManager; + + private WebSocketEngine mockEngine; + + private WebSocketTransport transport; + + private WebSocketClient mockWebSocketClient; + + private TransportParams transportParams; + + @Before + public void setUp() throws Exception { + mockConnectionManager = mock(ConnectionManager.class); + mockEngine = mock(WebSocketEngine.class); + mockWebSocketClient = mock(WebSocketClient.class); + when(mockEngine.isPingListenerSupported()).thenReturn(true); + when(mockEngine.create(any(), any())).thenReturn(mockWebSocketClient); + when(mockConnectionManager.getAuthParams()).thenReturn(new Param[]{}); + + mockConnectionManager.maxIdleInterval = 10; + + // Setup transport params + transportParams = new TransportParams(new ClientOptions(), new EmptyPlatformAgentProvider()); + transportParams.host = "realtime.ably.io"; + transportParams.port = 443; + transportParams.options.realtimeRequestTimeout = 10; + } + + private WebSocketTransport createWebSocketTransport() { + WebSocketTransport transport = new WebSocketTransport(transportParams, mockConnectionManager); + Helpers.setPrivateField(transport, "webSocketEngine", mockEngine); + return transport; + } + + @Test + public void throwExceptionsIfConnectCalledTwice() { + final WebSocketTransport transport = createWebSocketTransport(); + ITransport.ConnectListener connectListener = mock(ITransport.ConnectListener.class); + transport.connect(connectListener); + assertThrows(IllegalStateException.class, () -> + transport.connect(connectListener) + ); + } + + @Test + public void shouldCallCancelIfNotClosedGracefully() { + AtomicReference webSocketListenerRef = new AtomicReference<>(); + + when(mockEngine.create(any(), any())).thenAnswer(invocation -> { + webSocketListenerRef.set(invocation.getArgumentAt(1, WebSocketListener.class)); + return mockWebSocketClient; + }); + + doAnswer(invocation -> { + webSocketListenerRef.get().onClose( + invocation.getArgumentAt(0, Integer.class), + invocation.getArgumentAt(1, String.class) + ); + return null; + }).when(mockWebSocketClient).cancel(anyInt(), anyString()); + + final WebSocketTransport transport = createWebSocketTransport(); + ITransport.ConnectListener connectListener = mock(ITransport.ConnectListener.class); + transport.connect(connectListener); + transport.close(); + // check that we tried to close gracefully + verify(mockWebSocketClient).close(); + // check that we closed forcibly at the end + verify(mockWebSocketClient, timeout(1_000)).cancel(eq(1006), anyString()); + // verify that we call listener at the end + verify(connectListener).onTransportUnavailable(eq(transport), any()); + } + + /** + * `onClose` can be called twice, e.g. from activity timer force close and from manual `close()` + * It shouldn't result in any exceptions + */ + @Test + public void shouldNotThrowExceptionIfSeveralCloseEventsHappened() { + AtomicReference listenerRef = new AtomicReference<>(); + + when(mockEngine.create(any(), any())).thenAnswer(invocation -> { + listenerRef.set(invocation.getArgumentAt(1, WebSocketListener.class)); + return mockWebSocketClient; + }); + + + final WebSocketTransport transport = createWebSocketTransport(); + ITransport.ConnectListener connectListener = mock(ITransport.ConnectListener.class); + transport.connect(connectListener); + + listenerRef.get().onClose(1000, "OK"); + listenerRef.get().onClose(1006, "Abnormal close"); + + verify(connectListener, times(2)).onTransportUnavailable(eq(transport), any()); + } + + /** + * Calling `close()` on transport triggers the activity timer. + * Test checks that if it has been disposed it won't do anything. + */ + @Test + public void shouldNotThrowExceptionIfCloseCalledOnAlreadyClosedTransport() { + AtomicReference listenerRef = new AtomicReference<>(); + + when(mockEngine.create(any(), any())).thenAnswer(invocation -> { + listenerRef.set(invocation.getArgumentAt(1, WebSocketListener.class)); + return mockWebSocketClient; + }); + + + final WebSocketTransport transport = createWebSocketTransport(); + ITransport.ConnectListener connectListener = mock(ITransport.ConnectListener.class); + transport.connect(connectListener); + + listenerRef.get().onClose(1006, "Abnormal close"); + transport.close(); + verify(connectListener, timeout(1_000)).onTransportUnavailable(eq(transport), any()); + } +}