44"""
55
66import asyncio
7+ import signal
78from typing import Any , Dict
89
9- from ...http_client import ClientSession
10+ from ...http_client import AsyncClientSession , ClientSession
1011from .rp_job import get_job , handle_job
1112from .rp_logger import RunPodLogger
1213from .worker_state import JobsQueue , JobsProgress
@@ -36,26 +37,91 @@ class JobScaler:
3637 Job Scaler. This class is responsible for scaling the number of concurrent requests.
3738 """
3839
39- def __init__ (self , concurrency_modifier : Any ):
40+ def __init__ (self , config : Dict [str , Any ]):
41+ concurrency_modifier = config .get ("concurrency_modifier" )
4042 if concurrency_modifier is None :
4143 self .concurrency_modifier = _default_concurrency_modifier
4244 else :
4345 self .concurrency_modifier = concurrency_modifier
4446
47+ self ._shutdown_event = asyncio .Event ()
4548 self .current_concurrency = 1
46- self ._is_alive = True
49+ self .config = config
50+
51+ def start (self ):
52+ """
53+ This is required for the worker to be able to shut down gracefully
54+ when the user sends a SIGTERM or SIGINT signal. This is typically
55+ the case when the worker is running in a container.
56+ """
57+ try :
58+ # Register signal handlers for graceful shutdown
59+ signal .signal (signal .SIGTERM , self .handle_shutdown )
60+ signal .signal (signal .SIGINT , self .handle_shutdown )
61+ except ValueError :
62+ log .warning ("Signal handling is only supported in the main thread." )
63+
64+ # Start the main loop
65+ # Run forever until the worker is signalled to shut down.
66+ asyncio .run (self .run ())
67+
68+ def handle_shutdown (self , signum , frame ):
69+ """
70+ Called when the worker is signalled to shut down.
71+
72+ This function is called when the worker receives a signal to shut down, such as
73+ SIGTERM or SIGINT. It sets the shutdown event, which will cause the worker to
74+ exit its main loop and shut down gracefully.
75+
76+ Args:
77+ signum: The signal number that was received.
78+ frame: The current stack frame.
79+ """
80+ log .debug (f"Received shutdown signal: { signum } ." )
81+ self .kill_worker ()
82+
83+ async def run (self ):
84+ # Create an async session that will be closed when the worker is killed.
85+
86+ async with AsyncClientSession () as session :
87+ # Create tasks for getting and running jobs.
88+ jobtake_task = asyncio .create_task (self .get_jobs (session ))
89+ jobrun_task = asyncio .create_task (self .run_jobs (session ))
90+
91+ tasks = [jobtake_task , jobrun_task ]
92+
93+ try :
94+ # Concurrently run both tasks and wait for both to finish.
95+ await asyncio .gather (* tasks )
96+ except asyncio .CancelledError : # worker is killed
97+ log .debug ("Worker tasks cancelled." )
98+ self .kill_worker ()
99+ finally :
100+ # Handle the task cancellation gracefully
101+ for task in tasks :
102+ if not task .done ():
103+ task .cancel ()
104+ await asyncio .gather (* tasks , return_exceptions = True )
105+ await self .cleanup () # Ensure resources are cleaned up
106+
107+ async def cleanup (self ):
108+ # Perform any necessary cleanup here, such as closing connections
109+ log .debug ("Cleaning up resources before shutdown." )
110+ # TODO: stop heartbeat or close any open connections
111+ await asyncio .sleep (0 ) # Give a chance for other tasks to run (optional)
112+ log .debug ("Cleanup complete." )
47113
48114 def is_alive (self ):
49115 """
50116 Return whether the worker is alive or not.
51117 """
52- return self ._is_alive
118+ return not self ._shutdown_event . is_set ()
53119
54120 def kill_worker (self ):
55121 """
56122 Whether to kill the worker.
57123 """
58- self ._is_alive = False
124+ self ._shutdown_event . set ()
59125
60126 async def get_jobs (self , session : ClientSession ):
61127 """
@@ -66,38 +132,50 @@ async def get_jobs(self, session: ClientSession):
66132 Adds jobs to the JobsQueue
67133 """
68134 while self .is_alive ():
69- log .debug (f"Jobs in progress: { job_progress .get_job_count ()} " )
70-
71- try :
72- self .current_concurrency = self .concurrency_modifier (
73- self .current_concurrency
74- )
75- log .debug (f"Concurrency set to: { self .current_concurrency } " )
76-
77- jobs_needed = self .current_concurrency - job_progress .get_job_count ()
78- if not jobs_needed : # zero or less
79- log .debug ("Queue is full. Retrying soon." )
80- continue
135+ log .debug (f"JobScaler.get_jobs | Jobs in progress: { job_progress .get_job_count ()} " )
81136
82- acquired_jobs = await get_job ( session , jobs_needed )
83- if not acquired_jobs :
84- log . debug ( "No jobs acquired." )
85- continue
137+ self . current_concurrency = self . concurrency_modifier (
138+ self . current_concurrency
139+ )
140+ log . debug ( f"JobScaler.get_jobs | Concurrency set to: { self . current_concurrency } " )
86141
87- for job in acquired_jobs :
88- await job_list .add_job (job )
89-
90- log .info (f"Jobs in queue: { job_list .get_job_count ()} " )
142+ jobs_needed = self .current_concurrency - job_progress .get_job_count ()
143+ if jobs_needed <= 0 :
144+ log .debug ("JobScaler.get_jobs | Queue is full. Retrying soon." )
145+ await asyncio .sleep (0.1 ) # don't go rapidly
146+ continue
91147
148+ try :
149+ # Keep the connection to the blocking call up to 30 seconds
150+ acquired_jobs = await asyncio .wait_for (
151+ get_job (session , jobs_needed ), timeout = 30
152+ )
153+ except asyncio .CancelledError :
154+ log .debug ("JobScaler.get_jobs | Request was cancelled." )
155+ continue
156+ except TimeoutError :
157+ log .debug ("JobScaler.get_jobs | Job acquisition timed out. Retrying." )
158+ continue
159+ except TypeError as error :
160+ log .debug (f"JobScaler.get_jobs | Unexpected error: { error } ." )
161+ continue
92162 except Exception as error :
93163 log .error (
94164 f"Failed to get job. | Error Type: { type (error ).__name__ } | Error Message: { str (error )} "
95165 )
166+ continue
96167
97- finally :
98- await asyncio .sleep (5 ) # yield control back to the event loop
168+ if not acquired_jobs :
169+ log .debug ("JobScaler.get_jobs | No jobs acquired." )
170+ await asyncio .sleep (0 )
171+ continue
99172
100- async def run_jobs (self , session : ClientSession , config : Dict [str , Any ]):
173+ for job in acquired_jobs :
174+ await job_list .add_job (job )
175+
176+ log .info (f"Jobs in queue: { job_list .get_job_count ()} " )
177+
178+ async def run_jobs (self , session : ClientSession ):
101179 """
102180 Retrieve jobs from the jobs queue and process them concurrently.
103181
@@ -111,7 +189,7 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
111189 job = await job_list .get_job ()
112190
113191 # Create a new task for each job and add it to the task list
114- task = asyncio .create_task (self .handle_job (session , config , job ))
192+ task = asyncio .create_task (self .handle_job (session , job ))
115193 tasks .append (task )
116194
117195 # Wait for any job to finish
@@ -131,19 +209,19 @@ async def run_jobs(self, session: ClientSession, config: Dict[str, Any]):
131209 # Ensure all remaining tasks finish before stopping
132210 await asyncio .gather (* tasks )
133211
134- async def handle_job (self , session : ClientSession , config : Dict [ str , Any ], job ):
212+ async def handle_job (self , session : ClientSession , job : dict ):
135213 """
136214 Process an individual job. This function is run concurrently for multiple jobs.
137215 """
138- log .debug (f"Processing job: { job } " )
216+ log .debug (f"JobScaler.handle_job | { job } " )
139217 job_progress .add (job )
140218
141219 try :
142- await handle_job (session , config , job )
220+ await handle_job (session , self . config , job )
143221
144- if config .get ("refresh_worker" , False ):
222+ if self . config .get ("refresh_worker" , False ):
145223 self .kill_worker ()
146-
224+
147225 except Exception as err :
148226 log .error (f"Error handling job: { err } " , job ["id" ])
149227 raise err
0 commit comments