Skip to content

Commit d751f63

Browse files
committed
feat: support for async bidi streaming apis
1 parent 2c98385 commit d751f63

File tree

3 files changed

+579
-28
lines changed

3 files changed

+579
-28
lines changed

google/api_core/bidi_async.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# Copyright 2020, Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Asynchronous bi-directional streaming RPC helpers."""
16+
17+
import asyncio
18+
import logging
19+
20+
from google.api_core import exceptions
21+
22+
_LOGGER = logging.getLogger(__name__)
23+
24+
25+
class _AsyncRequestQueueGenerator:
26+
"""An async helper for sending requests to a gRPC stream from a Queue.
27+
28+
This generator takes requests off a given queue and yields them to gRPC.
29+
30+
This helper is useful when you have an indeterminate, indefinite, or
31+
otherwise open-ended set of requests to send through a request-streaming
32+
(or bidirectional) RPC.
33+
34+
The reason this is necessary is because gRPC takes an async iterator as the
35+
request for request-streaming RPCs. gRPC consumes this iterator to allow
36+
it to block while generating requests for the stream. However, if the
37+
generator blocks indefinitely gRPC will not be able to clean up the task
38+
as it'll be blocked on `anext(iterator)` and not be able to check the
39+
channel status to stop iterating. This helper mitigates that by waiting
40+
on the queue with a timeout and checking the RPC state before yielding.
41+
42+
Finally, it allows for retrying without swapping queues because if it does
43+
pull an item off the queue when the RPC is inactive, it'll immediately put
44+
it back and then exit. This is necessary because yielding the item in this
45+
case will cause gRPC to discard it. In practice, this means that the order
46+
of messages is not guaranteed. If such a thing is necessary it would be
47+
easy to use a priority queue.
48+
49+
Example::
50+
51+
requests = _AsyncRequestQueueGenerator(q)
52+
call = await stub.StreamingRequest(requests)
53+
requests.call = call
54+
55+
async for response in call:
56+
print(response)
57+
await q.put(...)
58+
59+
Args:
60+
queue (asyncio.Queue): The request queue.
61+
period (float): The number of seconds to wait for items from the queue
62+
before checking if the RPC is cancelled. In practice, this
63+
determines the maximum amount of time the request consumption
64+
task will live after the RPC is cancelled.
65+
initial_request (Union[protobuf.Message,
66+
Callable[[], protobuf.Message]]): The initial request to
67+
yield. This is done independently of the request queue to allow for
68+
easily restarting streams that require some initial configuration
69+
request.
70+
"""
71+
72+
def __init__(self, queue: asyncio.Queue, period: float = 1, initial_request=None):
73+
self._queue = queue
74+
self._period = period
75+
self._initial_request = initial_request
76+
self.call = None
77+
78+
def _is_active(self):
79+
# Note: there is a possibility that this starts *before* the call
80+
# property is set. So we have to check if self.call is set before
81+
# seeing if it's active. We need to return True if self.call is None.
82+
# See https://github.com/googleapis/python-api-core/issues/560.
83+
return self.call is None or not self.call.done()
84+
85+
async def __aiter__(self):
86+
if self._initial_request is not None:
87+
if callable(self._initial_request):
88+
yield self._initial_request()
89+
else:
90+
yield self._initial_request
91+
92+
while True:
93+
try:
94+
item = self._queue.get_nowait()
95+
except asyncio.QueueEmpty:
96+
if not self._is_active():
97+
_LOGGER.debug(
98+
"Empty queue and inactive call, exiting request generator."
99+
)
100+
return
101+
else:
102+
# call is still active, keep waiting for queue items.
103+
await asyncio.sleep(self._period)
104+
continue
105+
106+
# The consumer explicitly sent "None", indicating that the request
107+
# should end.
108+
if item is None:
109+
_LOGGER.debug("Cleanly exiting request generator.")
110+
return
111+
112+
if not self._is_active():
113+
# We have an item, but the call is closed. We should put the
114+
# item back on the queue so that the next call can consume it.
115+
await self._queue.put(item)
116+
_LOGGER.debug(
117+
"Inactive call, replacing item on queue and exiting "
118+
"request generator."
119+
)
120+
return
121+
122+
yield item
123+
124+
125+
126+
class AsyncBidiRpc:
127+
"""A helper for consuming a async bi-directional streaming RPC.
128+
129+
This maps gRPC's built-in interface which uses a request iterator and a
130+
response iterator into a socket-like :func:`send` and :func:`recv`. This
131+
is a more useful pattern for long-running or asymmetric streams (streams
132+
where there is not a direct correlation between the requests and
133+
responses).
134+
135+
Example::
136+
137+
initial_request = example_pb2.StreamingRpcRequest(
138+
setting='example')
139+
rpc = AsyncBidiRpc(
140+
stub.StreamingRpc,
141+
initial_request=initial_request,
142+
metadata=[('name', 'value')]
143+
)
144+
145+
await rpc.open()
146+
147+
while rpc.is_active:
148+
print(await rpc.recv())
149+
await rpc.send(example_pb2.StreamingRpcRequest(
150+
data='example'))
151+
152+
This does *not* retry the stream on errors. See :class:`AsyncResumableBidiRpc`.
153+
154+
Args:
155+
start_rpc (grpc.aio.StreamStreamMultiCallable): The gRPC method used to
156+
start the RPC.
157+
initial_request (Union[protobuf.Message,
158+
Callable[[], protobuf.Message]]): The initial request to
159+
yield. This is useful if an initial request is needed to start the
160+
stream.
161+
metadata (Sequence[Tuple(str, str)]): RPC metadata to include in
162+
the request.
163+
"""
164+
165+
def __init__(self, start_rpc, initial_request=None, metadata=None):
166+
self._start_rpc = start_rpc
167+
self._initial_request = initial_request
168+
self._rpc_metadata = metadata
169+
self._request_queue = asyncio.Queue()
170+
self._request_generator = None
171+
self._callbacks = []
172+
self.call = None
173+
self._loop = asyncio.get_event_loop()
174+
175+
def add_done_callback(self, callback):
176+
"""Adds a callback that will be called when the RPC terminates.
177+
178+
This occurs when the RPC errors or is successfully terminated.
179+
180+
Args:
181+
callback (Callable[[grpc.Future], None]): The callback to execute.
182+
It will be provided with the same gRPC future as the underlying
183+
stream which will also be a :class:`grpc.aio.Call`.
184+
"""
185+
self._callbacks.append(callback)
186+
187+
def _on_call_done(self, future):
188+
# This occurs when the RPC errors or is successfully terminated.
189+
# Note that grpc's "future" here can also be a grpc.RpcError.
190+
# See note in https://github.com/grpc/grpc/issues/10885#issuecomment-302651331
191+
# that `grpc.RpcError` is also `grpc.aio.Call`.
192+
for callback in self._callbacks:
193+
callback(future)
194+
195+
async def open(self):
196+
"""Opens the stream."""
197+
if self.is_active:
198+
raise ValueError("Can not open an already open stream.")
199+
200+
request_generator = _AsyncRequestQueueGenerator(
201+
self._request_queue, initial_request=self._initial_request
202+
)
203+
try:
204+
call = await self._start_rpc(request_generator, metadata=self._rpc_metadata)
205+
except exceptions.GoogleAPICallError as exc:
206+
# The original `grpc.RpcError` (which is usually also a `grpc.Call`) is
207+
# available from the ``response`` property on the mapped exception.
208+
self._on_call_done(exc.response)
209+
raise
210+
211+
request_generator.call = call
212+
213+
# TODO: api_core should expose the future interface for wrapped
214+
# callables as well.
215+
if hasattr(call, "_wrapped"): # pragma: NO COVER
216+
call._wrapped.add_done_callback(self._on_call_done)
217+
else:
218+
call.add_done_callback(self._on_call_done)
219+
220+
self._request_generator = request_generator
221+
self.call = call
222+
223+
async def close(self):
224+
"""Closes the stream."""
225+
if self.call is None:
226+
return
227+
228+
await self._request_queue.put(None)
229+
self.call.cancel()
230+
self._request_generator = None
231+
self._initial_request = None
232+
self._callbacks = []
233+
# Don't set self.call to None. Keep it around so that send/recv can
234+
# raise the error.
235+
236+
async def send(self, request):
237+
"""Queue a message to be sent on the stream.
238+
239+
If the underlying RPC has been closed, this will raise.
240+
241+
Args:
242+
request (protobuf.Message): The request to send.
243+
"""
244+
if self.call is None:
245+
raise ValueError("Can not send() on an RPC that has never been open()ed.")
246+
247+
# Don't use self.is_active(), as ResumableBidiRpc will overload it
248+
# to mean something semantically different.
249+
if not self.call.done():
250+
await self._request_queue.put(request)
251+
else:
252+
# calling read should cause the call to raise.
253+
await self.call.read()
254+
255+
async def recv(self):
256+
"""Wait for a message to be returned from the stream.
257+
258+
If the underlying RPC has been closed, this will raise.
259+
260+
Returns:
261+
protobuf.Message: The received message.
262+
"""
263+
if self.call is None:
264+
raise ValueError("Can not recv() on an RPC that has never been open()ed.")
265+
266+
return await self.call.read()
267+
268+
@property
269+
def is_active(self):
270+
"""bool: True if this stream is currently open and active."""
271+
return self.call is not None and not self.call.done()
272+
273+
@property
274+
def pending_requests(self):
275+
"""int: Returns an estimate of the number of queued requests."""
276+
return self._request_queue.qsize()

noxfile.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -180,37 +180,36 @@ def default(session, install_grpc=True, prerelease=False, install_async_rest=Fal
180180
session.run("python", "-c", "import grpc; print(grpc.__version__)")
181181
session.run("python", "-c", "import google.auth; print(google.auth.__version__)")
182182

183-
pytest_args = [
184-
"python",
185-
"-m",
186-
"pytest",
187-
*(
188-
# Helpful for running a single test or testfile.
189-
session.posargs
190-
or [
191-
"--quiet",
192-
"--cov=google.api_core",
193-
"--cov=tests.unit",
194-
"--cov-append",
195-
"--cov-config=.coveragerc",
196-
"--cov-report=",
197-
"--cov-fail-under=0",
198-
# Running individual tests with parallelism enabled is usually not helpful.
199-
"-n=auto",
200-
os.path.join("tests", "unit"),
201-
]
202-
),
203-
]
183+
pytest_command = ["python", "-m", "pytest"]
204184

205-
session.install("asyncmock", "pytest-asyncio")
185+
if session.posargs:
186+
pytest_options = list(session.posargs)
187+
else:
188+
pytest_options = [
189+
"--quiet",
190+
"--cov=google.api_core",
191+
"--cov=tests.unit",
192+
"--cov-append",
193+
"--cov-config=.coveragerc",
194+
"--cov-report=",
195+
"--cov-fail-under=0",
196+
"-n=auto",
197+
os.path.join("tests", "unit"),
198+
"--cov=tests.asyncio",
199+
os.path.join("tests", "asyncio"),
200+
]
206201

207-
# Having positional arguments means the user wants to run specific tests.
208-
# Best not to add additional tests to that list.
209-
if not session.posargs:
210-
pytest_args.append("--cov=tests.asyncio")
211-
pytest_args.append(os.path.join("tests", "asyncio"))
202+
# Temporarily disable coverage for Python 3.12 due to a missing
203+
# _sqlite3 module in some environments. This is a workaround. The proper
204+
# fix is to reinstall Python 3.12 with SQLite support.
205+
# We only do this when not running specific tests (no posargs).
206+
if session.python == "3.12" and not session.posargs:
207+
pytest_options = [opt for opt in pytest_options if not opt.startswith("--cov")]
208+
pytest_command.extend(["-p", "no:cov"])
209+
210+
session.install("asyncmock", "pytest-asyncio")
212211

213-
session.run(*pytest_args)
212+
session.run(*pytest_command, *pytest_options)
214213

215214

216215
@nox.session(python=PYTHON_VERSIONS)

0 commit comments

Comments
 (0)