Skip to content

Commit da1862e

Browse files
committed
Add tests
Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>
1 parent 5fddbb0 commit da1862e

File tree

2 files changed

+111
-38
lines changed

2 files changed

+111
-38
lines changed

vulnerabilities/api_v2.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#
99

1010

11+
from django.db.models import Prefetch
1112
from django_filters import rest_framework as filters
1213
from drf_spectacular.utils import OpenApiParameter
1314
from drf_spectacular.utils import extend_schema
@@ -20,8 +21,6 @@
2021
from rest_framework.response import Response
2122
from rest_framework.reverse import reverse
2223

23-
from vulnerabilities.api import PackageFilterSet
24-
from vulnerabilities.api import VulnerabilitySeveritySerializer
2524
from vulnerabilities.models import Package
2625
from vulnerabilities.models import Vulnerability
2726
from vulnerabilities.models import VulnerabilityReference
@@ -198,9 +197,8 @@ def get_affected_by_vulnerabilities(self, obj):
198197
"""
199198
Return a dictionary with vulnerabilities as keys and their details, including fixed_by_packages.
200199
"""
201-
vulnerabilities = obj.affected_by_vulnerabilities.prefetch_related("fixed_by_packages")
202200
result = {}
203-
for vuln in vulnerabilities:
201+
for vuln in getattr(obj, "prefetched_affected_vulnerabilities", []):
204202
fixed_by_package = vuln.fixed_by_packages.first()
205203
purl = None
206204
if fixed_by_package:
@@ -247,7 +245,13 @@ class PackageV2FilterSet(filters.FilterSet):
247245

248246

249247
class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
250-
queryset = Package.objects.all()
248+
queryset = Package.objects.all().prefetch_related(
249+
Prefetch(
250+
"affected_by_vulnerabilities",
251+
queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"),
252+
to_attr="prefetched_affected_vulnerabilities",
253+
)
254+
)
251255
serializer_class = PackageV2Serializer
252256
filter_backends = (filters.DjangoFilterBackend,)
253257
filterset_class = PackageV2FilterSet

vulnerabilities/tests/test_api_v2.py

Lines changed: 102 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
99

10+
from django.db.models import Prefetch
1011
from django.urls import reverse
1112
from packageurl import PackageURL
1213
from rest_framework import status
@@ -67,6 +68,8 @@ def test_list_vulnerabilities(self):
6768
"""
6869
url = reverse("vulnerability-v2-list")
6970
response = self.client.get(url, format="json")
71+
with self.assertNumQueries(5):
72+
response = self.client.get(url, format="json")
7073
self.assertEqual(response.status_code, status.HTTP_200_OK)
7174
self.assertIn("results", response.data)
7275
self.assertIn("vulnerabilities", response.data["results"])
@@ -80,7 +83,8 @@ def test_retrieve_vulnerability_detail(self):
8083
Test retrieving vulnerability details by vulnerability_id.
8184
"""
8285
url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-1234"})
83-
response = self.client.get(url, format="json")
86+
with self.assertNumQueries(8):
87+
response = self.client.get(url, format="json")
8488
self.assertEqual(response.status_code, status.HTTP_200_OK)
8589
self.assertEqual(response.data["vulnerability_id"], "VCID-1234")
8690
self.assertEqual(response.data["summary"], "Test vulnerability 1")
@@ -93,7 +97,8 @@ def test_filter_vulnerability_by_vulnerability_id(self):
9397
Test filtering vulnerabilities by vulnerability_id.
9498
"""
9599
url = reverse("vulnerability-v2-list")
96-
response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json")
100+
with self.assertNumQueries(4):
101+
response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json")
97102
self.assertEqual(response.status_code, status.HTTP_200_OK)
98103
self.assertEqual(response.data["vulnerability_id"], "VCID-1234")
99104

@@ -102,7 +107,8 @@ def test_filter_vulnerability_by_alias(self):
102107
Test filtering vulnerabilities by alias.
103108
"""
104109
url = reverse("vulnerability-v2-list")
105-
response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json")
110+
with self.assertNumQueries(5):
111+
response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json")
106112
self.assertEqual(response.status_code, status.HTTP_200_OK)
107113
self.assertIn("results", response.data)
108114
self.assertIn("vulnerabilities", response.data["results"])
@@ -116,7 +122,8 @@ def test_filter_vulnerabilities_multiple_ids(self):
116122
Test filtering vulnerabilities by multiple vulnerability_ids.
117123
"""
118124
url = reverse("vulnerability-v2-list")
119-
response = self.client.get(
125+
with self.assertNumQueries(5):
126+
response = self.client.get(
120127
url, {"vulnerability_id": ["VCID-1234", "VCID-5678"]}, format="json"
121128
)
122129
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -127,7 +134,8 @@ def test_filter_vulnerabilities_multiple_aliases(self):
127134
Test filtering vulnerabilities by multiple aliases.
128135
"""
129136
url = reverse("vulnerability-v2-list")
130-
response = self.client.get(
137+
with self.assertNumQueries(5):
138+
response = self.client.get(
131139
url, {"alias": ["CVE-2021-1234", "CVE-2021-5678"]}, format="json"
132140
)
133141
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -139,7 +147,8 @@ def test_invalid_vulnerability_id(self):
139147
Should return 404 Not Found.
140148
"""
141149
url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-9999"})
142-
response = self.client.get(url, format="json")
150+
with self.assertNumQueries(5):
151+
response = self.client.get(url, format="json")
143152
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
144153

145154
def test_get_url_in_serializer(self):
@@ -207,7 +216,8 @@ def test_list_packages(self):
207216
Should return a list of packages with their details and associated vulnerabilities.
208217
"""
209218
url = reverse("package-v2-list")
210-
response = self.client.get(url, format="json")
219+
with self.assertNumQueries(31):
220+
response = self.client.get(url, format="json")
211221
self.assertEqual(response.status_code, status.HTTP_200_OK)
212222
self.assertIn("results", response.data)
213223
self.assertIn("packages", response.data["results"])
@@ -228,7 +238,8 @@ def test_filter_packages_by_purl(self):
228238
Test filtering packages by one or more PURLs.
229239
"""
230240
url = reverse("package-v2-list")
231-
response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json")
241+
with self.assertNumQueries(19):
242+
response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json")
232243
self.assertEqual(response.status_code, status.HTTP_200_OK)
233244
self.assertEqual(len(response.data["results"]["packages"]), 1)
234245
self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2")
@@ -238,7 +249,8 @@ def test_filter_packages_by_affected_vulnerability(self):
238249
Test filtering packages by affected_by_vulnerability.
239250
"""
240251
url = reverse("package-v2-list")
241-
response = self.client.get(url, {"affected_by_vulnerability": "VCID-1234"}, format="json")
252+
with self.assertNumQueries(19):
253+
response = self.client.get(url, {"affected_by_vulnerability": "VCID-1234"}, format="json")
242254
self.assertEqual(response.status_code, status.HTTP_200_OK)
243255
self.assertEqual(len(response.data["results"]["packages"]), 1)
244256
self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2")
@@ -248,29 +260,59 @@ def test_filter_packages_by_fixing_vulnerability(self):
248260
Test filtering packages by fixing_vulnerability.
249261
"""
250262
url = reverse("package-v2-list")
251-
response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json")
263+
with self.assertNumQueries(18):
264+
response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json")
252265
self.assertEqual(response.status_code, status.HTTP_200_OK)
253266
self.assertEqual(len(response.data["results"]["packages"]), 1)
254267
self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:npm/lodash@4.17.20")
255268

256269
def test_package_serializer_fields(self):
257270
"""
258-
Test that the PackageV2Serializer returns the correct fields.
271+
Test that the PackageV2Serializer returns the correct fields and formats them correctly.
259272
"""
273+
# Fetch the package
260274
package = Package.objects.get(package_url="pkg:pypi/django@3.2")
275+
276+
# Ensure prefetched data is available for the serializer
277+
package = (
278+
Package.objects.filter(package_url="pkg:pypi/django@3.2")
279+
.prefetch_related(
280+
Prefetch(
281+
"affected_by_vulnerabilities",
282+
queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"),
283+
to_attr="prefetched_affected_vulnerabilities",
284+
)
285+
)
286+
.first()
287+
)
288+
289+
# Serialize the package
261290
serializer = PackageV2Serializer(package)
262291
data = serializer.data
292+
293+
# Verify the presence of required fields
263294
self.assertIn("purl", data)
264295
self.assertIn("affected_by_vulnerabilities", data)
265296
self.assertIn("fixing_vulnerabilities", data)
266297
self.assertIn("next_non_vulnerable_version", data)
267298
self.assertIn("latest_non_vulnerable_version", data)
299+
self.assertIn("risk_score", data)
300+
301+
# Verify field values
268302
self.assertEqual(data["purl"], "pkg:pypi/django@3.2")
269-
self.assertEqual(
270-
data["affected_by_vulnerabilities"],
271-
{"VCID-1234": {"vulnerability_id": "VCID-1234", "fixed_by_packages": None}},
272-
)
273-
self.assertEqual(data["fixing_vulnerabilities"], [])
303+
self.assertEqual(data["next_non_vulnerable_version"], None)
304+
self.assertEqual(data["latest_non_vulnerable_version"], None)
305+
self.assertEqual(data["risk_score"], None)
306+
307+
# Verify affected_by_vulnerabilities structure
308+
expected_affected_by_vulnerabilities = {
309+
"VCID-1234": {"vulnerability_id": "VCID-1234", "fixed_by_packages": None}
310+
}
311+
self.assertEqual(data["affected_by_vulnerabilities"], expected_affected_by_vulnerabilities)
312+
313+
# Verify fixing_vulnerabilities structure
314+
expected_fixing_vulnerabilities = []
315+
self.assertEqual(data["fixing_vulnerabilities"], expected_fixing_vulnerabilities)
274316

275317
def test_list_packages_pagination(self):
276318
"""
@@ -303,7 +345,8 @@ def test_invalid_vulnerability_filter(self):
303345
Should return an empty list.
304346
"""
305347
url = reverse("package-v2-list")
306-
response = self.client.get(url, {"affected_by_vulnerability": "VCID-9999"}, format="json")
348+
with self.assertNumQueries(4):
349+
response = self.client.get(url, {"affected_by_vulnerability": "VCID-9999"}, format="json")
307350
self.assertEqual(response.status_code, status.HTTP_200_OK)
308351
self.assertEqual(len(response.data["results"]["packages"]), 0)
309352

@@ -313,15 +356,27 @@ def test_invalid_purl_filter(self):
313356
Should return an empty list.
314357
"""
315358
url = reverse("package-v2-list")
316-
response = self.client.get(url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json")
359+
with self.assertNumQueries(4):
360+
response = self.client.get(url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json")
317361
self.assertEqual(response.status_code, status.HTTP_200_OK)
318362
self.assertEqual(len(response.data["results"]["packages"]), 0)
319363

320364
def test_get_affected_by_vulnerabilities(self):
321365
"""
322366
Test the get_affected_by_vulnerabilities method in the serializer.
323367
"""
324-
package = Package.objects.get(package_url="pkg:pypi/django@3.2")
368+
package = (
369+
Package.objects.filter(package_url="pkg:pypi/django@3.2")
370+
.prefetch_related(
371+
Prefetch(
372+
"affected_by_vulnerabilities",
373+
queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"),
374+
to_attr="prefetched_affected_vulnerabilities",
375+
)
376+
)
377+
.first()
378+
)
379+
325380
serializer = PackageV2Serializer()
326381
vulnerabilities = serializer.get_affected_by_vulnerabilities(package)
327382
self.assertEqual(
@@ -345,7 +400,8 @@ def test_bulk_lookup_with_valid_purls(self):
345400
"""
346401
url = reverse("package-v2-bulk-lookup")
347402
data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]}
348-
response = self.client.post(url, data, format="json")
403+
with self.assertNumQueries(28):
404+
response = self.client.post(url, data, format="json")
349405
self.assertEqual(response.status_code, status.HTTP_200_OK)
350406
self.assertIn("packages", response.data)
351407
self.assertIn("vulnerabilities", response.data)
@@ -369,7 +425,8 @@ def test_bulk_lookup_with_invalid_purls(self):
369425
"""
370426
url = reverse("package-v2-bulk-lookup")
371427
data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]}
372-
response = self.client.post(url, data, format="json")
428+
with self.assertNumQueries(4):
429+
response = self.client.post(url, data, format="json")
373430
self.assertEqual(response.status_code, status.HTTP_200_OK)
374431
# Since the packages don't exist, the response should be empty
375432
self.assertEqual(len(response.data["packages"]), 0)
@@ -382,7 +439,8 @@ def test_bulk_lookup_with_empty_purls(self):
382439
"""
383440
url = reverse("package-v2-bulk-lookup")
384441
data = {"purls": []}
385-
response = self.client.post(url, data, format="json")
442+
with self.assertNumQueries(3):
443+
response = self.client.post(url, data, format="json")
386444
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
387445
self.assertIn("error", response.data)
388446
self.assertIn("message", response.data)
@@ -395,7 +453,8 @@ def test_bulk_search_with_valid_purls(self):
395453
"""
396454
url = reverse("package-v2-bulk-search")
397455
data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]}
398-
response = self.client.post(url, data, format="json")
456+
with self.assertNumQueries(28):
457+
response = self.client.post(url, data, format="json")
399458
self.assertEqual(response.status_code, status.HTTP_200_OK)
400459
self.assertIn("packages", response.data)
401460
self.assertIn("vulnerabilities", response.data)
@@ -422,7 +481,8 @@ def test_bulk_search_with_purl_only_true(self):
422481
"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"],
423482
"purl_only": True,
424483
}
425-
response = self.client.post(url, data, format="json")
484+
with self.assertNumQueries(17):
485+
response = self.client.post(url, data, format="json")
426486
self.assertEqual(response.status_code, status.HTTP_200_OK)
427487
# Since purl_only=True, response should be a list of PURLs
428488
self.assertIsInstance(response.data, list)
@@ -448,7 +508,8 @@ def test_bulk_search_with_plain_purl_true(self):
448508
"purls": ["pkg:pypi/django@3.2", "pkg:pypi/django@3.2?extension=tar.gz"],
449509
"plain_purl": True,
450510
}
451-
response = self.client.post(url, data, format="json")
511+
with self.assertNumQueries(16):
512+
response = self.client.post(url, data, format="json")
452513
self.assertEqual(response.status_code, status.HTTP_200_OK)
453514
self.assertIn("packages", response.data)
454515
self.assertIn("vulnerabilities", response.data)
@@ -468,7 +529,8 @@ def test_bulk_search_with_purl_only_and_plain_purl_true(self):
468529
"purl_only": True,
469530
"plain_purl": True,
470531
}
471-
response = self.client.post(url, data, format="json")
532+
with self.assertNumQueries(11):
533+
response = self.client.post(url, data, format="json")
472534
self.assertEqual(response.status_code, status.HTTP_200_OK)
473535
# Response should be a list of plain PURLs
474536
self.assertIsInstance(response.data, list)
@@ -483,7 +545,8 @@ def test_bulk_search_with_invalid_purls(self):
483545
"""
484546
url = reverse("package-v2-bulk-search")
485547
data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]}
486-
response = self.client.post(url, data, format="json")
548+
with self.assertNumQueries(4):
549+
response = self.client.post(url, data, format="json")
487550
self.assertEqual(response.status_code, status.HTTP_200_OK)
488551
# Since the packages don't exist, the response should be empty
489552
self.assertEqual(len(response.data["packages"]), 0)
@@ -496,7 +559,8 @@ def test_bulk_search_with_empty_purls(self):
496559
"""
497560
url = reverse("package-v2-bulk-search")
498561
data = {"purls": []}
499-
response = self.client.post(url, data, format="json")
562+
with self.assertNumQueries(3):
563+
response = self.client.post(url, data, format="json")
500564
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
501565
self.assertIn("error", response.data)
502566
self.assertIn("message", response.data)
@@ -507,7 +571,8 @@ def test_all_vulnerable_packages(self):
507571
Test the 'all' endpoint that returns all vulnerable package URLs.
508572
"""
509573
url = reverse("package-v2-all")
510-
response = self.client.get(url, format="json")
574+
with self.assertNumQueries(4):
575+
response = self.client.get(url, format="json")
511576
self.assertEqual(response.status_code, status.HTTP_200_OK)
512577
# Since package1 is vulnerable, it should be returned
513578
expected_purls = ["pkg:pypi/django@3.2"]
@@ -520,7 +585,8 @@ def test_lookup_with_valid_purl(self):
520585
"""
521586
url = reverse("package-v2-lookup")
522587
data = {"purl": "pkg:pypi/django@3.2"}
523-
response = self.client.post(url, data, format="json")
588+
with self.assertNumQueries(12):
589+
response = self.client.post(url, data, format="json")
524590
self.assertEqual(response.status_code, status.HTTP_200_OK)
525591
self.assertEqual(1, len(response.data))
526592
self.assertIn("purl", response.data[0])
@@ -542,7 +608,8 @@ def test_lookup_with_invalid_purl(self):
542608
"""
543609
url = reverse("package-v2-lookup")
544610
data = {"purl": "pkg:pypi/nonexistent@1.0.0"}
545-
response = self.client.post(url, data, format="json")
611+
with self.assertNumQueries(4):
612+
response = self.client.post(url, data, format="json")
546613
self.assertEqual(response.status_code, status.HTTP_200_OK)
547614
# No packages or vulnerabilities should be returned
548615
self.assertEqual(len(response.data), 0)
@@ -554,7 +621,8 @@ def test_lookup_with_missing_purl(self):
554621
"""
555622
url = reverse("package-v2-lookup")
556623
data = {}
557-
response = self.client.post(url, data, format="json")
624+
with self.assertNumQueries(3):
625+
response = self.client.post(url, data, format="json")
558626
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
559627
self.assertIn("error", response.data)
560628
self.assertIn("message", response.data)
@@ -567,7 +635,8 @@ def test_lookup_with_invalid_purl_format(self):
567635
"""
568636
url = reverse("package-v2-lookup")
569637
data = {"purl": "invalid_purl_format"}
570-
response = self.client.post(url, data, format="json")
638+
with self.assertNumQueries(4):
639+
response = self.client.post(url, data, format="json")
571640
self.assertEqual(response.status_code, status.HTTP_200_OK)
572641
# No packages or vulnerabilities should be returned
573642
self.assertEqual(len(response.data), 0)

0 commit comments

Comments
 (0)