-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[GLM-Image] Add batch support for GlmImagePipeline #13007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
Signed-off-by: JaredforReal <w13431838023@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables batch processing (batch_size > 1) for the GlmImagePipeline, supporting both text-to-image and image-to-image generation. The implementation adds support for processing multiple prompts simultaneously and allows each prompt to have multiple condition images in a homogeneous batch.
Changes:
- Refactored
generate_prior_tokens()to support batch processing with left-padding and proper handling of multiple prompts/images - Updated KV cache implementation to support per-sample caching for batch processing
- Modified glyph text extraction and embedding generation to handle batched prompts
- Simplified
encode_prompt()logic usingrepeat_interleave() - Added comprehensive batch processing tests
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
src/diffusers/pipelines/glm_image/pipeline_glm_image.py |
Core pipeline changes enabling batch support with input normalization, AR token generation, and KV cache management for batched condition images |
src/diffusers/models/transformers/transformer_glm_image.py |
Refactored KV cache classes to support per-sample storage and retrieval with batch expansion logic |
tests/pipelines/glm_image/test_glm_image.py |
Added 4 new test methods covering batch processing scenarios and updated expected output values |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| continue | ||
|
|
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When len(prompt_grid_thw) == 0 at line 877, the code continues to the next iteration but never calls kv_caches.next_sample(). This means the cache index won't be incremented for samples without condition images, potentially causing misalignment between the cache indices and batch samples. The next_sample() call should be made even when continuing early.
| continue | |
| # Even if there are no condition images, advance to the next cache slot | |
| kv_caches.next_sample() | |
| continue |
| Returns: | ||
| prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens | ||
| prior_token_image_ids: Tensor with upsampled source image tokens (or None for t2i) | ||
| source_image_grid_thw: Upsampled grid info for splitting prior_token_image_ids |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring lists only 3 return values but the function actually returns 4 values: prior_token_ids, prior_token_image_ids, source_image_grid_thw, and num_source_images_per_sample. The docstring should document all return values including the fourth one.
| source_image_grid_thw: Upsampled grid info for splitting prior_token_image_ids | |
| source_image_grid_thw: Upsampled grid info for splitting prior_token_image_ids | |
| num_source_images_per_sample: Number of source images used for each prompt/sample |
| prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) | ||
| all_prior_token_ids.append(prior_token_ids) |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name prior_token_ids at line 417 shadows the function's return variable that is assigned at line 420. Inside the loop, this local variable holds upsampled tokens for a single sample, but at line 420, it's used to accumulate all samples. This shadowing reduces code clarity. Consider renaming the loop variable to something like sample_prior_token_ids to avoid confusion.
| prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) | |
| all_prior_token_ids.append(prior_token_ids) | |
| sample_prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) | |
| all_prior_token_ids.append(sample_prior_token_ids) |
| def test_inference_batch_single_identical(self): | ||
| # GLM-Image has batch_size=1 constraint due to AR model | ||
| pass | ||
| """Test that batch=1 produces consistent results with the same seed.""" | ||
| device = "cpu" | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| pipe.to(device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| @unittest.skip("Not supported.") | ||
| def test_inference_batch_consistent(self): | ||
| # GLM-Image has batch_size=1 constraint due to AR model | ||
| pass | ||
| # Run twice with same seed | ||
| inputs1 = self.get_dummy_inputs(device, seed=42) | ||
| inputs2 = self.get_dummy_inputs(device, seed=42) | ||
|
|
||
| image1 = pipe(**inputs1).images[0] | ||
| image2 = pipe(**inputs2).images[0] | ||
|
|
||
| self.assertTrue(torch.allclose(image1, image2, atol=1e-4)) | ||
|
|
||
| def test_inference_batch_multiple_prompts(self): | ||
| """Test batch processing with multiple prompts.""" | ||
| device = "cpu" | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| pipe.to(device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| generator = torch.Generator(device=device).manual_seed(42) | ||
| height, width = 32, 32 | ||
|
|
||
| inputs = { | ||
| "prompt": ["A photo of a cat", "A photo of a dog"], | ||
| "generator": generator, | ||
| "num_inference_steps": 2, | ||
| "guidance_scale": 1.5, | ||
| "height": height, | ||
| "width": width, | ||
| "max_sequence_length": 16, | ||
| "output_type": "pt", | ||
| } | ||
|
|
||
| images = pipe(**inputs).images | ||
|
|
||
| # Should return 2 images | ||
| self.assertEqual(len(images), 2) | ||
| self.assertEqual(images[0].shape, (3, 32, 32)) | ||
| self.assertEqual(images[1].shape, (3, 32, 32)) | ||
|
|
||
| @unittest.skip("Not supported.") | ||
| def test_num_images_per_prompt(self): | ||
| # GLM-Image has batch_size=1 constraint due to AR model | ||
| pass | ||
| """Test generating multiple images per prompt.""" | ||
| device = "cpu" | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| pipe.to(device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| generator = torch.Generator(device=device).manual_seed(42) | ||
| height, width = 32, 32 | ||
|
|
||
| inputs = { | ||
| "prompt": "A photo of a cat", | ||
| "generator": generator, | ||
| "num_inference_steps": 2, | ||
| "guidance_scale": 1.5, | ||
| "height": height, | ||
| "width": width, | ||
| "max_sequence_length": 16, | ||
| "output_type": "pt", | ||
| "num_images_per_prompt": 2, | ||
| } | ||
|
|
||
| images = pipe(**inputs).images | ||
|
|
||
| # Should return 2 images for single prompt | ||
| self.assertEqual(len(images), 2) | ||
| self.assertEqual(images[0].shape, (3, 32, 32)) | ||
| self.assertEqual(images[1].shape, (3, 32, 32)) | ||
|
|
||
| def test_batch_with_num_images_per_prompt(self): | ||
| """Test batch prompts with num_images_per_prompt > 1.""" | ||
| device = "cpu" | ||
| components = self.get_dummy_components() | ||
| pipe = self.pipeline_class(**components) | ||
| pipe.to(device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| generator = torch.Generator(device=device).manual_seed(42) | ||
| height, width = 32, 32 | ||
|
|
||
| inputs = { | ||
| "prompt": ["A photo of a cat", "A photo of a dog"], | ||
| "generator": generator, | ||
| "num_inference_steps": 2, | ||
| "guidance_scale": 1.5, | ||
| "height": height, | ||
| "width": width, | ||
| "max_sequence_length": 16, | ||
| "output_type": "pt", | ||
| "num_images_per_prompt": 2, | ||
| } | ||
|
|
||
| images = pipe(**inputs).images | ||
|
|
||
| # Should return 4 images (2 prompts × 2 images per prompt) | ||
| self.assertEqual(len(images), 4) | ||
|
|
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new tests only cover text-to-image batch processing scenarios. Given that this PR adds batch support for both text-to-image and image-to-image generation (as stated in the description), there should be tests covering image-to-image scenarios with batched condition images, such as multiple prompts each with one or more condition images. This would help validate the KV cache handling and prior token splitting logic for image-to-image generation.
| for p in prompt: | ||
| ocr_texts = ( | ||
| re.findall(r"'([^']*)'", p) | ||
| + re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p) |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The escaped unicode characters in the regex pattern appear incorrect. The pattern \u201c([^\u201c\u201d]*)\u201d uses unicode escape sequences that won't match properly in a raw string context. These should either be actual unicode characters or the string should not be a raw string. The original code used "([^""]*)" with actual curly quote characters which was correct.
| + re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p) | |
| + re.findall(r"“([^“”]*)”", p) |
| image1 = pipe(**inputs1).images[0] | ||
| image2 = pipe(**inputs2).images[0] | ||
|
|
||
| self.assertTrue(torch.allclose(image1, image2, atol=1e-4)) |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch.allclose call is missing the rtol parameter. For consistency with the other test at line 178 which uses both atol=1e-4 and rtol=1e-4, this test should also specify rtol=1e-4 to ensure consistent tolerance checking across all tests.
| self.assertTrue(torch.allclose(image1, image2, atol=1e-4)) | |
| self.assertTrue(torch.allclose(image1, image2, atol=1e-4, rtol=1e-4)) |
| width: Target image width | ||
| image: List of image lists, one per prompt. Each inner list contains condition images | ||
| for that prompt. For batchsize=1, can also be a simple List[PIL.Image]. | ||
| device: Target device |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring is missing documentation for the generator parameter which is now part of the function signature. This parameter should be documented to explain its purpose in setting the random seed for the AR model generation.
| device: Target device | |
| device: Target device | |
| generator: Optional torch.Generator to control the random seed for AR model generation. |
| k_cache_list = [] | ||
| v_cache_list = [] | ||
| for i in range(num_cached_samples): | ||
| k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1)) | ||
| v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1)) | ||
| k_cache_expanded = torch.cat(k_cache_list, dim=0) | ||
| v_cache_expanded = torch.cat(v_cache_list, dim=0) |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The KV cache expansion logic doesn't match the repeat_interleave semantics used for prior_token_ids at line 943 of the pipeline. When expanding caches for num_images_per_prompt > 1, the current code concatenates expanded caches sequentially (cache_0 repeated N times, then cache_1 repeated N times), but repeat_interleave produces an interleaved pattern (cache_0, cache_0, cache_1, cache_1). This mismatch will cause incorrect cache retrieval. The expansion should use repeat_interleave instead of expand + cat, or use a similar interleaving pattern.
| k_cache_list = [] | |
| v_cache_list = [] | |
| for i in range(num_cached_samples): | |
| k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1)) | |
| v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1)) | |
| k_cache_expanded = torch.cat(k_cache_list, dim=0) | |
| v_cache_expanded = torch.cat(v_cache_list, dim=0) | |
| # Use repeat_interleave semantics to align with prior_token_ids expansion | |
| k_caches_stacked = torch.stack(self.k_caches, dim=0) | |
| v_caches_stacked = torch.stack(self.v_caches, dim=0) | |
| k_cache_expanded = k_caches_stacked.repeat_interleave(repeat_factor, dim=0) | |
| v_cache_expanded = v_caches_stacked.repeat_interleave(repeat_factor, dim=0) |
This PR enables batch processing (batch_size > 1) for both text-to-image and image-to-image generation in
GlmImagePipeline.Need to be used with Transfomers GLM Image batch support
Changes
pipeline_glm_image.py
generate_prior_tokens(): Refactored to support batch processing with multiple prompts/imagesList[str]for prompts andList[List[PIL.Image]]for per-prompt condition imagesimages_per_sampleandnum_source_images_per_samplefrom processortorch.manual_seed()for reproducibilityget_glyph_texts()/_get_glyph_embeds(): Updated for batch processingencode_prompt(): Simplified repeat logic usingrepeat_interleave()__call__():batch_size == 1restrictionList[List[PIL.Image]]image format for per-prompt condition imagestransformer_glm_image.py
GlmImageLayerKVCache: Refactored to support per-sample cachingnext_sample()method to advance sample index during writingget()method handles batch expansion fornum_images_per_prompt > 1GlmImageKVCache: Addednext_sample()methodtest_glm_image.py
Supported Scenarios
Breaking Changes
None. Legacy image format
List[PIL.Image]is still supported and automatically normalized.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.