Skip to content

Commit a0ad12d

Browse files
committed
Progress towards switching from requests_mock to responses
1 parent 54a1dc4 commit a0ad12d

File tree

5 files changed

+53
-42
lines changed

5 files changed

+53
-42
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ dependencies = [
4444
"piq",
4545
"pydantic-settings",
4646
"requests",
47-
"requests-mock",
47+
"responses",
4848
"torch",
4949
"torchmetrics",
5050
"tzdata; sys_platform=='win32'",

src/mock_vws/_requests_mock_server/decorators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __enter__(self) -> Self:
132132
method=vws_http_method,
133133
url=compiled_url_pattern,
134134
callback=getattr(self._mock_vws_api, vws_route.route_name),
135+
content_type=None,
135136
)
136137

137138
for vwq_route in self._mock_vwq_api.routes:
@@ -147,6 +148,7 @@ def __enter__(self) -> Self:
147148
method=vwq_http_method,
148149
url=compiled_url_pattern,
149150
callback=getattr(self._mock_vwq_api, vwq_route.route_name),
151+
content_type=None,
150152
)
151153

152154
if self._real_http:

src/mock_vws/_requests_mock_server/mock_web_query_api.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
import email.utils
99
from collections.abc import Callable
10-
from http import HTTPMethod
10+
from http import HTTPMethod, HTTPStatus
11+
12+
from requests.models import PreparedRequest
1113

1214
from mock_vws._mock_common import Route
1315
from mock_vws._query_tools import (
@@ -22,7 +24,7 @@
2224

2325
_ROUTES: set[Route] = set()
2426

25-
_ResponseType = str
27+
_ResponseType = tuple[int, dict[str, str], str]
2628

2729

2830
def route(
@@ -64,11 +66,15 @@ def decorator(
6466
return decorator
6567

6668

67-
def _body_bytes(request: "Request") -> bytes:
69+
def _body_bytes(request: PreparedRequest) -> bytes:
6870
"""
6971
Return the body of a request as bytes.
7072
"""
71-
return request.body or b""
73+
if request.body is None:
74+
return b""
75+
76+
assert isinstance(request.body, bytes)
77+
return request.body
7278

7379

7480
class MockVuforiaWebQueryAPI:
@@ -97,28 +103,26 @@ def __init__(
97103
self._query_match_checker = query_match_checker
98104

99105
@route(path_pattern="/v1/query", http_methods={HTTPMethod.POST})
100-
def query(self, request: "Request", context: "Context") -> _ResponseType:
106+
def query(self, request: PreparedRequest) -> _ResponseType:
101107
"""
102108
Perform an image recognition query.
103109
"""
104110
try:
105111
run_query_validators(
106-
request_path=request.path,
112+
request_path=request.path_url,
107113
request_headers=request.headers,
108114
request_body=_body_bytes(request=request),
109-
request_method=request.method,
115+
request_method=request.method or "",
110116
databases=self._target_manager.databases,
111117
)
112118
except ValidatorError as exc:
113-
context.headers = exc.headers
114-
context.status_code = exc.status_code
115-
return exc.response_text
119+
return exc.status_code, exc.headers, exc.response_text
116120

117121
response_text = get_query_match_response_text(
118122
request_headers=request.headers,
119123
request_body=_body_bytes(request=request),
120-
request_method=request.method,
121-
request_path=request.path,
124+
request_method=request.method or "",
125+
request_path=request.path_url,
122126
databases=self._target_manager.databases,
123127
query_match_checker=self._query_match_checker,
124128
)
@@ -128,11 +132,11 @@ def query(self, request: "Request", context: "Context") -> _ResponseType:
128132
localtime=False,
129133
usegmt=True,
130134
)
131-
context.headers = {
135+
headers = {
132136
"Connection": "keep-alive",
133137
"Content-Type": "application/json",
134138
"Server": "nginx",
135139
"Date": date,
136140
"Content-Length": str(len(response_text)),
137141
}
138-
return response_text
142+
return HTTPStatus.OK, headers, response_text

src/mock_vws/_requests_mock_server/mock_web_services_api.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@
3838

3939
_ROUTES: set[Route] = set()
4040

41-
_ResponseType = str
41+
_ResponseType = tuple[int, dict[str, str], str]
4242

4343

4444
def route(
4545
path_pattern: str,
4646
http_methods: set[HTTPMethod],
47-
) -> Callable[[Callable[..., str]], Callable[..., _ResponseType]]:
47+
) -> Callable[[Callable[..., _ResponseType]], Callable[..., _ResponseType]]:
4848
"""
4949
Register a decorated method so that it can be recognized as a route.
5050
@@ -184,14 +184,14 @@ def add_target(self, request: PreparedRequest) -> _ResponseType:
184184
localtime=False,
185185
usegmt=True,
186186
)
187-
context.status_code = HTTPStatus.CREATED
187+
status_code = HTTPStatus.CREATED
188188
body = {
189189
"transaction_id": uuid.uuid4().hex,
190190
"result_code": ResultCodes.TARGET_CREATED.value,
191191
"target_id": new_target.target_id,
192192
}
193193
body_json = json_dump(body=body)
194-
context.headers = {
194+
headers = {
195195
"Connection": "keep-alive",
196196
"Content-Type": "application/json",
197197
"server": "envoy",
@@ -202,7 +202,7 @@ def add_target(self, request: PreparedRequest) -> _ResponseType:
202202
"x-aws-region": "us-east-2, us-west-2",
203203
"x-content-type-options": "nosniff",
204204
}
205-
return body_json
205+
return status_code, headers, body_json
206206

207207
@route(
208208
path_pattern=f"/targets/{_TARGET_ID_PATTERN}",
@@ -240,9 +240,11 @@ def delete_target(self, request: PreparedRequest) -> _ResponseType:
240240

241241
if target.status == TargetStatuses.PROCESSING.value:
242242
target_processing_exception = TargetStatusProcessingError()
243-
context.headers = target_processing_exception.headers
244-
context.status_code = target_processing_exception.status_code
245-
return target_processing_exception.response_text
243+
return (
244+
target_processing_exception.status_code,
245+
target_processing_exception.headers,
246+
target_processing_exception.response_text,
247+
)
246248

247249
now = datetime.datetime.now(tz=target.upload_date.tzinfo)
248250
new_target = dataclasses.replace(target, delete_date=now)
@@ -280,7 +282,6 @@ def database_summary(self, request: PreparedRequest) -> _ResponseType:
280282
Fake implementation of
281283
https://developer.vuforia.com/library/web-api/cloud-targets-web-services-api#summary-report
282284
"""
283-
breakpoint()
284285
try:
285286
run_services_validators(
286287
request_headers=request.headers,
@@ -516,7 +517,7 @@ def get_duplicates(self, request: PreparedRequest) -> _ResponseType:
516517
"similar_targets": similar_targets,
517518
}
518519
body_json = json_dump(body=body)
519-
context.headers = {
520+
headers = {
520521
"Connection": "keep-alive",
521522
"Content-Length": str(len(body_json)),
522523
"Content-Type": "application/json",
@@ -528,7 +529,7 @@ def get_duplicates(self, request: PreparedRequest) -> _ResponseType:
528529
"x-content-type-options": "nosniff",
529530
}
530531

531-
return body_json
532+
return HTTPStatus.OK, headers, body_json
532533

533534
@route(
534535
path_pattern=f"/targets/{_TARGET_ID_PATTERN}",
@@ -572,9 +573,11 @@ def update_target(self, request: PreparedRequest) -> _ResponseType:
572573

573574
if target.status != TargetStatuses.SUCCESS.value:
574575
exception = TargetStatusNotSuccessError()
575-
context.headers = exception.headers
576-
context.status_code = exception.status_code
577-
return exception.response_text
576+
return (
577+
exception.status_code,
578+
exception.headers,
579+
exception.response_text,
580+
)
578581

579582
request_json: dict[str, Any] = json.loads(s=request.body or b"")
580583
width = request_json.get("width", target.width)
@@ -591,18 +594,22 @@ def update_target(self, request: PreparedRequest) -> _ResponseType:
591594

592595
if "active_flag" in request_json and active_flag is None:
593596
fail_exception = FailError(status_code=HTTPStatus.BAD_REQUEST)
594-
context.headers = fail_exception.headers
595-
context.status_code = fail_exception.status_code
596-
return fail_exception.response_text
597+
return (
598+
fail_exception.status_code,
599+
fail_exception.headers,
600+
fail_exception.response_text,
601+
)
597602

598603
if (
599604
"application_metadata" in request_json
600605
and application_metadata is None
601606
):
602607
fail_exception = FailError(status_code=HTTPStatus.BAD_REQUEST)
603-
context.headers = fail_exception.headers
604-
context.status_code = fail_exception.status_code
605-
return fail_exception.response_text
608+
return (
609+
fail_exception.status_code,
610+
fail_exception.headers,
611+
fail_exception.response_text,
612+
)
606613

607614
gmt = ZoneInfo(key="GMT")
608615
last_modified_date = datetime.datetime.now(tz=gmt)
@@ -636,7 +643,7 @@ def update_target(self, request: PreparedRequest) -> _ResponseType:
636643
"x-aws-region": "us-east-2, us-west-2",
637644
"x-content-type-options": "nosniff",
638645
}
639-
return body_json
646+
return HTTPStatus.OK, headers, body_json
640647

641648
@route(
642649
path_pattern=f"/summary/{_TARGET_ID_PATTERN}",
@@ -658,9 +665,7 @@ def target_summary(self, request: PreparedRequest) -> _ResponseType:
658665
databases=self._target_manager.databases,
659666
)
660667
except ValidatorError as exc:
661-
context.headers = exc.headers
662-
context.status_code = exc.status_code
663-
return exc.response_text
668+
return exc.status_code, exc.headers, exc.response_text
664669

665670
database = get_database_matching_server_keys(
666671
request_headers=request.headers,
@@ -691,7 +696,7 @@ def target_summary(self, request: PreparedRequest) -> _ResponseType:
691696
"previous_month_recos": target.previous_month_recos,
692697
}
693698
body_json = json_dump(body=body)
694-
context.headers = {
699+
headers = {
695700
"Connection": "keep-alive",
696701
"Content-Length": str(len(body_json)),
697702
"Content-Type": "application/json",
@@ -703,4 +708,4 @@ def target_summary(self, request: PreparedRequest) -> _ResponseType:
703708
"x-content-type-options": "nosniff",
704709
}
705710

706-
return body_json
711+
return HTTPStatus.OK, headers, body_json

tests/mock_vws/test_requests_mock_usage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_default() -> None:
8080
By default, the mock stops any requests made with `requests` to
8181
non-Vuforia addresses, but not to mocked Vuforia endpoints.
8282
"""
83-
with MockVWS() as mock:
83+
with MockVWS():
8484
with pytest.raises(
8585
expected_exception=requests.exceptions.ConnectionError
8686
):

0 commit comments

Comments
 (0)