Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions ninja/testing/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from abc import ABC, abstractmethod
from json import dumps as json_dumps
from json import loads as json_loads
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
override,
)
from unittest.mock import Mock
from urllib.parse import urljoin

Expand All @@ -11,6 +25,8 @@
from ninja.responses import NinjaJSONEncoder
from ninja.responses import Response as HttpResponse

ResponseT = TypeVar("ResponseT")


def build_absolute_uri(location: Optional[str] = None) -> str:
base = "http://testlocation/"
Expand All @@ -23,7 +39,7 @@ def build_absolute_uri(location: Optional[str] = None) -> str:

# TODO: this should be changed
# maybe add here urlconf object and add urls from here
class NinjaClientBase:
class NinjaClientBase(Generic[ResponseT], ABC):
__test__ = False # <- skip pytest

def __init__(
Expand All @@ -38,7 +54,7 @@ def __init__(

def get(
self, path: str, data: Optional[Dict] = None, **request_params: Any
) -> "NinjaResponse":
) -> ResponseT:
return self.request("GET", path, data, **request_params)

def post(
Expand All @@ -47,7 +63,7 @@ def post(
data: Optional[Dict] = None,
json: Any = None,
**request_params: Any,
) -> "NinjaResponse":
) -> ResponseT:
return self.request("POST", path, data, json, **request_params)

def patch(
Expand All @@ -56,7 +72,7 @@ def patch(
data: Optional[Dict] = None,
json: Any = None,
**request_params: Any,
) -> "NinjaResponse":
) -> ResponseT:
return self.request("PATCH", path, data, json, **request_params)

def put(
Expand All @@ -65,7 +81,7 @@ def put(
data: Optional[Dict] = None,
json: Any = None,
**request_params: Any,
) -> "NinjaResponse":
) -> ResponseT:
return self.request("PUT", path, data, json, **request_params)

def delete(
Expand All @@ -74,7 +90,7 @@ def delete(
data: Optional[Dict] = None,
json: Any = None,
**request_params: Any,
) -> "NinjaResponse":
) -> ResponseT:
return self.request("DELETE", path, data, json, **request_params)

def request(
Expand All @@ -84,7 +100,7 @@ def request(
data: Optional[Dict] = None,
json: Any = None,
**request_params: Any,
) -> "NinjaResponse":
) -> ResponseT:
if json is not None:
request_params["body"] = json_dumps(json, cls=NinjaJSONEncoder)
if data is None:
Expand All @@ -100,7 +116,10 @@ def request(
**request_params.get("COOKIES", {}),
}
func, request, kwargs = self._resolve(method, path, data, request_params)
return self._call(func, request, kwargs) # type: ignore
return self._call(func, request, kwargs)

@abstractmethod
def _call(self, func: Callable, request: Mock, kwargs: Dict) -> ResponseT: ...

@property
def urls(self) -> List:
Expand Down Expand Up @@ -186,12 +205,14 @@ def _build_request(
return request


class TestClient(NinjaClientBase):
class TestClient(NinjaClientBase["NinjaResponse"]):
@override
def _call(self, func: Callable, request: Mock, kwargs: Dict) -> "NinjaResponse":
return NinjaResponse(func(request, **kwargs))


class TestAsyncClient(NinjaClientBase):
class TestAsyncClient(NinjaClientBase[Awaitable["NinjaResponse"]]):
@override
async def _call(
self, func: Callable, request: Mock, kwargs: Dict
) -> "NinjaResponse":
Expand All @@ -206,7 +227,7 @@ def __init__(self, http_response: Union[HttpResponse, StreamingHttpResponse]):
if self.streaming:
self.content = b"".join(http_response.streaming_content) # type: ignore
else:
self.content = http_response.content
self.content = cast(HttpResponse, http_response).content
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy complained about StreamingHttpResponse not having an attribute content.

self._data = None

def json(self) -> Any:
Expand Down