Skip to content

Commit b7a6059

Browse files
committed
Support max num results
1 parent ef81a9b commit b7a6059

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

src/vws/query.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@ def __init__(
3232
def query(
3333
self,
3434
image: io.BytesIO,
35+
max_num_results: int = 1,
3536
) -> str:
3637
"""
3738
TODO docstring
3839
"""
3940
image_content = image.getvalue()
40-
body = {'image': ('image.jpeg', image_content, 'image/jpeg')}
41+
body = {
42+
'image': ('image.jpeg', image_content, 'image/jpeg'),
43+
'max_num_results': (None, max_num_results, 'text/plain'),
44+
}
4145
date = rfc_1123_date()
4246
request_path = '/v1/query'
4347
content, content_type_header = encode_multipart_formdata(body)

tests/test_query.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import io
6+
import uuid
67

78
from vws import VWS, CloudRecoService
89

@@ -42,6 +43,79 @@ def test_match(
4243
assert matching_target['target_id'] == target_id
4344

4445

46+
class TestMaxNumResults:
47+
"""
48+
Tests for the ``max_num_results`` parameter of ``query``.
49+
"""
50+
51+
def test_default(
52+
self,
53+
client: VWS,
54+
cloud_reco_client: CloudRecoService,
55+
high_quality_image: io.BytesIO,
56+
) -> None:
57+
"""
58+
XXX
59+
"""
60+
target_id = client.add_target(
61+
name=uuid.uuid4().hex,
62+
width=1,
63+
image=high_quality_image,
64+
)
65+
target_id_2 = client.add_target(
66+
name=uuid.uuid4().hex,
67+
width=1,
68+
image=high_quality_image,
69+
)
70+
client.wait_for_target_processed(target_id=target_id)
71+
client.wait_for_target_processed(target_id=target_id_2)
72+
matches = cloud_reco_client.query(image=high_quality_image)
73+
assert len(matches) == 1
74+
75+
def test_custom(
76+
self,
77+
client: VWS,
78+
cloud_reco_client: CloudRecoService,
79+
high_quality_image: io.BytesIO,
80+
) -> None:
81+
"""
82+
XXX
83+
"""
84+
target_id = client.add_target(
85+
name=uuid.uuid4().hex,
86+
width=1,
87+
image=high_quality_image,
88+
)
89+
target_id_2 = client.add_target(
90+
name=uuid.uuid4().hex,
91+
width=1,
92+
image=high_quality_image,
93+
)
94+
client.wait_for_target_processed(target_id=target_id)
95+
client.wait_for_target_processed(target_id=target_id_2)
96+
matches = cloud_reco_client.query(image=high_quality_image, max_num_results=2)
97+
assert len(matches) == 2
98+
99+
100+
101+
def test_foo(self):
102+
pass
103+
# target_ids = set([])
104+
# for i in range(15):
105+
# target_id = client.add_target(
106+
# name=uuid.uuid4().hex,
107+
# width=1,
108+
# image=high_quality_image,
109+
# )
110+
# target_ids.add(target_id)
111+
#
112+
# for target_id in target_ids:
113+
# client.wait_for_target_processed(target_id=target_id)
114+
#
115+
# matching_targets = cloud_reco_client.query(image=high_quality_image)
116+
# assert len(matching_targets) == 1
117+
118+
45119
# TODO test custom base URL
46120
# TODO test bad credentials
47121
# TODO test options - max_num_results + include_target_data

0 commit comments

Comments
 (0)