Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions distorch_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
from .device_utils import get_device_list, soft_empty_cache_multigpu
from .model_management_mgpu import multigpu_memory_log, force_full_system_cleanup

def bc_unpack_block(block_list):
"""Backward compatible support for new block scheme in comfy/model_patcher.

New blocks: (module_offload_mem), module_size, module_name, module_object, params
"""
return [[None, *block] if len(block) == 4 else block for block in block_list]

def register_patched_safetensor_modelpatcher():
"""Register and patch the ModelPatcher for distributed safetensor loading"""
from comfy.model_patcher import wipe_lowvram_weight, move_weight_functions
Expand Down Expand Up @@ -242,7 +248,7 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
high_precision_loras = getattr(self.model, "_distorch_high_precision_loras", True)
loading = self._load_list()
loading.sort(reverse=True)
for module_size, module_name, module_object, params in loading:
for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(loading):
if not unpatch_weights and hasattr(module_object, "comfy_patched_weights") and module_object.comfy_patched_weights == True:
block_target_device = device_assignments['block_assignments'].get(module_name, device_to)
current_module_device = None
Expand Down Expand Up @@ -321,7 +327,7 @@ def _extract_clip_head_blocks(raw_block_list, compute_device):
head_memory = 0
block_assignments = {}

for module_size, module_name, module_object, params in raw_block_list:
for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(raw_block_list):
if any(kw in module_name.lower() for kw in head_keywords):
head_blocks.append((module_size, module_name, module_object, params))
block_assignments[module_name] = compute_device
Expand Down Expand Up @@ -423,7 +429,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)
total_memory = 0

raw_block_list = model_patcher._load_list()
total_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
total_memory = sum(module_size for _, module_size, _, _, _ in bc_unpack_block(raw_block_list))

MIN_BLOCK_THRESHOLD = total_memory * 0.0001
logger.debug(f"[MultiGPU DisTorch V2] Total model memory: {total_memory} bytes")
Expand All @@ -441,7 +447,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)

# Build all_blocks list for summary (using full raw_block_list)
all_blocks = []
for module_size, module_name, module_object, params in raw_block_list:
for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(raw_block_list):
block_type = type(module_object).__name__
# Populate summary dictionaries
block_summary[block_type] = block_summary.get(block_type, 0) + 1
Expand All @@ -450,7 +456,7 @@ def analyze_safetensor_loading(model_patcher, allocations_string, is_clip=False)

# Use distributable blocks for actual allocation (for CLIP, this excludes heads)
distributable_all_blocks = []
for module_size, module_name, module_object, params in distributable_raw:
for module_offload_mem, module_size, module_name, module_object, params in bc_unpack_block(distributable_raw):
distributable_all_blocks.append((module_name, module_object, type(module_object).__name__, module_size))

block_list = [b for b in distributable_all_blocks if b[3] >= MIN_BLOCK_THRESHOLD]
Expand Down Expand Up @@ -581,7 +587,7 @@ def parse_memory_string(mem_str):
def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
"""Convert byte allocation string (e.g. 'cuda:1,4gb;cpu,*') to fractional VRAM allocation string respecting device order and byte quotas."""
raw_block_list = model_patcher._load_list()
total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
total_model_memory = sum(module_size for _, module_size, _, _, _ in bc_unpack_block(raw_block_list))
remaining_model_bytes = total_model_memory

# Use a list of tuples to preserve the user-defined order
Expand Down Expand Up @@ -640,7 +646,7 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str):
"""Convert ratio allocation string (e.g. 'cuda:0,25%;cpu,75%') describing model split to fractional VRAM allocation string."""
raw_block_list = model_patcher._load_list()
total_model_memory = sum(module_size for module_size, _, _, _ in raw_block_list)
total_model_memory = sum(module_size for _, module_size, _, _, _ in bc_unpack_block(raw_block_list))

raw_ratios = {}
for allocation in ratio_str.split(';'):
Expand Down