3030DEFAULT_RECOMPRESS_BATCH_SIZE = 100
3131DEFAULT_BATCH_SIZE = 1000
3232MAX_RECOMPRESS_BATCH_BYTES = 100 * 1024 * 1024 # 100 MB
33+ MAX_RECOMPRESS_PARALLEL_BYTES = 500 * 1024 * 1024 # 500 MB
3334
3435
3536class Resolution (enum .Enum ):
@@ -39,7 +40,7 @@ class Resolution(enum.Enum):
3940
4041
4142# https://stackoverflow.com/questions/73395864/how-do-i-wait-when-all-threadpoolexecutor-threads-are-busy
42- class AvailableThreadPoolExecutor (ThreadPoolExecutor ):
43+ class RecompressThreadPoolExecutor (ThreadPoolExecutor ):
4344 """ThreadPoolExecutor that keeps track of the number of available workers.
4445
4546 Refs:
@@ -50,7 +51,7 @@ def __init__(
5051 self , max_workers = None , thread_name_prefix = "" , initializer = None , initargs = ()
5152 ):
5253 super ().__init__ (max_workers , thread_name_prefix , initializer , initargs )
53- self ._running_worker_futures : set [Future ] = set ()
54+ self ._running_worker_futures : dict [Future , int ] = {}
5455
5556 @property
5657 def available_workers (self ) -> int :
@@ -69,16 +70,21 @@ def wait_for_available_worker(self, timeout: "float | None" = None) -> None:
6970
7071 start_time = time .monotonic ()
7172 while True :
72- if self .available_workers > 0 :
73+ if (
74+ self .available_workers > 0
75+ and sum (self ._running_worker_futures .values ())
76+ < MAX_RECOMPRESS_PARALLEL_BYTES
77+ ):
7378 return
7479 if timeout is not None and time .monotonic () - start_time > timeout :
7580 raise TimeoutError
7681 time .sleep (0.1 )
7782
7883 def submit (self , fn , / , * args , ** kwargs ):
84+ size = sum (args [0 ].values ())
7985 f = super ().submit (fn , * args , ** kwargs )
80- self ._running_worker_futures . add ( f )
81- f .add_done_callback (self ._running_worker_futures .remove )
86+ self ._running_worker_futures [ f ] = size
87+ f .add_done_callback (self ._running_worker_futures .pop )
8288 return f
8389
8490
@@ -174,14 +180,14 @@ def overwrite_parallel(compressed_raw_mime_by_sha256: "dict[str, bytes]") -> Non
174180
175181
176182def recompress_batch (
177- recompress_sha256s : "set [str]" , * , dry_run = True , compression_level : int = 3
183+ recompress_sha256s : "dict [str, int ]" , * , dry_run = True , compression_level : int = 3
178184) -> None :
179185 if not recompress_sha256s :
180186 return
181187
182188 data_by_sha256 = {
183189 data_sha256 : data
184- for data_sha256 , data in download_parallel (recompress_sha256s )
190+ for data_sha256 , data in download_parallel (set ( recompress_sha256s ) )
185191 if data is not None and not data .startswith (blockstore .ZSTD_MAGIC_NUMBER_PREFIX )
186192 }
187193
@@ -306,7 +312,7 @@ def shutdown(signum, frame):
306312 assert batch_size > 0
307313 assert recompress_batch_size > 0
308314
309- recompress_executor = AvailableThreadPoolExecutor (
315+ recompress_executor = RecompressThreadPoolExecutor (
310316 max_workers = recompress_executor_workers
311317 )
312318
@@ -322,7 +328,7 @@ def shutdown(signum, frame):
322328 max_size ,
323329 )
324330
325- recompress_sha256s = set ()
331+ recompress_sha256s : dict [ str , int ] = {}
326332 recompress_bytes = 0
327333
328334 max_id = None
@@ -358,7 +364,7 @@ def shutdown(signum, frame):
358364 print (* print_arguments )
359365
360366 if resolution is Resolution .RECOMPRESS :
361- recompress_sha256s . add ( message .data_sha256 )
367+ recompress_sha256s [ message .data_sha256 ] = message . size
362368 recompress_bytes += message .size
363369
364370 if (
0 commit comments