diff --git a/ninja/testing/client.py b/ninja/testing/client.py index 759722fff..3409b8c66 100644 --- a/ninja/testing/client.py +++ b/ninja/testing/client.py @@ -1,6 +1,20 @@ +from abc import ABC, abstractmethod from json import dumps as json_dumps from json import loads as json_loads -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, + override, +) from unittest.mock import Mock from urllib.parse import urljoin @@ -11,6 +25,8 @@ from ninja.responses import NinjaJSONEncoder from ninja.responses import Response as HttpResponse +ResponseT = TypeVar("ResponseT") + def build_absolute_uri(location: Optional[str] = None) -> str: base = "http://testlocation/" @@ -23,7 +39,7 @@ def build_absolute_uri(location: Optional[str] = None) -> str: # TODO: this should be changed # maybe add here urlconf object and add urls from here -class NinjaClientBase: +class NinjaClientBase(Generic[ResponseT], ABC): __test__ = False # <- skip pytest def __init__( @@ -38,7 +54,7 @@ def __init__( def get( self, path: str, data: Optional[Dict] = None, **request_params: Any - ) -> "NinjaResponse": + ) -> ResponseT: return self.request("GET", path, data, **request_params) def post( @@ -47,7 +63,7 @@ def post( data: Optional[Dict] = None, json: Any = None, **request_params: Any, - ) -> "NinjaResponse": + ) -> ResponseT: return self.request("POST", path, data, json, **request_params) def patch( @@ -56,7 +72,7 @@ def patch( data: Optional[Dict] = None, json: Any = None, **request_params: Any, - ) -> "NinjaResponse": + ) -> ResponseT: return self.request("PATCH", path, data, json, **request_params) def put( @@ -65,7 +81,7 @@ def put( data: Optional[Dict] = None, json: Any = None, **request_params: Any, - ) -> "NinjaResponse": + ) -> ResponseT: return self.request("PUT", path, data, json, **request_params) def delete( @@ -74,7 +90,7 @@ def delete( data: Optional[Dict] = None, json: Any = None, **request_params: Any, - ) -> "NinjaResponse": + ) -> ResponseT: return self.request("DELETE", path, data, json, **request_params) def request( @@ -84,7 +100,7 @@ def request( data: Optional[Dict] = None, json: Any = None, **request_params: Any, - ) -> "NinjaResponse": + ) -> ResponseT: if json is not None: request_params["body"] = json_dumps(json, cls=NinjaJSONEncoder) if data is None: @@ -100,7 +116,10 @@ def request( **request_params.get("COOKIES", {}), } func, request, kwargs = self._resolve(method, path, data, request_params) - return self._call(func, request, kwargs) # type: ignore + return self._call(func, request, kwargs) + + @abstractmethod + def _call(self, func: Callable, request: Mock, kwargs: Dict) -> ResponseT: ... @property def urls(self) -> List: @@ -186,12 +205,14 @@ def _build_request( return request -class TestClient(NinjaClientBase): +class TestClient(NinjaClientBase["NinjaResponse"]): + @override def _call(self, func: Callable, request: Mock, kwargs: Dict) -> "NinjaResponse": return NinjaResponse(func(request, **kwargs)) -class TestAsyncClient(NinjaClientBase): +class TestAsyncClient(NinjaClientBase[Awaitable["NinjaResponse"]]): + @override async def _call( self, func: Callable, request: Mock, kwargs: Dict ) -> "NinjaResponse": @@ -206,7 +227,7 @@ def __init__(self, http_response: Union[HttpResponse, StreamingHttpResponse]): if self.streaming: self.content = b"".join(http_response.streaming_content) # type: ignore else: - self.content = http_response.content + self.content = cast(HttpResponse, http_response).content self._data = None def json(self) -> Any: