|
11 | 11 |
|
12 | 12 | from vws import VWS, CloudRecoService |
13 | 13 | from vws.exceptions import MaxNumResultsOutOfRange |
| 14 | +from vws.include_target_data import CloudRecoIncludeTargetData |
14 | 15 |
|
15 | 16 |
|
16 | 17 | class TestQuery: |
@@ -167,3 +168,127 @@ def test_too_many( |
167 | 168 | 'Accepted range is from 1 to 50 (inclusive).' |
168 | 169 | ) |
169 | 170 | assert str(exc.value) == exc.value.response.text == expected_value |
| 171 | + |
| 172 | + |
| 173 | +class TestIncludeTargetData: |
| 174 | + """ |
| 175 | + Tests for the ``include_target_data`` parameter of ``query``. |
| 176 | + """ |
| 177 | + |
| 178 | + def test_default( |
| 179 | + self, |
| 180 | + vws_client: VWS, |
| 181 | + cloud_reco_client: CloudRecoService, |
| 182 | + high_quality_image: io.BytesIO, |
| 183 | + ) -> None: |
| 184 | + """ |
| 185 | + By default, target data is only returned in the top match. |
| 186 | + """ |
| 187 | + target_id = vws_client.add_target( |
| 188 | + name=uuid.uuid4().hex, |
| 189 | + width=1, |
| 190 | + image=high_quality_image, |
| 191 | + ) |
| 192 | + target_id_2 = vws_client.add_target( |
| 193 | + name=uuid.uuid4().hex, |
| 194 | + width=1, |
| 195 | + image=high_quality_image, |
| 196 | + ) |
| 197 | + vws_client.wait_for_target_processed(target_id=target_id) |
| 198 | + vws_client.wait_for_target_processed(target_id=target_id_2) |
| 199 | + top_match, second_match = cloud_reco_client.query( |
| 200 | + image=high_quality_image, |
| 201 | + max_num_results=2, |
| 202 | + ) |
| 203 | + assert 'target_data' in top_match |
| 204 | + assert 'target_data' not in second_match |
| 205 | + |
| 206 | + def test_top( |
| 207 | + self, |
| 208 | + vws_client: VWS, |
| 209 | + cloud_reco_client: CloudRecoService, |
| 210 | + high_quality_image: io.BytesIO, |
| 211 | + ) -> None: |
| 212 | + """ |
| 213 | + When ``CloudRecoIncludeTargetData.TOP`` is given, target data is only |
| 214 | + returned in the top match. |
| 215 | + """ |
| 216 | + target_id = vws_client.add_target( |
| 217 | + name=uuid.uuid4().hex, |
| 218 | + width=1, |
| 219 | + image=high_quality_image, |
| 220 | + ) |
| 221 | + target_id_2 = vws_client.add_target( |
| 222 | + name=uuid.uuid4().hex, |
| 223 | + width=1, |
| 224 | + image=high_quality_image, |
| 225 | + ) |
| 226 | + vws_client.wait_for_target_processed(target_id=target_id) |
| 227 | + vws_client.wait_for_target_processed(target_id=target_id_2) |
| 228 | + top_match, second_match = cloud_reco_client.query( |
| 229 | + image=high_quality_image, |
| 230 | + max_num_results=2, |
| 231 | + include_target_data=CloudRecoIncludeTargetData.TOP, |
| 232 | + ) |
| 233 | + assert 'target_data' in top_match |
| 234 | + assert 'target_data' not in second_match |
| 235 | + |
| 236 | + def test_none( |
| 237 | + self, |
| 238 | + vws_client: VWS, |
| 239 | + cloud_reco_client: CloudRecoService, |
| 240 | + high_quality_image: io.BytesIO, |
| 241 | + ) -> None: |
| 242 | + """ |
| 243 | + When ``CloudRecoIncludeTargetData.NONE`` is given, target data is not |
| 244 | + returned in any match. |
| 245 | + """ |
| 246 | + target_id = vws_client.add_target( |
| 247 | + name=uuid.uuid4().hex, |
| 248 | + width=1, |
| 249 | + image=high_quality_image, |
| 250 | + ) |
| 251 | + target_id_2 = vws_client.add_target( |
| 252 | + name=uuid.uuid4().hex, |
| 253 | + width=1, |
| 254 | + image=high_quality_image, |
| 255 | + ) |
| 256 | + vws_client.wait_for_target_processed(target_id=target_id) |
| 257 | + vws_client.wait_for_target_processed(target_id=target_id_2) |
| 258 | + top_match, second_match = cloud_reco_client.query( |
| 259 | + image=high_quality_image, |
| 260 | + max_num_results=2, |
| 261 | + include_target_data=CloudRecoIncludeTargetData.NONE, |
| 262 | + ) |
| 263 | + assert 'target_data' not in top_match |
| 264 | + assert 'target_data' not in second_match |
| 265 | + |
| 266 | + def test_all( |
| 267 | + self, |
| 268 | + vws_client: VWS, |
| 269 | + cloud_reco_client: CloudRecoService, |
| 270 | + high_quality_image: io.BytesIO, |
| 271 | + ) -> None: |
| 272 | + """ |
| 273 | + When ``CloudRecoIncludeTargetData.ALL`` is given, target data is |
| 274 | + returned in all matches. |
| 275 | + """ |
| 276 | + target_id = vws_client.add_target( |
| 277 | + name=uuid.uuid4().hex, |
| 278 | + width=1, |
| 279 | + image=high_quality_image, |
| 280 | + ) |
| 281 | + target_id_2 = vws_client.add_target( |
| 282 | + name=uuid.uuid4().hex, |
| 283 | + width=1, |
| 284 | + image=high_quality_image, |
| 285 | + ) |
| 286 | + vws_client.wait_for_target_processed(target_id=target_id) |
| 287 | + vws_client.wait_for_target_processed(target_id=target_id_2) |
| 288 | + top_match, second_match = cloud_reco_client.query( |
| 289 | + image=high_quality_image, |
| 290 | + max_num_results=2, |
| 291 | + include_target_data=CloudRecoIncludeTargetData.ALL, |
| 292 | + ) |
| 293 | + assert 'target_data' in top_match |
| 294 | + assert 'target_data' in second_match |
0 commit comments