Skip to content

Commit 5835178

Browse files
committed
Add bidiBase, refactor bidiRpc and add AsyncBidi
1 parent fdcd903 commit 5835178

File tree

4 files changed

+125
-79
lines changed

4 files changed

+125
-79
lines changed

google/api_core/bidi.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import time
2323

2424
from google.api_core import exceptions
25+
from google.api_core.bidi_base import BidiRpcBase
2526

2627
_LOGGER = logging.getLogger(__name__)
2728
_BIDIRECTIONAL_CONSUMER_NAME = "Thread-ConsumeBidirectionalStream"
@@ -201,7 +202,7 @@ def __repr__(self):
201202
)
202203

203204

204-
class BidiRpc(object):
205+
class BidiRpc(BidiRpcBase):
205206
"""A helper for consuming a bi-directional streaming RPC.
206207
207208
This maps gRPC's built-in interface which uses a request iterator and a
@@ -240,35 +241,9 @@ class BidiRpc(object):
240241
the request.
241242
"""
242243

243-
def __init__(self, start_rpc, initial_request=None, metadata=None):
244-
self._start_rpc = start_rpc
245-
self._initial_request = initial_request
246-
self._rpc_metadata = metadata
247-
self._request_queue = queue_module.Queue()
248-
self._request_generator = None
249-
self._is_active = False
250-
self._callbacks = []
251-
self.call = None
252-
253-
def add_done_callback(self, callback):
254-
"""Adds a callback that will be called when the RPC terminates.
255-
256-
This occurs when the RPC errors or is successfully terminated.
257-
258-
Args:
259-
callback (Callable[[grpc.Future], None]): The callback to execute.
260-
It will be provided with the same gRPC future as the underlying
261-
stream which will also be a :class:`grpc.Call`.
262-
"""
263-
self._callbacks.append(callback)
264-
265-
def _on_call_done(self, future):
266-
# This occurs when the RPC errors or is successfully terminated.
267-
# Note that grpc's "future" here can also be a grpc.RpcError.
268-
# See note in https://github.com/grpc/grpc/issues/10885#issuecomment-302651331
269-
# that `grpc.RpcError` is also `grpc.call`.
270-
for callback in self._callbacks:
271-
callback(future)
244+
def _create_queue(self):
245+
"""Create a queue for requests."""
246+
return queue_module.Queue()
272247

273248
def open(self):
274249
"""Opens the stream."""
@@ -352,11 +327,6 @@ def is_active(self):
352327
"""bool: True if this stream is currently open and active."""
353328
return self.call is not None and self.call.is_active()
354329

355-
@property
356-
def pending_requests(self):
357-
"""int: Returns an estimate of the number of queued requests."""
358-
return self._request_queue.qsize()
359-
360330

361331
def _never_terminate(future_or_error):
362332
"""By default, no errors cause BiDi termination."""

google/api_core/bidi_async.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2020, Google LLC
1+
# Copyright 2024, Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
1818
import logging
1919

2020
from google.api_core import exceptions
21+
from google.api_core.bidi_base import BidiRpcBase
2122

2223
_LOGGER = logging.getLogger(__name__)
2324

@@ -69,7 +70,7 @@ class _AsyncRequestQueueGenerator:
6970
request.
7071
"""
7172

72-
def __init__(self, queue: asyncio.Queue, period: float = 1, initial_request=None):
73+
def __init__(self, queue: asyncio.Queue, period: float = 0.1, initial_request=None):
7374
self._queue = queue
7475
self._period = period
7576
self._initial_request = initial_request
@@ -122,7 +123,7 @@ async def __aiter__(self):
122123
yield item
123124

124125

125-
class AsyncBidiRpc:
126+
class AsyncBidiRpc(BidiRpcBase):
126127
"""A helper for consuming a async bi-directional streaming RPC.
127128
128129
This maps gRPC's built-in interface which uses a request iterator and a
@@ -161,35 +162,9 @@ class AsyncBidiRpc:
161162
the request.
162163
"""
163164

164-
def __init__(self, start_rpc, initial_request=None, metadata=None):
165-
self._start_rpc = start_rpc
166-
self._initial_request = initial_request
167-
self._rpc_metadata = metadata
168-
self._request_queue = asyncio.Queue()
169-
self._request_generator = None
170-
self._callbacks = []
171-
self.call = None
172-
self._loop = asyncio.get_event_loop()
173-
174-
def add_done_callback(self, callback):
175-
"""Adds a callback that will be called when the RPC terminates.
176-
177-
This occurs when the RPC errors or is successfully terminated.
178-
179-
Args:
180-
callback (Callable[[grpc.Future], None]): The callback to execute.
181-
It will be provided with the same gRPC future as the underlying
182-
stream which will also be a :class:`grpc.aio.Call`.
183-
"""
184-
self._callbacks.append(callback)
185-
186-
def _on_call_done(self, future):
187-
# This occurs when the RPC errors or is successfully terminated.
188-
# Note that grpc's "future" here can also be a grpc.RpcError.
189-
# See note in https://github.com/grpc/grpc/issues/10885#issuecomment-302651331
190-
# that `grpc.RpcError` is also `grpc.aio.Call`.
191-
for callback in self._callbacks:
192-
callback(future)
165+
def _create_queue(self):
166+
"""Create a queue for requests."""
167+
return asyncio.Queue()
193168

194169
async def open(self):
195170
"""Opens the stream."""
@@ -268,8 +243,3 @@ async def recv(self):
268243
def is_active(self):
269244
"""bool: True if this stream is currently open and active."""
270245
return self.call is not None and not self.call.done()
271-
272-
@property
273-
def pending_requests(self):
274-
"""int: Returns an estimate of the number of queued requests."""
275-
return self._request_queue.qsize()

google/api_core/bidi_base.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2024, Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# You may obtain a copy of the License at
5+
# https://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
"""Base class for Bi-directional streaming RPC helpers."""
14+
15+
16+
class BidiRpcBase:
17+
"""A base helper for consuming a bi-directional streaming RPC.
18+
19+
This maps gRPC's built-in interface which uses a request iterator and a
20+
response iterator into a socket-like :func:`send` and :func:`recv`. This
21+
is a more useful pattern for long-running or asymmetric streams (streams
22+
where there is not a direct correlation between the requests and
23+
responses).
24+
25+
This does *not* retry the stream on errors.
26+
27+
Args:
28+
start_rpc (grpc.StreamStreamMultiCallable): The gRPC method used to
29+
start the RPC.
30+
initial_request (Union[protobuf.Message,
31+
Callable[[], protobuf.Message]]): The initial request to
32+
yield. This is useful if an initial request is needed to start the
33+
stream.
34+
metadata (Sequence[Tuple(str, str)]): RPC metadata to include in
35+
the request.
36+
"""
37+
38+
def __init__(self, start_rpc, initial_request=None, metadata=None):
39+
self._start_rpc = start_rpc
40+
self._initial_request = initial_request
41+
self._rpc_metadata = metadata
42+
self._request_queue = self._create_queue()
43+
self._request_generator = None
44+
self._callbacks = []
45+
self.call = None
46+
47+
def _create_queue(self):
48+
"""Create a queue for requests."""
49+
raise NotImplementedError("Not implemented in base class")
50+
51+
def add_done_callback(self, callback):
52+
"""Adds a callback that will be called when the RPC terminates.
53+
54+
This occurs when the RPC errors or is successfully terminated.
55+
56+
Args:
57+
callback (Callable[[grpc.Future], None]): The callback to execute.
58+
It will be provided with the same gRPC future as the underlying
59+
stream which will also be a :class:`grpc.aio.Call`.
60+
"""
61+
self._callbacks.append(callback)
62+
63+
def _on_call_done(self, future):
64+
# This occurs when the RPC errors or is successfully terminated.
65+
# Note that grpc's "future" here can also be a grpc.RpcError.
66+
# See note in https://github.com/grpc/grpc/issues/10885#issuecomment-302651331
67+
# that `grpc.RpcError` is also `grpc.aio.Call`.
68+
for callback in self._callbacks:
69+
callback(future)
70+
71+
def open(self):
72+
"""Opens the stream."""
73+
raise NotImplementedError("Not implemented in base class")
74+
75+
def close(self):
76+
"""Closes the stream."""
77+
raise NotImplementedError("Not implemented in base class")
78+
79+
def send(self, request):
80+
"""Queue a message to be sent on the stream."""
81+
raise NotImplementedError("Not implemented in base class")
82+
83+
def recv(self):
84+
"""Wait for a message to be returned from the stream."""
85+
raise NotImplementedError("Not implemented in base class")
86+
87+
@property
88+
def is_active(self):
89+
"""bool: True if this stream is currently open and active."""
90+
raise NotImplementedError("Not implemented in base class")
91+
92+
@property
93+
def pending_requests(self):
94+
"""int: Returns an estimate of the number of queued requests."""
95+
return self._request_queue.qsize()

tests/asyncio/test_bidi_async.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
import asyncio
1616

17+
from unittest import mock
18+
1719
try:
18-
from unittest import mock
20+
from unittest.mock import AsyncMock
1921
except ImportError: # pragma: NO COVER
20-
import mock # type: ignore
22+
from asyncmock import AsyncMock
2123

2224
import pytest
2325

@@ -138,9 +140,9 @@ async def test_exit_with_stop(self):
138140
def make_async_rpc():
139141
"""Makes a mock async RPC used to test Bidi classes."""
140142
call = mock.create_autospec(aio.StreamStreamCall, instance=True)
141-
rpc = mock.AsyncMock(spec=aio.StreamStreamMultiCallable)
143+
rpc = AsyncMock()
142144

143-
async def rpc_side_effect(request, metadata=None):
145+
def rpc_side_effect(request, metadata=None):
144146
call.done.return_value = False
145147
return call
146148

@@ -151,7 +153,7 @@ def cancel_side_effect():
151153
return True
152154

153155
call.cancel.side_effect = cancel_side_effect
154-
call.read = mock.AsyncMock()
156+
call.read = AsyncMock()
155157

156158
return rpc, call
157159

@@ -167,7 +169,6 @@ async def read(self):
167169
raise self.exception
168170

169171

170-
@pytest.mark.asyncio
171172
class TestAsyncBidiRpc:
172173
def test_initial_state(self):
173174
bidi_rpc = bidi_async.AsyncBidiRpc(None)
@@ -182,6 +183,7 @@ def test_done_callbacks(self):
182183

183184
callback.assert_called_once_with(mock.sentinel.future)
184185

186+
@pytest.mark.asyncio
185187
async def test_metadata(self):
186188
rpc, call = make_async_rpc()
187189
bidi_rpc = bidi_async.AsyncBidiRpc(rpc, metadata=mock.sentinel.A)
@@ -192,6 +194,7 @@ async def test_metadata(self):
192194
rpc.assert_awaited_once()
193195
assert rpc.call_args.kwargs["metadata"] == mock.sentinel.A
194196

197+
@pytest.mark.asyncio
195198
async def test_open(self):
196199
rpc, call = make_async_rpc()
197200
bidi_rpc = bidi_async.AsyncBidiRpc(rpc)
@@ -202,6 +205,7 @@ async def test_open(self):
202205
assert bidi_rpc.is_active
203206
call.add_done_callback.assert_called_once_with(bidi_rpc._on_call_done)
204207

208+
@pytest.mark.asyncio
205209
async def test_open_error_already_open(self):
206210
rpc, _ = make_async_rpc()
207211
bidi_rpc = bidi_async.AsyncBidiRpc(rpc)
@@ -211,6 +215,7 @@ async def test_open_error_already_open(self):
211215
with pytest.raises(ValueError):
212216
await bidi_rpc.open()
213217

218+
@pytest.mark.asyncio
214219
async def test_close(self):
215220
rpc, call = make_async_rpc()
216221
bidi_rpc = bidi_async.AsyncBidiRpc(rpc)
@@ -228,10 +233,12 @@ async def test_close(self):
228233
assert bidi_rpc._initial_request is None
229234
assert not bidi_rpc._callbacks
230235

236+
@pytest.mark.asyncio
231237
async def test_close_no_rpc(self):
232238
bidi_rpc = bidi_async.AsyncBidiRpc(None)
233239
await bidi_rpc.close()
234240

241+
@pytest.mark.asyncio
235242
async def test_send(self):
236243
rpc, call = make_async_rpc()
237244
bidi_rpc = bidi_async.AsyncBidiRpc(rpc)
@@ -242,12 +249,14 @@ async def test_send(self):
242249
assert bidi_rpc.pending_requests == 1
243250
assert await bidi_rpc._request_queue.get() is mock.sentinel.request
244251

252+
@pytest.mark.asyncio
245253
async def test_send_not_open(self):
246254
bidi_rpc = bidi_async.AsyncBidiRpc(None)
247255

248256
with pytest.raises(ValueError):
249257
await bidi_rpc.send(mock.sentinel.request)
250258

259+
@pytest.mark.asyncio
251260
async def test_send_dead_rpc(self):
252261
error = ValueError()
253262
bidi_rpc = bidi_async.AsyncBidiRpc(None)
@@ -258,15 +267,17 @@ async def test_send_dead_rpc(self):
258267

259268
assert exc_info.value == error
260269

270+
@pytest.mark.asyncio
261271
async def test_recv(self):
262272
bidi_rpc = bidi_async.AsyncBidiRpc(None)
263273
bidi_rpc.call = mock.create_autospec(aio.Call, instance=True)
264-
bidi_rpc.call.read = mock.AsyncMock(return_value=mock.sentinel.response)
274+
bidi_rpc.call.read = AsyncMock(return_value=mock.sentinel.response)
265275

266276
response = await bidi_rpc.recv()
267277

268278
assert response == mock.sentinel.response
269279

280+
@pytest.mark.asyncio
270281
async def test_recv_not_open(self):
271282
bidi_rpc = bidi_async.AsyncBidiRpc(None)
272283

0 commit comments

Comments
 (0)