Skip to content

Commit 1cb6f44

Browse files
committed
Fix method return types of TestAsyncClient
1 parent fe25a67 commit 1cb6f44

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

ninja/testing/client.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
from json import dumps as json_dumps
22
from json import loads as json_loads
3-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3+
from typing import (
4+
Any,
5+
Awaitable,
6+
Callable,
7+
Dict,
8+
Generic,
9+
List,
10+
Optional,
11+
Tuple,
12+
TypeVar,
13+
Union,
14+
cast,
15+
)
416
from unittest.mock import Mock
517
from urllib.parse import urljoin
618

@@ -11,6 +23,8 @@
1123
from ninja.responses import NinjaJSONEncoder
1224
from ninja.responses import Response as HttpResponse
1325

26+
ResponseT = TypeVar("ResponseT")
27+
1428

1529
def build_absolute_uri(location: Optional[str] = None) -> str:
1630
base = "http://testlocation/"
@@ -23,7 +37,7 @@ def build_absolute_uri(location: Optional[str] = None) -> str:
2337

2438
# TODO: this should be changed
2539
# maybe add here urlconf object and add urls from here
26-
class NinjaClientBase:
40+
class NinjaClientBase(Generic[ResponseT]):
2741
__test__ = False # <- skip pytest
2842

2943
def __init__(
@@ -38,7 +52,7 @@ def __init__(
3852

3953
def get(
4054
self, path: str, data: Optional[Dict] = None, **request_params: Any
41-
) -> "NinjaResponse":
55+
) -> ResponseT:
4256
return self.request("GET", path, data, **request_params)
4357

4458
def post(
@@ -47,7 +61,7 @@ def post(
4761
data: Optional[Dict] = None,
4862
json: Any = None,
4963
**request_params: Any,
50-
) -> "NinjaResponse":
64+
) -> ResponseT:
5165
return self.request("POST", path, data, json, **request_params)
5266

5367
def patch(
@@ -56,7 +70,7 @@ def patch(
5670
data: Optional[Dict] = None,
5771
json: Any = None,
5872
**request_params: Any,
59-
) -> "NinjaResponse":
73+
) -> ResponseT:
6074
return self.request("PATCH", path, data, json, **request_params)
6175

6276
def put(
@@ -65,7 +79,7 @@ def put(
6579
data: Optional[Dict] = None,
6680
json: Any = None,
6781
**request_params: Any,
68-
) -> "NinjaResponse":
82+
) -> ResponseT:
6983
return self.request("PUT", path, data, json, **request_params)
7084

7185
def delete(
@@ -74,7 +88,7 @@ def delete(
7488
data: Optional[Dict] = None,
7589
json: Any = None,
7690
**request_params: Any,
77-
) -> "NinjaResponse":
91+
) -> ResponseT:
7892
return self.request("DELETE", path, data, json, **request_params)
7993

8094
def request(
@@ -84,7 +98,7 @@ def request(
8498
data: Optional[Dict] = None,
8599
json: Any = None,
86100
**request_params: Any,
87-
) -> "NinjaResponse":
101+
) -> ResponseT:
88102
if json is not None:
89103
request_params["body"] = json_dumps(json, cls=NinjaJSONEncoder)
90104
if data is None:
@@ -147,10 +161,12 @@ def _build_request(
147161
request.META = request_params.pop("META", {"REMOTE_ADDR": "127.0.0.1"})
148162
request.FILES = request_params.pop("FILES", {})
149163

150-
request.META.update({
151-
f"HTTP_{k.replace('-', '_')}": v
152-
for k, v in request_params.pop("headers", {}).items()
153-
})
164+
request.META.update(
165+
{
166+
f"HTTP_{k.replace('-', '_')}": v
167+
for k, v in request_params.pop("headers", {}).items()
168+
}
169+
)
154170

155171
request.headers = HttpHeaders(request.META)
156172

@@ -186,12 +202,12 @@ def _build_request(
186202
return request
187203

188204

189-
class TestClient(NinjaClientBase):
205+
class TestClient(NinjaClientBase["NinjaResponse"]):
190206
def _call(self, func: Callable, request: Mock, kwargs: Dict) -> "NinjaResponse":
191207
return NinjaResponse(func(request, **kwargs))
192208

193209

194-
class TestAsyncClient(NinjaClientBase):
210+
class TestAsyncClient(NinjaClientBase[Awaitable["NinjaResponse"]]):
195211
async def _call(
196212
self, func: Callable, request: Mock, kwargs: Dict
197213
) -> "NinjaResponse":
@@ -206,7 +222,7 @@ def __init__(self, http_response: Union[HttpResponse, StreamingHttpResponse]):
206222
if self.streaming:
207223
self.content = b"".join(http_response.streaming_content) # type: ignore
208224
else:
209-
self.content = http_response.content
225+
self.content = cast(HttpResponse, http_response).content
210226
self._data = None
211227

212228
def json(self) -> Any:

0 commit comments

Comments
 (0)