1010from ...http_client import AsyncClientSession , ClientSession , TooManyRequests
1111from .rp_job import get_job , handle_job
1212from .rp_logger import RunPodLogger
13- from .worker_state import JobsQueue , JobsProgress
13+ from .worker_state import JobsProgress , IS_LOCAL_TEST
1414
1515log = RunPodLogger ()
16- job_list = JobsQueue ()
1716job_progress = JobsProgress ()
1817
1918
@@ -38,16 +37,50 @@ class JobScaler:
3837 """
3938
4039 def __init__ (self , config : Dict [str , Any ]):
41- concurrency_modifier = config .get ("concurrency_modifier" )
42- if concurrency_modifier is None :
43- self .concurrency_modifier = _default_concurrency_modifier
44- else :
45- self .concurrency_modifier = concurrency_modifier
46-
4740 self ._shutdown_event = asyncio .Event ()
4841 self .current_concurrency = 1
4942 self .config = config
5043
44+ self .jobs_queue = asyncio .Queue (maxsize = self .current_concurrency )
45+
46+ self .concurrency_modifier = _default_concurrency_modifier
47+ self .jobs_fetcher = get_job
48+ self .jobs_fetcher_timeout = 90
49+ self .jobs_handler = handle_job
50+
51+ if concurrency_modifier := config .get ("concurrency_modifier" ):
52+ self .concurrency_modifier = concurrency_modifier
53+
54+ if not IS_LOCAL_TEST :
55+ # below cannot be changed unless local
56+ return
57+
58+ if jobs_fetcher := self .config .get ("jobs_fetcher" ):
59+ self .jobs_fetcher = jobs_fetcher
60+
61+ if jobs_fetcher_timeout := self .config .get ("jobs_fetcher_timeout" ):
62+ self .jobs_fetcher_timeout = jobs_fetcher_timeout
63+
64+ if jobs_handler := self .config .get ("jobs_handler" ):
65+ self .jobs_handler = jobs_handler
66+
67+ async def set_scale (self ):
68+ self .current_concurrency = self .concurrency_modifier (self .current_concurrency )
69+
70+ if self .jobs_queue and (self .current_concurrency == self .jobs_queue .maxsize ):
71+ # no need to resize
72+ return
73+
74+ while self .current_occupancy () > 0 :
75+ # not safe to scale when jobs are in flight
76+ await asyncio .sleep (1 )
77+ continue
78+
79+ self .jobs_queue = asyncio .Queue (maxsize = self .current_concurrency )
80+ log .debug (
81+ f"JobScaler.set_scale | New concurrency set to: { self .current_concurrency } "
82+ )
83+
5184 def start (self ):
5285 """
5386 This is required for the worker to be able to shut down gracefully
@@ -105,6 +138,15 @@ def kill_worker(self):
105138 log .info ("Kill worker." )
106139 self ._shutdown_event .set ()
107140
141+ def current_occupancy (self ) -> int :
142+ current_queue_count = self .jobs_queue .qsize ()
143+ current_progress_count = job_progress .get_job_count ()
144+
145+ log .debug (
146+ f"JobScaler.status | concurrency: { self .current_concurrency } ; queue: { current_queue_count } ; progress: { current_progress_count } "
147+ )
148+ return current_progress_count + current_queue_count
149+
108150 async def get_jobs (self , session : ClientSession ):
109151 """
110152 Retrieve multiple jobs from the server in batches using blocking requests.
@@ -114,45 +156,42 @@ async def get_jobs(self, session: ClientSession):
114156 Adds jobs to the JobsQueue
115157 """
116158 while self .is_alive ():
117- log .debug ("JobScaler.get_jobs | Starting job acquisition." )
118-
119- self .current_concurrency = self .concurrency_modifier (
120- self .current_concurrency
121- )
122- log .debug (f"JobScaler.get_jobs | current Concurrency set to: { self .current_concurrency } " )
159+ await self .set_scale ()
123160
124- current_progress_count = await job_progress .get_job_count ()
125- log .debug (f"JobScaler.get_jobs | current Jobs in progress: { current_progress_count } " )
126-
127- current_queue_count = job_list .get_job_count ()
128- log .debug (f"JobScaler.get_jobs | current Jobs in queue: { current_queue_count } " )
129-
130- jobs_needed = self .current_concurrency - current_progress_count - current_queue_count
161+ jobs_needed = self .current_concurrency - self .current_occupancy ()
131162 if jobs_needed <= 0 :
132163 log .debug ("JobScaler.get_jobs | Queue is full. Retrying soon." )
133164 await asyncio .sleep (1 ) # don't go rapidly
134165 continue
135166
136167 try :
137- # Keep the connection to the blocking call up to 30 seconds
168+ log .debug ("JobScaler.get_jobs | Starting job acquisition." )
169+
170+ # Keep the connection to the blocking call with timeout
138171 acquired_jobs = await asyncio .wait_for (
139- get_job (session , jobs_needed ), timeout = 30
172+ self .jobs_fetcher (session , jobs_needed ),
173+ timeout = self .jobs_fetcher_timeout ,
140174 )
141175
142176 if not acquired_jobs :
143177 log .debug ("JobScaler.get_jobs | No jobs acquired." )
144178 continue
145179
146180 for job in acquired_jobs :
147- await job_list .add_job (job )
181+ await self .jobs_queue .put (job )
182+ job_progress .add (job )
183+ log .debug ("Job Queued" , job ["id" ])
148184
149- log .info (f"Jobs in queue: { job_list . get_job_count ()} " )
185+ log .info (f"Jobs in queue: { self . jobs_queue . qsize ()} " )
150186
151187 except TooManyRequests :
152- log .debug (f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds." )
188+ log .debug (
189+ f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds."
190+ )
153191 await asyncio .sleep (5 ) # debounce for 5 seconds
154192 except asyncio .CancelledError :
155193 log .debug ("JobScaler.get_jobs | Request was cancelled." )
194+ raise # CancelledError is a BaseException
156195 except TimeoutError :
157196 log .debug ("JobScaler.get_jobs | Job acquisition timed out. Retrying." )
158197 except TypeError as error :
@@ -173,10 +212,10 @@ async def run_jobs(self, session: ClientSession):
173212 """
174213 tasks = [] # Store the tasks for concurrent job processing
175214
176- while self .is_alive () or not job_list .empty ():
215+ while self .is_alive () or not self . jobs_queue .empty ():
177216 # Fetch as many jobs as the concurrency allows
178- while len (tasks ) < self .current_concurrency and not job_list .empty ():
179- job = await job_list . get_job ()
217+ while len (tasks ) < self .current_concurrency and not self . jobs_queue .empty ():
218+ job = await self . jobs_queue . get ()
180219
181220 # Create a new task for each job and add it to the task list
182221 task = asyncio .create_task (self .handle_job (session , job ))
@@ -204,9 +243,9 @@ async def handle_job(self, session: ClientSession, job: dict):
204243 Process an individual job. This function is run concurrently for multiple jobs.
205244 """
206245 try :
207- await job_progress . add ( job )
246+ log . debug ( "Handling Job" , job [ "id" ] )
208247
209- await handle_job (session , self .config , job )
248+ await self . jobs_handler (session , self .config , job )
210249
211250 if self .config .get ("refresh_worker" , False ):
212251 self .kill_worker ()
@@ -216,8 +255,10 @@ async def handle_job(self, session: ClientSession, job: dict):
216255 raise err
217256
218257 finally :
219- # Inform JobsQueue of a task completion
220- job_list .task_done ()
258+ # Inform Queue of a task completion
259+ self . jobs_queue .task_done ()
221260
222261 # Job is no longer in progress
223- await job_progress .remove (job ["id" ])
262+ job_progress .remove (job )
263+
264+ log .debug ("Finished Job" , job ["id" ])
0 commit comments