Skip to content

Conversation

@JaredforReal
Copy link

@JaredforReal JaredforReal commented Jan 21, 2026

This PR enables batch processing (batch_size > 1) for both text-to-image and image-to-image generation in GlmImagePipeline.

Need to be used with Transfomers GLM Image batch support

Changes

pipeline_glm_image.py

  • generate_prior_tokens(): Refactored to support batch processing with multiple prompts/images
    • Accepts List[str] for prompts and List[List[PIL.Image]] for per-prompt condition images
    • Uses left-padding for batch tokenization (compatible with transformers)
    • Properly handles images_per_sample and num_source_images_per_sample from processor
    • Upsamples prior tokens from AR model resolution (d32) to VAE/DiT resolution (d8)
    • Extracts seed from generator and sets torch.manual_seed() for reproducibility
  • get_glyph_texts() / _get_glyph_embeds(): Updated for batch processing
  • encode_prompt(): Simplified repeat logic using repeat_interleave()
  • __call__():
    • Removed batch_size == 1 restriction
    • Added homogeneous batch validation (all samples must have same number of condition images)
    • Supports List[List[PIL.Image]] image format for per-prompt condition images
    • Properly splits prior tokens by sample for KV cache writing

transformer_glm_image.py

  • GlmImageLayerKVCache: Refactored to support per-sample caching
    • Stores separate KV caches per batch sample
    • Added next_sample() method to advance sample index during writing
    • get() method handles batch expansion for num_images_per_prompt > 1
  • GlmImageKVCache: Added next_sample() method

test_glm_image.py

  • Added 4 new batch processing tests
  • Updated expected slice values

Supported Scenarios

batch_size Condition Images Description
1 N images Single prompt with any number of condition images
N 1 image each Multiple prompts, each with 1 condition image
N M images each Multiple prompts, each with M condition images (homogeneous)

Breaking Changes

None. Legacy image format List[PIL.Image] is still supported and automatically normalized.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables batch processing (batch_size > 1) for the GlmImagePipeline, supporting both text-to-image and image-to-image generation. The implementation adds support for processing multiple prompts simultaneously and allows each prompt to have multiple condition images in a homogeneous batch.

Changes:

  • Refactored generate_prior_tokens() to support batch processing with left-padding and proper handling of multiple prompts/images
  • Updated KV cache implementation to support per-sample caching for batch processing
  • Modified glyph text extraction and embedding generation to handle batched prompts
  • Simplified encode_prompt() logic using repeat_interleave()
  • Added comprehensive batch processing tests

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 8 comments.

File Description
src/diffusers/pipelines/glm_image/pipeline_glm_image.py Core pipeline changes enabling batch support with input normalization, AR token generation, and KV cache management for batched condition images
src/diffusers/models/transformers/transformer_glm_image.py Refactored KV cache classes to support per-sample storage and retrieval with batch expansion logic
tests/pipelines/glm_image/test_glm_image.py Added 4 new test methods covering batch processing scenarios and updated expected output values

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +878 to +879
continue

Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When len(prompt_grid_thw) == 0 at line 877, the code continues to the next iteration but never calls kv_caches.next_sample(). This means the cache index won't be incremented for samples without condition images, potentially causing misalignment between the cache indices and batch samples. The next_sample() call should be made even when continuing early.

Suggested change
continue
# Even if there are no condition images, advance to the next cache slot
kv_caches.next_sample()
continue

Copilot uses AI. Check for mistakes.
Returns:
prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens
prior_token_image_ids: Tensor with upsampled source image tokens (or None for t2i)
source_image_grid_thw: Upsampled grid info for splitting prior_token_image_ids
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring lists only 3 return values but the function actually returns 4 values: prior_token_ids, prior_token_image_ids, source_image_grid_thw, and num_source_images_per_sample. The docstring should document all return values including the fourth one.

Suggested change
source_image_grid_thw: Upsampled grid info for splitting prior_token_image_ids
source_image_grid_thw: Upsampled grid info for splitting prior_token_image_ids
num_source_images_per_sample: Number of source images used for each prompt/sample

Copilot uses AI. Check for mistakes.
Comment on lines +417 to +418
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
all_prior_token_ids.append(prior_token_ids)
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name prior_token_ids at line 417 shadows the function's return variable that is assigned at line 420. Inside the loop, this local variable holds upsampled tokens for a single sample, but at line 420, it's used to accumulate all samples. This shadowing reduces code clarity. Consider renaming the loop variable to something like sample_prior_token_ids to avoid confusion.

Suggested change
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
all_prior_token_ids.append(prior_token_ids)
sample_prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
all_prior_token_ids.append(sample_prior_token_ids)

Copilot uses AI. Check for mistakes.
Comment on lines 180 to 283
def test_inference_batch_single_identical(self):
# GLM-Image has batch_size=1 constraint due to AR model
pass
"""Test that batch=1 produces consistent results with the same seed."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

@unittest.skip("Not supported.")
def test_inference_batch_consistent(self):
# GLM-Image has batch_size=1 constraint due to AR model
pass
# Run twice with same seed
inputs1 = self.get_dummy_inputs(device, seed=42)
inputs2 = self.get_dummy_inputs(device, seed=42)

image1 = pipe(**inputs1).images[0]
image2 = pipe(**inputs2).images[0]

self.assertTrue(torch.allclose(image1, image2, atol=1e-4))

def test_inference_batch_multiple_prompts(self):
"""Test batch processing with multiple prompts."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device=device).manual_seed(42)
height, width = 32, 32

inputs = {
"prompt": ["A photo of a cat", "A photo of a dog"],
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}

images = pipe(**inputs).images

# Should return 2 images
self.assertEqual(len(images), 2)
self.assertEqual(images[0].shape, (3, 32, 32))
self.assertEqual(images[1].shape, (3, 32, 32))

@unittest.skip("Not supported.")
def test_num_images_per_prompt(self):
# GLM-Image has batch_size=1 constraint due to AR model
pass
"""Test generating multiple images per prompt."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device=device).manual_seed(42)
height, width = 32, 32

inputs = {
"prompt": "A photo of a cat",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
"num_images_per_prompt": 2,
}

images = pipe(**inputs).images

# Should return 2 images for single prompt
self.assertEqual(len(images), 2)
self.assertEqual(images[0].shape, (3, 32, 32))
self.assertEqual(images[1].shape, (3, 32, 32))

def test_batch_with_num_images_per_prompt(self):
"""Test batch prompts with num_images_per_prompt > 1."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

generator = torch.Generator(device=device).manual_seed(42)
height, width = 32, 32

inputs = {
"prompt": ["A photo of a cat", "A photo of a dog"],
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.5,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
"num_images_per_prompt": 2,
}

images = pipe(**inputs).images

# Should return 4 images (2 prompts × 2 images per prompt)
self.assertEqual(len(images), 4)

Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests only cover text-to-image batch processing scenarios. Given that this PR adds batch support for both text-to-image and image-to-image generation (as stated in the description), there should be tests covering image-to-image scenarios with batched condition images, such as multiple prompts each with one or more condition images. This would help validate the KV cache handling and prior token splitting logic for image-to-image generation.

Copilot uses AI. Check for mistakes.
for p in prompt:
ocr_texts = (
re.findall(r"'([^']*)'", p)
+ re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p)
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The escaped unicode characters in the regex pattern appear incorrect. The pattern \u201c([^\u201c\u201d]*)\u201d uses unicode escape sequences that won't match properly in a raw string context. These should either be actual unicode characters or the string should not be a raw string. The original code used "([^""]*)" with actual curly quote characters which was correct.

Suggested change
+ re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p)
+ re.findall(r"([^“”]*)", p)

Copilot uses AI. Check for mistakes.
image1 = pipe(**inputs1).images[0]
image2 = pipe(**inputs2).images[0]

self.assertTrue(torch.allclose(image1, image2, atol=1e-4))
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch.allclose call is missing the rtol parameter. For consistency with the other test at line 178 which uses both atol=1e-4 and rtol=1e-4, this test should also specify rtol=1e-4 to ensure consistent tolerance checking across all tests.

Suggested change
self.assertTrue(torch.allclose(image1, image2, atol=1e-4))
self.assertTrue(torch.allclose(image1, image2, atol=1e-4, rtol=1e-4))

Copilot uses AI. Check for mistakes.
width: Target image width
image: List of image lists, one per prompt. Each inner list contains condition images
for that prompt. For batchsize=1, can also be a simple List[PIL.Image].
device: Target device
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is missing documentation for the generator parameter which is now part of the function signature. This parameter should be documented to explain its purpose in setting the random seed for the AR model generation.

Suggested change
device: Target device
device: Target device
generator: Optional torch.Generator to control the random seed for AR model generation.

Copilot uses AI. Check for mistakes.
Comment on lines +204 to +210
k_cache_list = []
v_cache_list = []
for i in range(num_cached_samples):
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
k_cache_expanded = torch.cat(k_cache_list, dim=0)
v_cache_expanded = torch.cat(v_cache_list, dim=0)
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The KV cache expansion logic doesn't match the repeat_interleave semantics used for prior_token_ids at line 943 of the pipeline. When expanding caches for num_images_per_prompt > 1, the current code concatenates expanded caches sequentially (cache_0 repeated N times, then cache_1 repeated N times), but repeat_interleave produces an interleaved pattern (cache_0, cache_0, cache_1, cache_1). This mismatch will cause incorrect cache retrieval. The expansion should use repeat_interleave instead of expand + cat, or use a similar interleaving pattern.

Suggested change
k_cache_list = []
v_cache_list = []
for i in range(num_cached_samples):
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
k_cache_expanded = torch.cat(k_cache_list, dim=0)
v_cache_expanded = torch.cat(v_cache_list, dim=0)
# Use repeat_interleave semantics to align with prior_token_ids expansion
k_caches_stacked = torch.stack(self.k_caches, dim=0)
v_caches_stacked = torch.stack(self.v_caches, dim=0)
k_cache_expanded = k_caches_stacked.repeat_interleave(repeat_factor, dim=0)
v_cache_expanded = v_caches_stacked.repeat_interleave(repeat_factor, dim=0)

Copilot uses AI. Check for mistakes.
@sayakpaul sayakpaul requested a review from yiyixuxu January 21, 2026 03:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant