Skip to content
Open
101 changes: 82 additions & 19 deletions src/diffusers/models/transformers/transformer_glm_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,41 +143,98 @@ def forward(


class GlmImageLayerKVCache:
"""KV cache for GlmImage model."""
"""KV cache for GlmImage model.

Supports per-sample caching for batch processing where each sample
may have different condition images.
"""

def __init__(self):
self.k_cache = None
self.v_cache = None
# List of (k_cache, v_cache) tuples, one per batch sample
self.k_caches: List[Optional[torch.Tensor]] = []
self.v_caches: List[Optional[torch.Tensor]] = []
self.mode: Optional[str] = None # "write", "read", "skip"
self.current_sample_idx: int = 0 # Current sample index for writing

def store(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache is None:
self.k_cache = k
self.v_cache = v
"""Store KV cache for the current sample."""
# k, v shape: (1, seq_len, num_heads, head_dim)
if len(self.k_caches) <= self.current_sample_idx:
# First time storing for this sample
self.k_caches.append(k)
self.v_caches.append(v)
else:
self.k_cache = torch.cat([self.k_cache, k], dim=1)
self.v_cache = torch.cat([self.v_cache, v], dim=1)
# Append to existing cache for this sample (multiple condition images)
self.k_caches[self.current_sample_idx] = torch.cat(
[self.k_caches[self.current_sample_idx], k], dim=1
)
self.v_caches[self.current_sample_idx] = torch.cat(
[self.v_caches[self.current_sample_idx], v], dim=1
)

def get(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache.shape[0] != k.shape[0]:
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
"""Get combined KV cache for all samples in the batch.

Args:
k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim)
v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim)

Returns:
Combined key and value tensors with cached values prepended.
"""
batch_size = k.shape[0]
num_cached_samples = len(self.k_caches)

if num_cached_samples == 0:
return k, v

if num_cached_samples == 1:
# Single cache, expand for all batch samples (shared condition images)
k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1)
v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1)
elif num_cached_samples == batch_size:
# Per-sample cache, concatenate along batch dimension
k_cache_expanded = torch.cat(self.k_caches, dim=0)
v_cache_expanded = torch.cat(self.v_caches, dim=0)
else:
k_cache_expanded = self.k_cache
v_cache_expanded = self.v_cache
# Mismatch: try to handle by repeating the caches
# This handles cases like num_images_per_prompt > 1
repeat_factor = batch_size // num_cached_samples
if batch_size % num_cached_samples == 0:
k_cache_list = []
v_cache_list = []
for i in range(num_cached_samples):
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
k_cache_expanded = torch.cat(k_cache_list, dim=0)
v_cache_expanded = torch.cat(v_cache_list, dim=0)
Comment on lines +204 to +210
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The KV cache expansion logic doesn't match the repeat_interleave semantics used for prior_token_ids at line 943 of the pipeline. When expanding caches for num_images_per_prompt > 1, the current code concatenates expanded caches sequentially (cache_0 repeated N times, then cache_1 repeated N times), but repeat_interleave produces an interleaved pattern (cache_0, cache_0, cache_1, cache_1). This mismatch will cause incorrect cache retrieval. The expansion should use repeat_interleave instead of expand + cat, or use a similar interleaving pattern.

Suggested change
k_cache_list = []
v_cache_list = []
for i in range(num_cached_samples):
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
k_cache_expanded = torch.cat(k_cache_list, dim=0)
v_cache_expanded = torch.cat(v_cache_list, dim=0)
# Use repeat_interleave semantics to align with prior_token_ids expansion
k_caches_stacked = torch.stack(self.k_caches, dim=0)
v_caches_stacked = torch.stack(self.v_caches, dim=0)
k_cache_expanded = k_caches_stacked.repeat_interleave(repeat_factor, dim=0)
v_cache_expanded = v_caches_stacked.repeat_interleave(repeat_factor, dim=0)

Copilot uses AI. Check for mistakes.
else:
raise ValueError(
f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. "
f"Batch size must be a multiple of the number of cached samples."
)

k_cache = torch.cat([k_cache_expanded, k], dim=1)
v_cache = torch.cat([v_cache_expanded, v], dim=1)
return k_cache, v_cache
k_combined = torch.cat([k_cache_expanded, k], dim=1)
v_combined = torch.cat([v_cache_expanded, v], dim=1)
return k_combined, v_combined

def clear(self):
self.k_cache = None
self.v_cache = None
self.k_caches = []
self.v_caches = []
self.mode = None
self.current_sample_idx = 0

def next_sample(self):
"""Move to the next sample for writing."""
self.current_sample_idx += 1


class GlmImageKVCache:
"""Container for all layers' KV caches."""
"""Container for all layers' KV caches.

Supports per-sample caching for batch processing where each sample
may have different condition images.
"""

def __init__(self, num_layers: int):
self.num_layers = num_layers
Expand All @@ -192,6 +249,12 @@ def set_mode(self, mode: Optional[str]):
for cache in self.caches:
cache.mode = mode

def next_sample(self):
"""Move to the next sample for writing. Call this after processing
all condition images for one batch sample."""
for cache in self.caches:
cache.next_sample()

def clear(self):
for cache in self.caches:
cache.clear()
Expand Down
Loading