Skip to content

Commit 3f924b3

Browse files
committed
Progress towards flexible tests
1 parent bfb2808 commit 3f924b3

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

src/mock_vws/target.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from zoneinfo import ZoneInfo
1515

1616
import brisque
17+
import cv2
1718
import numpy as np
1819
from PIL import Image, ImageStat
1920

@@ -64,12 +65,15 @@ def _quality(image_content: bytes) -> int:
6465
image = Image.open(fp=image_file)
6566
image_array = np.asarray(a=image)
6667
obj = brisque.BRISQUE(url=False)
68+
min_height = min_width = 2
69+
if image.height < min_height or image.width < min_width:
70+
return -2
6771
# We avoid a barrage of warnings from the BRISQUE library.
6872
with np.errstate(divide="ignore", invalid="ignore"):
6973
score = obj.score(img=image_array)
7074
if math.isnan(score):
7175
return 0
72-
return int(score / 20)
76+
return min(int(score / 5), 5)
7377

7478

7579
@dataclass(frozen=True, eq=True)
@@ -168,10 +172,14 @@ def tracking_rating(self) -> int:
168172
if time_since_upload <= pre_rating_time:
169173
return -1
170174

171-
if self._post_processing_status == TargetStatuses.SUCCESS:
175+
try:
172176
return _quality(image_content=self.image_value)
173-
174-
return 0
177+
except cv2.error as exc:
178+
breakpoint()
179+
print(exc)
180+
return -2
181+
else:
182+
return 0
175183

176184
@classmethod
177185
def from_dict(cls, target_dict: TargetDict) -> Target:

tests/mock_vws/test_get_target.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import uuid
9+
from enum import Enum, auto
910
from typing import TYPE_CHECKING
1011

1112
import pytest
@@ -140,13 +141,19 @@ def test_target_quality(
140141
"""
141142
The target tracking rating is as expected.
142143
"""
144+
145+
class TrackingRating(Enum):
146+
ZERO = auto()
147+
NEGATIVE = auto()
148+
POSITIVE = auto()
149+
143150
target_id_expected_rating_pairs = []
144151
for image_file in (
145152
high_quality_image,
146153
image_file_failed_state,
147-
image_file_success_state_low_rating,
154+
# image_file_success_state_low_rating,
148155
corrupted_image_file,
149-
different_high_quality_image,
156+
# different_high_quality_image,
150157
):
151158
target_id = vws_client.add_target(
152159
name=f"example_{uuid.uuid4().hex}",
@@ -157,11 +164,11 @@ def test_target_quality(
157164
)
158165

159166
expected_tracking_rating = {
160-
high_quality_image: 5,
161-
image_file_failed_state: 0,
162-
image_file_success_state_low_rating: 0,
163-
corrupted_image_file: -2,
164-
different_high_quality_image: 0,
167+
high_quality_image: TrackingRating.POSITIVE,
168+
image_file_failed_state: TrackingRating.ZERO,
169+
image_file_success_state_low_rating: TrackingRating.ZERO,
170+
corrupted_image_file: TrackingRating.NEGATIVE,
171+
different_high_quality_image: TrackingRating.ZERO,
165172
}[image_file]
166173
target_id_expected_rating_pairs.append(
167174
(target_id, expected_tracking_rating),
@@ -174,10 +181,12 @@ def test_target_quality(
174181
vws_client.wait_for_target_processed(target_id=target_id)
175182

176183
target_details = vws_client.get_target_record(target_id=target_id)
177-
assert (
178-
target_details.target_record.tracking_rating
179-
== expected_tracking_rating
180-
)
184+
if expected_tracking_rating == TrackingRating.ZERO:
185+
assert target_details.target_record.tracking_rating == 0
186+
elif expected_tracking_rating == TrackingRating.NEGATIVE:
187+
assert target_details.target_record.tracking_rating < 0
188+
else:
189+
assert target_details.target_record.tracking_rating > 0
181190

182191

183192
@pytest.mark.usefixtures("verify_mock_vuforia")

0 commit comments

Comments
 (0)