11import os
22import sys
3+ import threading
34from contextlib import contextmanager
4- from contextvars import ContextVar
55from dataclasses import dataclass
66from typing import Any , Callable , Dict , Literal , Optional , Tuple
77
1313from .run import _has_output_iterator_array_type
1414from .version import Version
1515
16- __all__ = ["include" ]
16+ __all__ = ["get_run_state" , "get_run_token" , " include" , "run_state" , "run_token " ]
1717
1818
19- _RUN_STATE : ContextVar [Literal ["load" , "setup" , "run" ] | None ] = ContextVar (
20- "run_state" ,
21- default = None ,
22- )
23- _RUN_TOKEN : ContextVar [str | None ] = ContextVar ("run_token" , default = None )
19+ _run_state : Optional [Literal ["load" , "setup" , "run" ]] = None
20+ _run_token : Optional [str ] = None
21+
22+ _state_stack = []
23+ _token_stack = []
24+
25+ _state_lock = threading .RLock ()
26+ _token_lock = threading .RLock ()
27+
28+
29+ def get_run_state () -> Optional [Literal ["load" , "setup" , "run" ]]:
30+ """
31+ Get the current run state.
32+ """
33+ return _run_state
34+
35+
36+ def get_run_token () -> Optional [str ]:
37+ """
38+ Get the current API token.
39+ """
40+ return _run_token
2441
2542
2643@contextmanager
2744def run_state (state : Literal ["load" , "setup" , "run" ]) -> Any :
2845 """
29- Internal context manager for execution state.
46+ Context manager for setting the current run state.
3047 """
31- s = _RUN_STATE .set (state )
48+ global _run_state
49+
50+ if threading .current_thread () is not threading .main_thread ():
51+ raise RuntimeError ("Only the main thread can modify run state" )
52+
53+ with _state_lock :
54+ _state_stack .append (_run_state )
55+
56+ _run_state = state
57+
3258 try :
3359 yield
3460 finally :
35- _RUN_STATE .reset (s )
61+ with _state_lock :
62+ _run_state = _state_stack .pop ()
3663
3764
3865@contextmanager
3966def run_token (token : str ) -> Any :
4067 """
41- Sets the API token for the current context .
68+ Context manager for setting the current API token .
4269 """
43- t = _RUN_TOKEN .set (token )
70+ global _run_token
71+
72+ if threading .current_thread () is not threading .main_thread ():
73+ raise RuntimeError ("Only the main thread can modify API token" )
74+
75+ with _token_lock :
76+ _token_stack .append (_run_token )
77+
78+ _run_token = token
79+
4480 try :
4581 yield
4682 finally :
47- _RUN_TOKEN .reset (t )
83+ with _token_lock :
84+ _run_token = _token_stack .pop ()
4885
4986
5087def _find_api_token () -> str :
@@ -53,12 +90,11 @@ def _find_api_token() -> str:
5390 print ("Using Replicate API token from environment" , file = sys .stderr )
5491 return token
5592
56- token = _RUN_TOKEN .get ()
57-
58- if not token :
93+ current_token = get_run_token ()
94+ if current_token is None :
5995 raise ValueError ("No run token found" )
6096
61- return token
97+ return current_token
6298
6399
64100@dataclass
@@ -158,7 +194,7 @@ def include(function_ref: str) -> Callable[..., Any]:
158194
159195 This function can only be called at the top level.
160196 """
161- if _RUN_STATE . get () != "load" :
197+ if get_run_state () != "load" :
162198 raise RuntimeError ("You may only call replicate.include at the top level." )
163199
164200 return Function (function_ref )
0 commit comments