Skip to content

Commit 3e1ff2e

Browse files
committed
Allow the public methods to take file-like objects, not just io.BytesIO
1 parent 246c9b3 commit 3e1ff2e

File tree

8 files changed

+146
-114
lines changed

8 files changed

+146
-114
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Changelog
44
Next
55
----
66

7+
* Support file-like objects in every method which accepts a file.
8+
79
2023.03.05
810
------------
911

README.md

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ shutil.copy(image, new_image)
5151
<!--pytest-codeblocks:cont-->
5252

5353
```python
54-
import io
5554
import pathlib
5655

5756
from vws import VWS, CloudRecoService
@@ -73,17 +72,15 @@ name = 'my_image_name'
7372

7473
image = pathlib.Path('high_quality_image.jpg')
7574
with image.open(mode='rb') as my_image_file:
76-
my_image = io.BytesIO(my_image_file.read())
77-
78-
target_id = vws_client.add_target(
79-
name=name,
80-
width=1,
81-
image=my_image,
82-
active_flag=True,
83-
application_metadata=None,
84-
)
85-
vws_client.wait_for_target_processed(target_id=target_id)
86-
matching_targets = cloud_reco_client.query(image=my_image)
75+
target_id = vws_client.add_target(
76+
name=name,
77+
width=1,
78+
image=my_image_file,
79+
active_flag=True,
80+
application_metadata=None,
81+
)
82+
vws_client.wait_for_target_processed(target_id=target_id)
83+
matching_targets = cloud_reco_client.query(image=my_image_file)
8784

8885
assert matching_targets[0].target_id == target_id
8986
```

docs/source/index.rst

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ See the :doc:`api-reference` for full usage details.
4343

4444
.. testcode:: simple
4545

46-
import io
4746
import pathlib
4847

4948
from vws import VWS, CloudRecoService
@@ -65,17 +64,15 @@ See the :doc:`api-reference` for full usage details.
6564

6665
image = pathlib.Path('high_quality_image.jpg')
6766
with image.open(mode='rb') as my_image_file:
68-
my_image = io.BytesIO(my_image_file.read())
69-
70-
target_id = vws_client.add_target(
71-
name=name,
72-
width=1,
73-
image=my_image,
74-
active_flag=True,
75-
application_metadata=None,
76-
)
77-
vws_client.wait_for_target_processed(target_id=target_id)
78-
matching_targets = cloud_reco_client.query(image=my_image)
67+
target_id = vws_client.add_target(
68+
name=name,
69+
width=1,
70+
image=my_image_file,
71+
active_flag=True,
72+
application_metadata=None,
73+
)
74+
vws_client.wait_for_target_processed(target_id=target_id)
75+
matching_targets = cloud_reco_client.query(image=my_image_file)
7976

8077
assert matching_targets[0].target_id == target_id
8178

@@ -109,7 +106,6 @@ To write unit tests for code which uses this library, without using your Vuforia
109106

110107
.. testcode:: testing
111108

112-
import io
113109
import pathlib
114110

115111
from mock_vws.database import VuforiaDatabase
@@ -131,15 +127,13 @@ To write unit tests for code which uses this library, without using your Vuforia
131127

132128
image = pathlib.Path('high_quality_image.jpg')
133129
with image.open(mode='rb') as my_image_file:
134-
my_image = io.BytesIO(my_image_file.read())
135-
136-
target_id = vws_client.add_target(
137-
name="example_image_name",
138-
width=1,
139-
image=my_image,
140-
application_metadata=None,
141-
active_flag=True,
142-
)
130+
target_id = vws_client.add_target(
131+
name="example_image_name",
132+
width=1,
133+
image=my_image_file,
134+
application_metadata=None,
135+
active_flag=True,
136+
)
143137

144138
.. testcleanup:: testing
145139

src/vws/query.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import datetime
88
from http import HTTPStatus
9-
from typing import TYPE_CHECKING
9+
from typing import BinaryIO
1010
from urllib.parse import urljoin
1111

1212
import requests
@@ -27,8 +27,13 @@
2727
from vws.include_target_data import CloudRecoIncludeTargetData
2828
from vws.reports import QueryResult, TargetData
2929

30-
if TYPE_CHECKING:
31-
import io
30+
31+
def _get_image_data(image: BinaryIO) -> bytes:
32+
original_tell = image.tell()
33+
image.seek(0)
34+
image_data = image.read()
35+
image.seek(original_tell)
36+
return image_data
3237

3338

3439
class CloudRecoService:
@@ -54,7 +59,7 @@ def __init__(
5459

5560
def query(
5661
self,
57-
image: io.BytesIO,
62+
image: BinaryIO,
5863
max_num_results: int = 1,
5964
include_target_data: CloudRecoIncludeTargetData = (
6065
CloudRecoIncludeTargetData.TOP
@@ -96,7 +101,7 @@ def query(
96101
Returns:
97102
An ordered list of target details of matching targets.
98103
"""
99-
image_content = image.getvalue()
104+
image_content = _get_image_data(image=image)
100105
body = {
101106
"image": ("image.jpeg", image_content, "image/jpeg"),
102107
"max_num_results": (None, int(max_num_results), "text/plain"),

src/vws/vws.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
from datetime import date
1010
from time import sleep
11-
from typing import TYPE_CHECKING
11+
from typing import TYPE_CHECKING, BinaryIO
1212
from urllib.parse import urljoin
1313

1414
import requests
@@ -51,6 +51,14 @@
5151
import io
5252

5353

54+
def _get_image_data(image: BinaryIO) -> bytes:
55+
original_tell = image.tell()
56+
image.seek(0)
57+
image_data = image.read()
58+
image.seek(original_tell)
59+
return image_data
60+
61+
5462
def _target_api_request(
5563
server_access_key: str,
5664
server_secret_key: str,
@@ -204,7 +212,7 @@ def add_target(
204212
self,
205213
name: str,
206214
width: int | float,
207-
image: io.BytesIO,
215+
image: BinaryIO,
208216
application_metadata: str | None,
209217
*,
210218
active_flag: bool,
@@ -255,7 +263,7 @@ def add_target(
255263
occurred". This has been seen to happen when the given name
256264
includes a bad character.
257265
"""
258-
image_data = image.getvalue()
266+
image_data = _get_image_data(image=image)
259267
image_data_encoded = base64.b64encode(image_data).decode("ascii")
260268

261269
data = {
@@ -644,7 +652,7 @@ def update_target(
644652
data["width"] = width
645653

646654
if image is not None:
647-
image_data = image.getvalue()
655+
image_data = _get_image_data(image=image)
648656
image_data_encoded = base64.b64encode(image_data).decode("ascii")
649657
data["image"] = image_data_encoded
650658

tests/conftest.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44

55
from __future__ import annotations
66

7-
from typing import TYPE_CHECKING
7+
import io
8+
from typing import TYPE_CHECKING, BinaryIO
89

910
import pytest
1011
from mock_vws import MockVWS
1112
from mock_vws.database import VuforiaDatabase
1213
from vws import VWS, CloudRecoService
1314

1415
if TYPE_CHECKING:
15-
from collections.abc import Iterator
16+
from collections.abc import Generator
17+
from pathlib import Path
1618

1719

1820
@pytest.fixture(name="_mock_database")
19-
def mock_database() -> Iterator[VuforiaDatabase]:
21+
def mock_database() -> Generator[VuforiaDatabase, None, None]:
2022
"""
2123
Yield a mock ``VuforiaDatabase``.
2224
"""
@@ -29,7 +31,7 @@ def mock_database() -> Iterator[VuforiaDatabase]:
2931
@pytest.fixture()
3032
def vws_client(_mock_database: VuforiaDatabase) -> VWS:
3133
"""
32-
Yield a VWS client which connects to a mock database.
34+
A VWS client which connects to a mock database.
3335
"""
3436
return VWS(
3537
server_access_key=_mock_database.server_access_key,
@@ -40,9 +42,33 @@ def vws_client(_mock_database: VuforiaDatabase) -> VWS:
4042
@pytest.fixture()
4143
def cloud_reco_client(_mock_database: VuforiaDatabase) -> CloudRecoService:
4244
"""
43-
Yield a ``CloudRecoService`` client which connects to a mock database.
45+
A ``CloudRecoService`` client which connects to a mock database.
4446
"""
4547
return CloudRecoService(
4648
client_access_key=_mock_database.client_access_key,
4749
client_secret_key=_mock_database.client_secret_key,
4850
)
51+
52+
53+
@pytest.fixture()
54+
def image_file(
55+
high_quality_image: io.BytesIO,
56+
tmp_path: Path,
57+
) -> Generator[io.BufferedRandom, None, None]:
58+
"""An image file object."""
59+
file = tmp_path / "image.jpg"
60+
file.touch()
61+
with file.open("r+b") as fileobj:
62+
buffer = high_quality_image.getvalue()
63+
fileobj.write(buffer)
64+
yield fileobj
65+
66+
67+
@pytest.fixture(params=["high_quality_image", "image_file"])
68+
def image(
69+
request: pytest.FixtureRequest,
70+
) -> BinaryIO:
71+
"""An image in any of the types that the API accepts."""
72+
result = request.getfixturevalue(request.param)
73+
assert isinstance(result, io.BytesIO | io.BufferedRandom)
74+
return result

0 commit comments

Comments
 (0)