Skip to content

Commit 6655769

Browse files
authored
Merge pull request #386 from splitio/async-sse-class
Async sse class
2 parents c584967 + dfb411e commit 6655769

File tree

2 files changed

+259
-12
lines changed

2 files changed

+259
-12
lines changed

splitio/push/sse.py

Lines changed: 146 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,46 @@
11
"""Low-level SSE Client."""
22
import logging
33
import socket
4+
import abc
5+
import urllib
46
from collections import namedtuple
57
from http.client import HTTPConnection, HTTPSConnection
68
from urllib.parse import urlparse
9+
import pytest
710

11+
from splitio.optional.loaders import asyncio, aiohttp
12+
from splitio.api.client import HttpClientException
813

914
_LOGGER = logging.getLogger(__name__)
1015

11-
1216
SSE_EVENT_ERROR = 'error'
1317
SSE_EVENT_MESSAGE = 'message'
14-
18+
_DEFAULT_HEADERS = {'accept': 'text/event-stream'}
19+
_EVENT_SEPARATORS = set([b'\n', b'\r\n'])
20+
_DEFAULT_ASYNC_TIMEOUT = 300
1521

1622
SSEEvent = namedtuple('SSEEvent', ['event_id', 'event', 'retry', 'data'])
1723

1824

1925
__ENDING_CHARS = set(['\n', ''])
2026

27+
def _get_request_parameters(url, extra_headers):
28+
"""
29+
Parse URL and headers
30+
31+
:param url: url to connect to
32+
:type url: str
33+
34+
:param extra_headers: additional headers
35+
:type extra_headers: dict[str, str]
36+
37+
:returns: processed URL and Headers
38+
:rtype: str, dict
39+
"""
40+
url = urlparse(url)
41+
headers = _DEFAULT_HEADERS.copy()
42+
headers.update(extra_headers if extra_headers is not None else {})
43+
return url, headers
2144

2245
class EventBuilder(object):
2346
"""Event builder class."""
@@ -46,12 +69,19 @@ def build(self):
4669
return SSEEvent(self._lines.get('id'), self._lines.get('event'),
4770
self._lines.get('retry'), self._lines.get('data'))
4871

72+
class SSEClientBase(object, metaclass=abc.ABCMeta):
73+
"""Worker template."""
4974

50-
class SSEClient(object):
51-
"""SSE Client implementation."""
75+
@abc.abstractmethod
76+
def start(self, url, extra_headers, timeout): # pylint:disable=protected-access
77+
"""Connect and start listening for events."""
5278

53-
_DEFAULT_HEADERS = {'accept': 'text/event-stream'}
54-
_EVENT_SEPARATORS = set([b'\n', b'\r\n'])
79+
@abc.abstractmethod
80+
def shutdown(self):
81+
"""Shutdown the current connection."""
82+
83+
class SSEClient(SSEClientBase):
84+
"""SSE Client implementation."""
5585

5686
def __init__(self, callback):
5787
"""
@@ -81,7 +111,7 @@ def _read_events(self):
81111
elif line.startswith(b':'): # comment. Skip
82112
_LOGGER.debug("skipping sse comment")
83113
continue
84-
elif line in self._EVENT_SEPARATORS:
114+
elif line in _EVENT_SEPARATORS:
85115
event = event_builder.build()
86116
_LOGGER.debug("dispatching event: %s", event)
87117
self._event_callback(event)
@@ -117,9 +147,7 @@ def start(self, url, extra_headers=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT)
117147
raise RuntimeError('Client already started.')
118148

119149
self._shutdown_requested = False
120-
url = urlparse(url)
121-
headers = self._DEFAULT_HEADERS.copy()
122-
headers.update(extra_headers if extra_headers is not None else {})
150+
url, headers = _get_request_parameters(url, extra_headers)
123151
self._conn = (HTTPSConnection(url.hostname, url.port, timeout=timeout)
124152
if url.scheme == 'https'
125153
else HTTPConnection(url.hostname, port=url.port, timeout=timeout))
@@ -139,3 +167,111 @@ def shutdown(self):
139167

140168
self._shutdown_requested = True
141169
self._conn.sock.shutdown(socket.SHUT_RDWR)
170+
171+
class SSEClientAsync(SSEClientBase):
172+
"""SSE Client implementation."""
173+
174+
def __init__(self, callback):
175+
"""
176+
Construct an SSE client.
177+
178+
:param callback: function to call when an event is received
179+
:type callback: callable
180+
"""
181+
self._conn = None
182+
self._event_callback = callback
183+
self._shutdown_requested = False
184+
185+
async def _read_events(self, response):
186+
"""
187+
Read events from the supplied connection.
188+
189+
:returns: True if the connection was ended by us. False if it was closed by the serve.
190+
:rtype: bool
191+
"""
192+
try:
193+
event_builder = EventBuilder()
194+
while not self._shutdown_requested:
195+
line = await response.readline()
196+
if line is None or len(line) <= 0: # connection ended
197+
break
198+
elif line.startswith(b':'): # comment. Skip
199+
_LOGGER.debug("skipping sse comment")
200+
continue
201+
elif line in _EVENT_SEPARATORS:
202+
event = event_builder.build()
203+
_LOGGER.debug("dispatching event: %s", event)
204+
await self._event_callback(event)
205+
event_builder = EventBuilder()
206+
else:
207+
event_builder.process_line(line)
208+
except asyncio.CancelledError:
209+
_LOGGER.debug("Cancellation request, proceeding to cancel.")
210+
raise
211+
except Exception: # pylint:disable=broad-except
212+
_LOGGER.debug('sse connection ended.')
213+
_LOGGER.debug('stack trace: ', exc_info=True)
214+
finally:
215+
await self._conn.close()
216+
self._conn = None # clear so it can be started again
217+
218+
return self._shutdown_requested
219+
220+
async def start(self, url, extra_headers=None, timeout=_DEFAULT_ASYNC_TIMEOUT): # pylint:disable=protected-access
221+
"""
222+
Connect and start listening for events.
223+
224+
:param url: url to connect to
225+
:type url: str
226+
227+
:param extra_headers: additional headers
228+
:type extra_headers: dict[str, str]
229+
230+
:param timeout: connection & read timeout
231+
:type timeout: float
232+
233+
:returns: True if the connection was ended by us. False if it was closed by the serve.
234+
:rtype: bool
235+
"""
236+
_LOGGER.debug("Async SSEClient Started")
237+
if self._conn is not None:
238+
raise RuntimeError('Client already started.')
239+
240+
self._shutdown_requested = False
241+
url = urlparse(url)
242+
headers = _DEFAULT_HEADERS.copy()
243+
headers.update(extra_headers if extra_headers is not None else {})
244+
parsed_url = urllib.parse.urljoin(url[0] + "://" + url[1], url[2])
245+
params=url[4]
246+
try:
247+
self._conn = aiohttp.connector.TCPConnector()
248+
async with aiohttp.client.ClientSession(
249+
connector=self._conn,
250+
headers=headers,
251+
timeout=aiohttp.ClientTimeout(timeout)
252+
) as self._session:
253+
reader = await self._session.request(
254+
"GET",
255+
parsed_url,
256+
params=params
257+
)
258+
return await self._read_events(reader.content)
259+
except aiohttp.ClientError as exc: # pylint: disable=broad-except
260+
_LOGGER.error(str(exc))
261+
raise HttpClientException('http client is throwing exceptions') from exc
262+
263+
async def shutdown(self):
264+
"""Shutdown the current connection."""
265+
_LOGGER.debug("Async SSEClient Shutdown")
266+
if self._conn is None:
267+
_LOGGER.warning("no sse connection has been started on this SSEClient instance. Ignoring")
268+
return
269+
270+
if self._shutdown_requested:
271+
_LOGGER.warning("shutdown already requested")
272+
return
273+
274+
self._shutdown_requested = True
275+
sock = self._session.connector._loop._ssock
276+
sock.shutdown(socket.SHUT_RDWR)
277+
await self._conn.close()

tests/push/test_sse.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import time
44
import threading
55
import pytest
6-
from splitio.push.sse import SSEClient, SSEEvent
7-
from tests.helpers.mockserver import SSEMockServer
86

7+
from splitio.push.sse import SSEClient, SSEEvent, SSEClientAsync
8+
from splitio.optional.loaders import asyncio
9+
from tests.helpers.mockserver import SSEMockServer
910

1011
class SSEClientTests(object):
1112
"""SSEClient test cases."""
@@ -123,3 +124,113 @@ def runner():
123124
]
124125

125126
assert client._conn is None
127+
128+
class SSEClientAsyncTests(object):
129+
"""SSEClient test cases."""
130+
131+
# @pytest.mark.asyncio
132+
async def test_sse_client_disconnects(self):
133+
"""Test correct initialization. Client ends the connection."""
134+
server = SSEMockServer()
135+
server.start()
136+
137+
events = []
138+
async def callback(event):
139+
"""Callback."""
140+
events.append(event)
141+
142+
client = SSEClientAsync(callback)
143+
144+
async def connect_split_sse_client():
145+
await client.start('http://127.0.0.1:' + str(server.port()))
146+
147+
self._client_task = asyncio.gather(connect_split_sse_client())
148+
server.publish({'id': '1'})
149+
server.publish({'id': '2', 'event': 'message', 'data': 'abc'})
150+
server.publish({'id': '3', 'event': 'message', 'data': 'def'})
151+
server.publish({'id': '4', 'event': 'message', 'data': 'ghi'})
152+
await asyncio.sleep(1)
153+
await client.shutdown()
154+
self._client_task.cancel()
155+
await asyncio.sleep(1)
156+
157+
assert events == [
158+
SSEEvent('1', None, None, None),
159+
SSEEvent('2', 'message', None, 'abc'),
160+
SSEEvent('3', 'message', None, 'def'),
161+
SSEEvent('4', 'message', None, 'ghi')
162+
]
163+
assert client._conn is None
164+
server.publish(server.GRACEFUL_REQUEST_END)
165+
server.stop()
166+
167+
async def test_sse_server_disconnects(self):
168+
"""Test correct initialization. Server ends connection."""
169+
server = SSEMockServer()
170+
server.start()
171+
172+
events = []
173+
async def callback(event):
174+
"""Callback."""
175+
events.append(event)
176+
177+
client = SSEClientAsync(callback)
178+
179+
async def start_client():
180+
await client.start('http://127.0.0.1:' + str(server.port()))
181+
182+
asyncio.gather(start_client())
183+
server.publish({'id': '1'})
184+
server.publish({'id': '2', 'event': 'message', 'data': 'abc'})
185+
server.publish({'id': '3', 'event': 'message', 'data': 'def'})
186+
server.publish({'id': '4', 'event': 'message', 'data': 'ghi'})
187+
server.publish(server.GRACEFUL_REQUEST_END)
188+
189+
await asyncio.sleep(1)
190+
server.stop()
191+
await asyncio.sleep(1)
192+
193+
assert events == [
194+
SSEEvent('1', None, None, None),
195+
SSEEvent('2', 'message', None, 'abc'),
196+
SSEEvent('3', 'message', None, 'def'),
197+
SSEEvent('4', 'message', None, 'ghi')
198+
]
199+
200+
assert client._conn is None
201+
202+
async def test_sse_server_disconnects_abruptly(self):
203+
"""Test correct initialization. Server ends connection."""
204+
server = SSEMockServer()
205+
server.start()
206+
207+
events = []
208+
async def callback(event):
209+
"""Callback."""
210+
events.append(event)
211+
212+
client = SSEClientAsync(callback)
213+
214+
async def runner():
215+
"""SSE client runner thread."""
216+
await client.start('http://127.0.0.1:' + str(server.port()))
217+
218+
client_task = asyncio.gather(runner())
219+
220+
server.publish({'id': '1'})
221+
server.publish({'id': '2', 'event': 'message', 'data': 'abc'})
222+
server.publish({'id': '3', 'event': 'message', 'data': 'def'})
223+
server.publish({'id': '4', 'event': 'message', 'data': 'ghi'})
224+
await asyncio.sleep(1)
225+
server.publish(server.VIOLENT_REQUEST_END)
226+
server.stop()
227+
await asyncio.sleep(1)
228+
229+
assert events == [
230+
SSEEvent('1', None, None, None),
231+
SSEEvent('2', 'message', None, 'abc'),
232+
SSEEvent('3', 'message', None, 'def'),
233+
SSEEvent('4', 'message', None, 'ghi')
234+
]
235+
236+
assert client._conn is None

0 commit comments

Comments
 (0)