@@ -45,18 +45,19 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
4545class CallbackHandler (BaseHTTPRequestHandler ):
4646 """Simple HTTP handler to capture OAuth callback."""
4747
48- authorization_code = None
49- state = None
50- error = None
48+ def __init__ (self , request , client_address , server , callback_data ):
49+ """Initialize with callback data storage."""
50+ self .callback_data = callback_data
51+ super ().__init__ (request , client_address , server )
5152
5253 def do_GET (self ):
5354 """Handle GET request from OAuth redirect."""
5455 parsed = urlparse (self .path )
5556 query_params = parse_qs (parsed .query )
5657
5758 if "code" in query_params :
58- CallbackHandler . authorization_code = query_params ["code" ][0 ]
59- CallbackHandler . state = query_params .get ("state" , [None ])[0 ]
59+ self . callback_data [ " authorization_code" ] = query_params ["code" ][0 ]
60+ self . callback_data [ " state" ] = query_params .get ("state" , [None ])[0 ]
6061 self .send_response (200 )
6162 self .send_header ("Content-type" , "text/html" )
6263 self .end_headers ()
@@ -70,7 +71,7 @@ def do_GET(self):
7071 </html>
7172 """ )
7273 elif "error" in query_params :
73- CallbackHandler . error = query_params ["error" ][0 ]
74+ self . callback_data [ " error" ] = query_params ["error" ][0 ]
7475 self .send_response (400 )
7576 self .send_header ("Content-type" , "text/html" )
7677 self .end_headers ()
@@ -101,10 +102,26 @@ def __init__(self, port=3000):
101102 self .port = port
102103 self .server = None
103104 self .thread = None
105+ self .callback_data = {
106+ "authorization_code" : None ,
107+ "state" : None ,
108+ "error" : None
109+ }
110+
111+ def _create_handler_with_data (self ):
112+ """Create a handler class with access to callback data."""
113+ callback_data = self .callback_data
114+
115+ class DataCallbackHandler (CallbackHandler ):
116+ def __init__ (self , request , client_address , server ):
117+ super ().__init__ (request , client_address , server , callback_data )
118+
119+ return DataCallbackHandler
104120
105121 def start (self ):
106122 """Start the callback server in a background thread."""
107- self .server = HTTPServer (("localhost" , self .port ), CallbackHandler )
123+ handler_class = self ._create_handler_with_data ()
124+ self .server = HTTPServer (("localhost" , self .port ), handler_class )
108125 self .thread = threading .Thread (target = self .server .serve_forever , daemon = True )
109126 self .thread .start ()
110127 print (f"🖥️ Started callback server on http://localhost:{ self .port } " )
@@ -121,12 +138,16 @@ def wait_for_callback(self, timeout=300):
121138 """Wait for OAuth callback with timeout."""
122139 start_time = time .time ()
123140 while time .time () - start_time < timeout :
124- if CallbackHandler . authorization_code :
125- return CallbackHandler . authorization_code
126- elif CallbackHandler . error :
127- raise Exception (f"OAuth error: { CallbackHandler . error } " )
141+ if self . callback_data [ " authorization_code" ] :
142+ return self . callback_data [ " authorization_code" ]
143+ elif self . callback_data [ " error" ] :
144+ raise Exception (f"OAuth error: { self . callback_data [ ' error' ] } " )
128145 time .sleep (0.1 )
129146 raise Exception ("Timeout waiting for OAuth callback" )
147+
148+ def get_state (self ):
149+ """Get the received state parameter."""
150+ return self .callback_data ["state" ]
130151
131152
132153class SimpleAuthClient :
@@ -153,7 +174,7 @@ async def callback_handler() -> tuple[str, str | None]:
153174 print ("⏳ Waiting for authorization callback..." )
154175 try :
155176 auth_code = callback_server .wait_for_callback (timeout = 300 )
156- return auth_code , CallbackHandler . state
177+ return auth_code , callback_server . get_state ()
157178 finally :
158179 callback_server .stop ()
159180
0 commit comments