Skip to content

Commit 6d7cc85

Browse files
authored
Merge pull request #37 from sacha-development-stuff/codex/investigate-test-coverage-gaps
Use AsyncMock in OAuth2 provider tests
2 parents 9945699 + a09f355 commit 6d7cc85

File tree

1 file changed

+27
-68
lines changed

1 file changed

+27
-68
lines changed

tests/unit/client/test_oauth2_providers.py

Lines changed: 27 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Iterator
44
from types import MethodType, SimpleNamespace, TracebackType
55
from typing import cast
6+
from unittest.mock import AsyncMock
67

78
import httpx
89
import pytest
@@ -433,18 +434,13 @@ async def test_client_credentials_ensure_token_returns_when_valid() -> None:
433434
provider._current_tokens = OAuthToken(access_token="token")
434435
provider._token_expiry_time = time.time() + 60
435436

436-
request_called = False
437-
438-
async def fake_request_token() -> None:
439-
nonlocal request_called
440-
request_called = True
441-
437+
fake_request_token = AsyncMock()
442438
provider._request_token = fake_request_token # type: ignore[assignment]
443439

444440
await provider.ensure_token()
445441

446442
assert provider._current_tokens is not None
447-
assert not request_called
443+
fake_request_token.assert_not_awaited()
448444

449445

450446
@pytest.mark.anyio
@@ -523,19 +519,16 @@ async def test_token_exchange_request_token(monkeypatch: pytest.MonkeyPatch) ->
523519
storage = InMemoryStorage()
524520
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
525521

526-
async def provide_subject() -> str:
527-
return "subject-token"
528-
529-
async def provide_actor() -> str:
530-
return "actor-token"
522+
subject_supplier = AsyncMock(return_value="subject-token")
523+
actor_supplier = AsyncMock(return_value="actor-token")
531524

532525
provider = TokenExchangeProvider(
533526
"https://api.example.com/service",
534527
client_metadata,
535528
storage,
536-
subject_token_supplier=provide_subject,
529+
subject_token_supplier=subject_supplier,
537530
subject_token_type="access_token",
538-
actor_token_supplier=provide_actor,
531+
actor_token_supplier=actor_supplier,
539532
actor_token_type="jwt",
540533
audience="https://audience.example.com",
541534
resource="https://resource.example.com",
@@ -558,26 +551,25 @@ async def provide_actor() -> str:
558551
assert storage.tokens.access_token == "access-token"
559552
assert provider._current_tokens is storage.tokens
560553
assert provider._token_expiry_time is not None
554+
subject_supplier.assert_awaited_once()
555+
actor_supplier.assert_awaited_once()
561556

562557

563558
@pytest.mark.anyio
564559
async def test_token_exchange_request_token_handles_invalid_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
565560
storage = InMemoryStorage()
566561
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
567562

568-
async def provide_subject() -> str:
569-
return "subject-token"
570-
571-
async def provide_actor() -> str:
572-
return "actor-token"
563+
subject_supplier = AsyncMock(return_value="subject-token")
564+
actor_supplier = AsyncMock(return_value="actor-token")
573565

574566
provider = TokenExchangeProvider(
575567
"https://api.example.com/service",
576568
client_metadata,
577569
storage,
578-
subject_token_supplier=provide_subject,
570+
subject_token_supplier=subject_supplier,
579571
subject_token_type="access_token",
580-
actor_token_supplier=provide_actor,
572+
actor_token_supplier=actor_supplier,
581573
actor_token_type="jwt",
582574
audience="https://audience.example.com",
583575
)
@@ -608,21 +600,20 @@ async def provide_actor() -> str:
608600
assert storage.tokens is not None
609601
assert storage.tokens.access_token == "exchange-token"
610602
assert provider._token_expiry_time is None
603+
subject_supplier.assert_awaited_once()
604+
actor_supplier.assert_awaited_once()
611605

612606

613607
@pytest.mark.anyio
614608
async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None:
615609
storage = InMemoryStorage()
616610
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
617611

618-
async def provide_subject() -> str:
619-
return "subject-token"
620-
621612
provider = TokenExchangeProvider(
622613
"https://api.example.com/service",
623614
client_metadata,
624615
storage,
625-
subject_token_supplier=provide_subject,
616+
subject_token_supplier=AsyncMock(return_value="subject-token"),
626617
)
627618

628619
provider._metadata = OAuthMetadata.model_validate(_metadata_json())
@@ -639,14 +630,11 @@ def test_token_exchange_has_valid_token_checks_expiry() -> None:
639630
storage = InMemoryStorage()
640631
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
641632

642-
async def provide_subject() -> str:
643-
return "subject-token"
644-
645633
provider = TokenExchangeProvider(
646634
"https://api.example.com/service",
647635
client_metadata,
648636
storage,
649-
subject_token_supplier=provide_subject,
637+
subject_token_supplier=AsyncMock(return_value="subject-token"),
650638
)
651639

652640
provider._current_tokens = OAuthToken(access_token="token")
@@ -660,14 +648,11 @@ async def test_token_exchange_validate_token_scopes_returns_when_missing() -> No
660648
storage = InMemoryStorage()
661649
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
662650

663-
async def provide_subject() -> str:
664-
return "subject-token"
665-
666651
provider = TokenExchangeProvider(
667652
"https://api.example.com/service",
668653
client_metadata,
669654
storage,
670-
subject_token_supplier=provide_subject,
655+
subject_token_supplier=AsyncMock(return_value="subject-token"),
671656
)
672657

673658
token = OAuthToken(access_token="token", scope=None)
@@ -680,14 +665,11 @@ async def test_token_exchange_get_or_register_client(monkeypatch: pytest.MonkeyP
680665
storage = InMemoryStorage()
681666
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
682667

683-
async def provide_subject() -> str:
684-
return "subject-token"
685-
686668
provider = TokenExchangeProvider(
687669
"https://api.example.com/service",
688670
client_metadata,
689671
storage,
690-
subject_token_supplier=provide_subject,
672+
subject_token_supplier=AsyncMock(return_value="subject-token"),
691673
)
692674

693675
registration_response = _make_response(200, json_data=_registration_json())
@@ -710,14 +692,11 @@ async def test_token_exchange_initialize_loads_cached_values() -> None:
710692

711693
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
712694

713-
async def provide_subject() -> str:
714-
return "subject-token"
715-
716695
provider = TokenExchangeProvider(
717696
"https://api.example.com/service",
718697
client_metadata,
719698
storage,
720-
subject_token_supplier=provide_subject,
699+
subject_token_supplier=AsyncMock(return_value="subject-token"),
721700
)
722701

723702
await provider.initialize()
@@ -731,14 +710,11 @@ async def test_token_exchange_validate_token_scopes_rejects_extra() -> None:
731710
storage = InMemoryStorage()
732711
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
733712

734-
async def provide_subject() -> str:
735-
return "subject-token"
736-
737713
provider = TokenExchangeProvider(
738714
"https://api.example.com/service",
739715
client_metadata,
740716
storage,
741-
subject_token_supplier=provide_subject,
717+
subject_token_supplier=AsyncMock(return_value="subject-token"),
742718
)
743719

744720
token = OAuthToken(access_token="token", scope="alpha beta")
@@ -752,14 +728,11 @@ async def test_token_exchange_validate_token_scopes_accepts_server_defined() ->
752728
storage = InMemoryStorage()
753729
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope=None)
754730

755-
async def provide_subject() -> str:
756-
return "subject-token"
757-
758731
provider = TokenExchangeProvider(
759732
"https://api.example.com/service",
760733
client_metadata,
761734
storage,
762-
subject_token_supplier=provide_subject,
735+
subject_token_supplier=AsyncMock(return_value="subject-token"),
763736
)
764737

765738
token = OAuthToken(access_token="token", scope="delta")
@@ -772,14 +745,11 @@ async def test_token_exchange_async_auth_flow_handles_401(monkeypatch: pytest.Mo
772745
storage = InMemoryStorage()
773746
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
774747

775-
async def provide_subject() -> str:
776-
return "subject-token"
777-
778748
provider = TokenExchangeProvider(
779749
"https://api.example.com/service",
780750
client_metadata,
781751
storage,
782-
subject_token_supplier=provide_subject,
752+
subject_token_supplier=AsyncMock(return_value="subject-token"),
783753
)
784754

785755
async def fake_initialize() -> None:
@@ -809,14 +779,11 @@ async def test_token_exchange_async_auth_flow_with_cached_token() -> None:
809779
storage = InMemoryStorage()
810780
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
811781

812-
async def provide_subject() -> str:
813-
return "subject-token"
814-
815782
provider = TokenExchangeProvider(
816783
"https://api.example.com/service",
817784
client_metadata,
818785
storage,
819-
subject_token_supplier=provide_subject,
786+
subject_token_supplier=AsyncMock(return_value="subject-token"),
820787
)
821788

822789
provider._current_tokens = OAuthToken(access_token="cached")
@@ -838,31 +805,23 @@ async def test_token_exchange_ensure_token_returns_when_valid() -> None:
838805
storage = InMemoryStorage()
839806
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
840807

841-
async def provide_subject() -> str:
842-
return "subject-token"
843-
844808
provider = TokenExchangeProvider(
845809
"https://api.example.com/service",
846810
client_metadata,
847811
storage,
848-
subject_token_supplier=provide_subject,
812+
subject_token_supplier=AsyncMock(return_value="subject-token"),
849813
)
850814

851815
provider._current_tokens = OAuthToken(access_token="token")
852816
provider._token_expiry_time = time.time() + 60
853817

854-
request_called = False
855-
856-
async def fake_request_token() -> None:
857-
nonlocal request_called
858-
request_called = True
859-
818+
fake_request_token = AsyncMock()
860819
provider._request_token = fake_request_token # type: ignore[assignment]
861820

862821
await provider.ensure_token()
863822

864823
assert provider._current_tokens is not None
865-
assert not request_called
824+
fake_request_token.assert_not_awaited()
866825

867826

868827
@pytest.mark.anyio

0 commit comments

Comments
 (0)