From 3ddcad1699d0209d9279b9e8b4777514329a507c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 08:30:26 +0000 Subject: [PATCH 1/7] initial batch size > 1 support --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 176 +++++++++++++----- 1 file changed, 128 insertions(+), 48 deletions(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 257e2d846c7c..cd2811b465d9 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -315,7 +315,35 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + # Check if image is a nested list (batch_size > 1) + if image is not None and isinstance(image, list) and image and isinstance(image[0], list): + # Process each batch item separately + all_prompt_embeds = [] + all_prompt_embeds_mask = [] + for i, (single_prompt, batch_images) in enumerate(zip(prompt, image)): + embeds, mask = self._get_qwen_prompt_embeds([single_prompt], batch_images, device) + all_prompt_embeds.append(embeds) + all_prompt_embeds_mask.append(mask) + + # Find max sequence length across all batch items + max_seq_len = max(e.shape[1] for e in all_prompt_embeds) + + # Pad all embeddings to same length and concatenate + padded_embeds = [] + padded_masks = [] + for embeds, mask in zip(all_prompt_embeds, all_prompt_embeds_mask): + if embeds.shape[1] < max_seq_len: + pad_len = max_seq_len - embeds.shape[1] + embeds = torch.cat([embeds, torch.zeros(embeds.shape[0], pad_len, embeds.shape[2], device=embeds.device, dtype=embeds.dtype)], dim=1) + mask = torch.cat([mask, torch.zeros(mask.shape[0], pad_len, device=mask.device, dtype=mask.dtype)], dim=1) + padded_embeds.append(embeds) + padded_masks.append(mask) + + prompt_embeds = torch.cat(padded_embeds, dim=0) + prompt_embeds_mask = torch.cat(padded_masks, dim=0) + else: + # Single batch or batch_size == 1 + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -407,6 +435,32 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents + def _preprocess_image_list(self, images): + """ + Preprocess a list of PIL images for both condition encoder and VAE. + + Args: + images: List of PIL images + + Returns: + Tuple of (condition_sizes, condition_images, vae_sizes, vae_images) + """ + condition_sizes = [] + condition_images = [] + vae_sizes = [] + vae_images = [] + + for img in images: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions(CONDITION_IMAGE_SIZE, image_width / image_height) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_sizes.append((condition_width, condition_height)) + vae_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + return condition_sizes, condition_images, vae_sizes, vae_images + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): @@ -431,6 +485,18 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): return image_latents + def _encode_and_pack_image(self, image, num_channels_latents, device, dtype, generator): + """Encode a single image and pack it. Returns packed latents.""" + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + img_latents = self._encode_vae_image(image=image, generator=generator) + else: + img_latents = image + + image_latent_height, image_latent_width = img_latents.shape[3:] + img_latents = self._pack_latents(img_latents, 1, num_channels_latents, image_latent_height, image_latent_width) + return img_latents + def prepare_latents( self, images, @@ -454,30 +520,28 @@ def prepare_latents( if images is not None: if not isinstance(images, list): images = [images] - all_image_latents = [] - for image in images: - image = image.to(device=device, dtype=dtype) - if image.shape[1] != self.latent_channels: - image_latents = self._encode_vae_image(image=image, generator=generator) - else: - image_latents = image - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - image_latent_height, image_latent_width = image_latents.shape[3:] - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width - ) - all_image_latents.append(image_latents) - image_latents = torch.cat(all_image_latents, dim=1) + + # Check if nested list (batch_size > 1): [[img1, img2], [img3, img4]] + is_nested = images and isinstance(images[0], list) + + if is_nested: + # batch_size > 1: Process each batch item separately + batch_image_latents = [] + for batch_images in images: + batch_item_latents = [ + self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator) + for img in batch_images + ] + # Concatenate all images for this batch item along sequence dimension + batch_image_latents.append(torch.cat(batch_item_latents, dim=1)) + # Stack all batch items to create final batch dimension + image_latents = torch.cat(batch_image_latents, dim=0) + else: + # batch_size == 1: Process flat list [img1, img2] + all_image_latents = [ + self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator) for img in images + ] + image_latents = torch.cat(all_image_latents, dim=1) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -627,7 +691,17 @@ def __call__( [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - image_size = image[-1].size if isinstance(image, list) else image.size + # Handle both flat list [img1, img2] and nested list [[img1, img2], [img3, img4]] + if isinstance(image, list): + # Check if nested list (batch_size > 1) + if isinstance(image[0], list): + # Use last image from first batch item + image_size = image[0][-1].size + else: + # Flat list (batch_size == 1) + image_size = image[-1].size + else: + image_size = image.size calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) height = height or calculated_height width = width or calculated_width @@ -663,32 +737,34 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # QwenImageEditPlusPipeline does not currently support batch_size > 1 - if batch_size > 1: - raise ValueError( - f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " - "Please process prompts one at a time." - ) - device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): if not isinstance(image, list): image = [image] - condition_image_sizes = [] - condition_images = [] - vae_image_sizes = [] - vae_images = [] - for img in image: - image_width, image_height = img.size - condition_width, condition_height = calculate_dimensions( - CONDITION_IMAGE_SIZE, image_width / image_height + + # Check if nested list (batch_size > 1) or flat list (batch_size == 1) + is_nested = isinstance(image[0], list) + + if is_nested: + # batch_size > 1: image = [[img1, img2], [img3, img4]] + # Process each batch item separately + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + + for batch_images in image: + cond_sizes, cond_imgs, vae_szs, vae_imgs = self._preprocess_image_list(batch_images) + condition_image_sizes.append(cond_sizes) + condition_images.append(cond_imgs) + vae_image_sizes.append(vae_szs) + vae_images.append(vae_imgs) + else: + # batch_size == 1: image = [img1, img2] + condition_image_sizes, condition_images, vae_image_sizes, vae_images = self._preprocess_image_list( + image ) - vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) - condition_image_sizes.append((condition_width, condition_height)) - vae_image_sizes.append((vae_width, vae_height)) - condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) - vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None @@ -737,15 +813,19 @@ def __call__( generator, latents, ) + # Build img_shapes for each batch item (avoid shared references!) + # Normalize vae_image_sizes to nested list format for uniform processing + sizes_list = vae_image_sizes if is_nested else [vae_image_sizes] img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), *[ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) - for vae_width, vae_height in vae_image_sizes + for vae_width, vae_height in batch_vae_sizes ], ] - ] * batch_size + for batch_vae_sizes in sizes_list + ] # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas From 25312b96f559a0e6c176ce9e0ac46367e1d976ec Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 08:34:55 +0000 Subject: [PATCH 2/7] add docs --- docs/source/en/api/pipelines/qwenimage.md | 32 +++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index ee3dd3b28e4d..48d4fc9b4f6d 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -95,6 +95,8 @@ image.save("qwen_fewsteps.png") With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference. +### Single prompt with multiple reference images + ```py import torch from PIL import Image @@ -114,6 +116,36 @@ image = pipe( ).images[0] ``` +### Batch processing with multiple prompts + +The pipeline also supports batch processing where you can edit multiple images with different prompts simultaneously. Use a nested list format `[[img1], [img2]]` to provide input images for each prompt: + +```py +import torch +from diffusers import QwenImageEditPlusPipeline +from diffusers.utils import load_image + +pipe = QwenImageEditPlusPipeline.from_pretrained( + "Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16 +).to("cuda") + +# Load input images +mountain_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mountain.jpg") + +# Process two different edits in a single batch +images = pipe( + image=[[mountain_image], [mountain_image]], # Nested list for batch_size=2 + prompt=[ + "Transform into a sunset scene with warm orange and pink sky", + "Add snow and make it a winter scene" + ], + num_inference_steps=50 +).images + +# images[0] contains the sunset version +# images[1] contains the winter version +``` + ## Performance ### torch.compile From 0c841c7ec5e7776d9130b337fa53b22277950a93 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 08:41:41 +0000 Subject: [PATCH 3/7] add tests --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 5 +- .../qwenimage/test_qwenimage_edit_plus.py | 48 +++++++++++++++---- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index cd2811b465d9..cf5137f67e45 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -607,12 +607,15 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, or `List[List[PIL.Image.Image]]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. + For batch processing with multiple prompts (batch_size > 1), provide a nested list where each sublist + contains the input images for that prompt: `[[img1_for_prompt1], [img2_for_prompt2]]`. For a single + prompt with multiple reference images (batch_size == 1), use a flat list: `[img1, img2]`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py index 6faf34728286..25f497aee8af 100644 --- a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py +++ b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py @@ -240,14 +240,44 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) - @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) - def test_num_images_per_prompt(): - super().test_num_images_per_prompt() + def test_inference_batch_single_identical(self): + # Test that batch_size=1 gives identical results to non-batched inference + self._test_inference_batch_single_identical(expected_max_diff=1e-3) - @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) - def test_inference_batch_consistent(): - super().test_inference_batch_consistent() + def test_inference_batch_consistent(self): + # Test that batched inference gives consistent results + self._test_inference_batch_consistent() - @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) - def test_inference_batch_single_identical(): - super().test_inference_batch_single_identical() + def test_batch_processing_multiple_prompts(self): + # Test batch processing with multiple prompts (batch_size > 1) + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + if str(device).startswith("mps"): + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=device).manual_seed(0) + + image = Image.new("RGB", (32, 32)) + + # Test with nested list format for batch_size=2 + inputs = { + "prompt": ["dance monkey", "jump around"], + "image": [[image], [image]], # Nested list for batch_size=2 + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + + images = pipe(**inputs).images + + # Should return 2 images (batch_size=2) + self.assertEqual(len(images), 2) + self.assertEqual(images[0].shape, (3, 32, 32)) + self.assertEqual(images[1].shape, (3, 32, 32)) From 31f315fd3c7171f863f466dfcffb84c0ecde225e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 08:46:45 +0000 Subject: [PATCH 4/7] add back xfail for num_images_per_prompt --- tests/pipelines/qwenimage/test_qwenimage_edit_plus.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py index 25f497aee8af..a6a16a4d194b 100644 --- a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py +++ b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py @@ -240,6 +240,14 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) + @pytest.mark.xfail( + condition=True, + reason="num_images_per_prompt > 1 is not yet supported for EditPlus pipeline", + strict=True, + ) + def test_num_images_per_prompt(self): + super().test_num_images_per_prompt() + def test_inference_batch_single_identical(self): # Test that batch_size=1 gives identical results to non-batched inference self._test_inference_batch_single_identical(expected_max_diff=1e-3) From 1cfab88cd1937543ec7bb05e3fa0c72926c41570 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 08:47:48 +0000 Subject: [PATCH 5/7] use helper --- .../qwenimage/pipeline_qwenimage_edit_plus.py | 38 +++---------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index cf5137f67e45..99983092df74 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -315,35 +315,7 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - # Check if image is a nested list (batch_size > 1) - if image is not None and isinstance(image, list) and image and isinstance(image[0], list): - # Process each batch item separately - all_prompt_embeds = [] - all_prompt_embeds_mask = [] - for i, (single_prompt, batch_images) in enumerate(zip(prompt, image)): - embeds, mask = self._get_qwen_prompt_embeds([single_prompt], batch_images, device) - all_prompt_embeds.append(embeds) - all_prompt_embeds_mask.append(mask) - - # Find max sequence length across all batch items - max_seq_len = max(e.shape[1] for e in all_prompt_embeds) - - # Pad all embeddings to same length and concatenate - padded_embeds = [] - padded_masks = [] - for embeds, mask in zip(all_prompt_embeds, all_prompt_embeds_mask): - if embeds.shape[1] < max_seq_len: - pad_len = max_seq_len - embeds.shape[1] - embeds = torch.cat([embeds, torch.zeros(embeds.shape[0], pad_len, embeds.shape[2], device=embeds.device, dtype=embeds.dtype)], dim=1) - mask = torch.cat([mask, torch.zeros(mask.shape[0], pad_len, device=mask.device, dtype=mask.dtype)], dim=1) - padded_embeds.append(embeds) - padded_masks.append(mask) - - prompt_embeds = torch.cat(padded_embeds, dim=0) - prompt_embeds_mask = torch.cat(padded_masks, dim=0) - else: - # Single batch or batch_size == 1 - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -612,10 +584,10 @@ def __call__( numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. - For batch processing with multiple prompts (batch_size > 1), provide a nested list where each sublist - contains the input images for that prompt: `[[img1_for_prompt1], [img2_for_prompt2]]`. For a single - prompt with multiple reference images (batch_size == 1), use a flat list: `[img1, img2]`. + latents as `image`, but if passing latents directly it is not encoded again. For batch processing with + multiple prompts (batch_size > 1), provide a nested list where each sublist contains the input images + for that prompt: `[[img1_for_prompt1], [img2_for_prompt2]]`. For a single prompt with multiple + reference images (batch_size == 1), use a flat list: `[img1, img2]`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. From c595c6973bd49170b88c3eddbe414fac26853dbd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 13:26:30 +0100 Subject: [PATCH 6/7] Update src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py Co-authored-by: YiYi Xu --- .../pipelines/qwenimage/pipeline_qwenimage_edit_plus.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 99983092df74..aaf67ece2777 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -722,6 +722,8 @@ def __call__( is_nested = isinstance(image[0], list) if is_nested: + if batch_size > 1 and len(image) != batch_size: + raise ValueError( f"Image batch_size ({len(image)}) must match batch_size for prompts ({batch_size}) for batch inference."") # batch_size > 1: image = [[img1, img2], [img3, img4]] # Process each batch item separately condition_image_sizes = [] From 8faef28c42bae54bd4a32082eb7cdd2770deefc3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 12 Jan 2026 12:42:06 +0000 Subject: [PATCH 7/7] fix tests --- .../pipelines/qwenimage/pipeline_qwenimage_edit_plus.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index aaf67ece2777..f3c8910ded78 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -723,7 +723,9 @@ def __call__( if is_nested: if batch_size > 1 and len(image) != batch_size: - raise ValueError( f"Image batch_size ({len(image)}) must match batch_size for prompts ({batch_size}) for batch inference."") + raise ValueError( + f"Image batch_size ({len(image)}) must match batch_size for prompts ({batch_size}) for batch inference." + ) # batch_size > 1: image = [[img1, img2], [img3, img4]] # Process each batch item separately condition_image_sizes = []