33from collections .abc import Iterator
44from types import MethodType , SimpleNamespace , TracebackType
55from typing import cast
6+ from unittest .mock import AsyncMock
67
78import httpx
89import pytest
@@ -433,18 +434,13 @@ async def test_client_credentials_ensure_token_returns_when_valid() -> None:
433434 provider ._current_tokens = OAuthToken (access_token = "token" )
434435 provider ._token_expiry_time = time .time () + 60
435436
436- request_called = False
437-
438- async def fake_request_token () -> None :
439- nonlocal request_called
440- request_called = True
441-
437+ fake_request_token = AsyncMock ()
442438 provider ._request_token = fake_request_token # type: ignore[assignment]
443439
444440 await provider .ensure_token ()
445441
446442 assert provider ._current_tokens is not None
447- assert not request_called
443+ fake_request_token . assert_not_awaited ()
448444
449445
450446@pytest .mark .anyio
@@ -523,19 +519,16 @@ async def test_token_exchange_request_token(monkeypatch: pytest.MonkeyPatch) ->
523519 storage = InMemoryStorage ()
524520 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris (), scope = "alpha" )
525521
526- async def provide_subject () -> str :
527- return "subject-token"
528-
529- async def provide_actor () -> str :
530- return "actor-token"
522+ subject_supplier = AsyncMock (return_value = "subject-token" )
523+ actor_supplier = AsyncMock (return_value = "actor-token" )
531524
532525 provider = TokenExchangeProvider (
533526 "https://api.example.com/service" ,
534527 client_metadata ,
535528 storage ,
536- subject_token_supplier = provide_subject ,
529+ subject_token_supplier = subject_supplier ,
537530 subject_token_type = "access_token" ,
538- actor_token_supplier = provide_actor ,
531+ actor_token_supplier = actor_supplier ,
539532 actor_token_type = "jwt" ,
540533 audience = "https://audience.example.com" ,
541534 resource = "https://resource.example.com" ,
@@ -558,26 +551,25 @@ async def provide_actor() -> str:
558551 assert storage .tokens .access_token == "access-token"
559552 assert provider ._current_tokens is storage .tokens
560553 assert provider ._token_expiry_time is not None
554+ subject_supplier .assert_awaited_once ()
555+ actor_supplier .assert_awaited_once ()
561556
562557
563558@pytest .mark .anyio
564559async def test_token_exchange_request_token_handles_invalid_metadata (monkeypatch : pytest .MonkeyPatch ) -> None :
565560 storage = InMemoryStorage ()
566561 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris (), scope = "alpha" )
567562
568- async def provide_subject () -> str :
569- return "subject-token"
570-
571- async def provide_actor () -> str :
572- return "actor-token"
563+ subject_supplier = AsyncMock (return_value = "subject-token" )
564+ actor_supplier = AsyncMock (return_value = "actor-token" )
573565
574566 provider = TokenExchangeProvider (
575567 "https://api.example.com/service" ,
576568 client_metadata ,
577569 storage ,
578- subject_token_supplier = provide_subject ,
570+ subject_token_supplier = subject_supplier ,
579571 subject_token_type = "access_token" ,
580- actor_token_supplier = provide_actor ,
572+ actor_token_supplier = actor_supplier ,
581573 actor_token_type = "jwt" ,
582574 audience = "https://audience.example.com" ,
583575 )
@@ -608,21 +600,20 @@ async def provide_actor() -> str:
608600 assert storage .tokens is not None
609601 assert storage .tokens .access_token == "exchange-token"
610602 assert provider ._token_expiry_time is None
603+ subject_supplier .assert_awaited_once ()
604+ actor_supplier .assert_awaited_once ()
611605
612606
613607@pytest .mark .anyio
614608async def test_token_exchange_request_token_raises_on_failure (monkeypatch : pytest .MonkeyPatch ) -> None :
615609 storage = InMemoryStorage ()
616610 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris ())
617611
618- async def provide_subject () -> str :
619- return "subject-token"
620-
621612 provider = TokenExchangeProvider (
622613 "https://api.example.com/service" ,
623614 client_metadata ,
624615 storage ,
625- subject_token_supplier = provide_subject ,
616+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
626617 )
627618
628619 provider ._metadata = OAuthMetadata .model_validate (_metadata_json ())
@@ -639,14 +630,11 @@ def test_token_exchange_has_valid_token_checks_expiry() -> None:
639630 storage = InMemoryStorage ()
640631 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris ())
641632
642- async def provide_subject () -> str :
643- return "subject-token"
644-
645633 provider = TokenExchangeProvider (
646634 "https://api.example.com/service" ,
647635 client_metadata ,
648636 storage ,
649- subject_token_supplier = provide_subject ,
637+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
650638 )
651639
652640 provider ._current_tokens = OAuthToken (access_token = "token" )
@@ -660,14 +648,11 @@ async def test_token_exchange_validate_token_scopes_returns_when_missing() -> No
660648 storage = InMemoryStorage ()
661649 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris (), scope = "alpha" )
662650
663- async def provide_subject () -> str :
664- return "subject-token"
665-
666651 provider = TokenExchangeProvider (
667652 "https://api.example.com/service" ,
668653 client_metadata ,
669654 storage ,
670- subject_token_supplier = provide_subject ,
655+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
671656 )
672657
673658 token = OAuthToken (access_token = "token" , scope = None )
@@ -680,14 +665,11 @@ async def test_token_exchange_get_or_register_client(monkeypatch: pytest.MonkeyP
680665 storage = InMemoryStorage ()
681666 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris ())
682667
683- async def provide_subject () -> str :
684- return "subject-token"
685-
686668 provider = TokenExchangeProvider (
687669 "https://api.example.com/service" ,
688670 client_metadata ,
689671 storage ,
690- subject_token_supplier = provide_subject ,
672+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
691673 )
692674
693675 registration_response = _make_response (200 , json_data = _registration_json ())
@@ -710,14 +692,11 @@ async def test_token_exchange_initialize_loads_cached_values() -> None:
710692
711693 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris ())
712694
713- async def provide_subject () -> str :
714- return "subject-token"
715-
716695 provider = TokenExchangeProvider (
717696 "https://api.example.com/service" ,
718697 client_metadata ,
719698 storage ,
720- subject_token_supplier = provide_subject ,
699+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
721700 )
722701
723702 await provider .initialize ()
@@ -731,14 +710,11 @@ async def test_token_exchange_validate_token_scopes_rejects_extra() -> None:
731710 storage = InMemoryStorage ()
732711 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris (), scope = "alpha" )
733712
734- async def provide_subject () -> str :
735- return "subject-token"
736-
737713 provider = TokenExchangeProvider (
738714 "https://api.example.com/service" ,
739715 client_metadata ,
740716 storage ,
741- subject_token_supplier = provide_subject ,
717+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
742718 )
743719
744720 token = OAuthToken (access_token = "token" , scope = "alpha beta" )
@@ -752,14 +728,11 @@ async def test_token_exchange_validate_token_scopes_accepts_server_defined() ->
752728 storage = InMemoryStorage ()
753729 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris (), scope = None )
754730
755- async def provide_subject () -> str :
756- return "subject-token"
757-
758731 provider = TokenExchangeProvider (
759732 "https://api.example.com/service" ,
760733 client_metadata ,
761734 storage ,
762- subject_token_supplier = provide_subject ,
735+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
763736 )
764737
765738 token = OAuthToken (access_token = "token" , scope = "delta" )
@@ -772,14 +745,11 @@ async def test_token_exchange_async_auth_flow_handles_401(monkeypatch: pytest.Mo
772745 storage = InMemoryStorage ()
773746 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris ())
774747
775- async def provide_subject () -> str :
776- return "subject-token"
777-
778748 provider = TokenExchangeProvider (
779749 "https://api.example.com/service" ,
780750 client_metadata ,
781751 storage ,
782- subject_token_supplier = provide_subject ,
752+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
783753 )
784754
785755 async def fake_initialize () -> None :
@@ -809,14 +779,11 @@ async def test_token_exchange_async_auth_flow_with_cached_token() -> None:
809779 storage = InMemoryStorage ()
810780 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris ())
811781
812- async def provide_subject () -> str :
813- return "subject-token"
814-
815782 provider = TokenExchangeProvider (
816783 "https://api.example.com/service" ,
817784 client_metadata ,
818785 storage ,
819- subject_token_supplier = provide_subject ,
786+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
820787 )
821788
822789 provider ._current_tokens = OAuthToken (access_token = "cached" )
@@ -838,31 +805,23 @@ async def test_token_exchange_ensure_token_returns_when_valid() -> None:
838805 storage = InMemoryStorage ()
839806 client_metadata = OAuthClientMetadata (redirect_uris = _redirect_uris ())
840807
841- async def provide_subject () -> str :
842- return "subject-token"
843-
844808 provider = TokenExchangeProvider (
845809 "https://api.example.com/service" ,
846810 client_metadata ,
847811 storage ,
848- subject_token_supplier = provide_subject ,
812+ subject_token_supplier = AsyncMock ( return_value = "subject-token" ) ,
849813 )
850814
851815 provider ._current_tokens = OAuthToken (access_token = "token" )
852816 provider ._token_expiry_time = time .time () + 60
853817
854- request_called = False
855-
856- async def fake_request_token () -> None :
857- nonlocal request_called
858- request_called = True
859-
818+ fake_request_token = AsyncMock ()
860819 provider ._request_token = fake_request_token # type: ignore[assignment]
861820
862821 await provider .ensure_token ()
863822
864823 assert provider ._current_tokens is not None
865- assert not request_called
824+ fake_request_token . assert_not_awaited ()
866825
867826
868827@pytest .mark .anyio
0 commit comments