diff --git a/distorch_2.py b/distorch_2.py index fea6181..fc5a2fd 100644 --- a/distorch_2.py +++ b/distorch_2.py @@ -19,7 +19,13 @@ from .device_utils import get_device_list, soft_empty_cache_multigpu from .model_management_mgpu import multigpu_memory_log, force_full_system_cleanup +def bc_unpack_block(block_list): + """Backward compatible support for new block scheme in comfy/model_patcher. + New blocks: (module_offload_mem), module_size, module_name, module_object, params + """ + return [[None, *block] if len(block) == 4 else block for block in block_list] + def register_patched_safetensor_modelpatcher(): """Register and patch the ModelPatcher for distributed safetensor loading""" from comfy.model_patcher import wipe_lowvram_weight, move_weight_functions @@ -242,7 +248,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p high_precision_loras = getattr(self.model, "_distorch_high_precision_loras", True) loading = self._load_list() loading.sort(reverse=True) - for module_size, module_name, module_object, params in loading: + for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(loading): if not unpatch_weights and hasattr(module_object, "comfy_patched_weights") and module_object.comfy_patched_weights == True: block_target_device = device_assignments['block_assignments'].get(module_name, device_to) current_module_device = None @@ -321,7 +327,7 @@ def _extract_clip_head_blocks(raw_block_list, compute_device): head_memory = 0 block_assignments = {} - for module_size, module_name, module_object, params in raw_block_list: + for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(raw_block_list): if any(kw in module_name.lower() for kw in head_keywords): head_blocks.append((module_size, module_name, module_object, params)) block_assignments[module_name] = compute_device @@ -423,7 +429,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) total_memory = 0 raw_block_list = model_patcher._load_list() - total_memory = sum(module_size for module_size, _, _, _ in raw_block_list) + total_memory = sum(module_size for _, module_size, _, _, _ in bc_unpack_block(raw_block_list)) MIN_BLOCK_THRESHOLD = total_memory * 0.0001 logger.debug(f"[MultiGPU DisTorch V2] Total model memory: {total_memory} bytes") @@ -441,7 +447,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) # Build all_blocks list for summary (using full raw_block_list) all_blocks = [] - for module_size, module_name, module_object, params in raw_block_list: + for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(raw_block_list): block_type = type(module_object).__name__ # Populate summary dictionaries block_summary[block_type] = block_summary.get(block_type, 0) + 1 @@ -450,7 +456,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False) # Use distributable blocks for actual allocation (for CLIP, this excludes heads) distributable_all_blocks = [] - for module_size, module_name, module_object, params in distributable_raw: + for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(distributable_raw): distributable_all_blocks.append((module_name, module_object, type(module_object).__name__, module_size)) block_list = [b for b in distributable_all_blocks if b[3] >= MIN_BLOCK_THRESHOLD] @@ -581,7 +587,7 @@ def parse_memory_string(mem_str): def calculate_fraction_from_byte_expert_string(model_patcher, byte_str): """Convert byte allocation string (e.g. 'cuda:1,4gb;cpu,*') to fractional VRAM allocation string respecting device order and byte quotas.""" raw_block_list = model_patcher._load_list() - total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list) + total_model_memory = sum(module_size for _, module_size, _, _, _ in bc_unpack_block(raw_block_list)) remaining_model_bytes = total_model_memory # Use a list of tuples to preserve the user-defined order @@ -640,7 +646,7 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str): def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str): """Convert ratio allocation string (e.g. 'cuda:0,25%;cpu,75%') describing model split to fractional VRAM allocation string.""" raw_block_list = model_patcher._load_list() - total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list) + total_model_memory = sum(module_size for _, module_size, _, _, _ in bc_unpack_block(raw_block_list)) raw_ratios = {} for allocation in ratio_str.split(';'):