diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index b7b3aa391ce4..e863b9714fa0 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -143,41 +143,98 @@ def forward( class GlmImageLayerKVCache: - """KV cache for GlmImage model.""" + """KV cache for GlmImage model. + + Supports per-sample caching for batch processing where each sample + may have different condition images. + """ def __init__(self): - self.k_cache = None - self.v_cache = None + # List of (k_cache, v_cache) tuples, one per batch sample + self.k_caches: List[Optional[torch.Tensor]] = [] + self.v_caches: List[Optional[torch.Tensor]] = [] self.mode: Optional[str] = None # "write", "read", "skip" + self.current_sample_idx: int = 0 # Current sample index for writing def store(self, k: torch.Tensor, v: torch.Tensor): - if self.k_cache is None: - self.k_cache = k - self.v_cache = v + """Store KV cache for the current sample.""" + # k, v shape: (1, seq_len, num_heads, head_dim) + if len(self.k_caches) <= self.current_sample_idx: + # First time storing for this sample + self.k_caches.append(k) + self.v_caches.append(v) else: - self.k_cache = torch.cat([self.k_cache, k], dim=1) - self.v_cache = torch.cat([self.v_cache, v], dim=1) + # Append to existing cache for this sample (multiple condition images) + self.k_caches[self.current_sample_idx] = torch.cat( + [self.k_caches[self.current_sample_idx], k], dim=1 + ) + self.v_caches[self.current_sample_idx] = torch.cat( + [self.v_caches[self.current_sample_idx], v], dim=1 + ) def get(self, k: torch.Tensor, v: torch.Tensor): - if self.k_cache.shape[0] != k.shape[0]: - k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1) - v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1) + """Get combined KV cache for all samples in the batch. + + Args: + k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim) + v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim) + + Returns: + Combined key and value tensors with cached values prepended. + """ + batch_size = k.shape[0] + num_cached_samples = len(self.k_caches) + + if num_cached_samples == 0: + return k, v + + if num_cached_samples == 1: + # Single cache, expand for all batch samples (shared condition images) + k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1) + v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1) + elif num_cached_samples == batch_size: + # Per-sample cache, concatenate along batch dimension + k_cache_expanded = torch.cat(self.k_caches, dim=0) + v_cache_expanded = torch.cat(self.v_caches, dim=0) else: - k_cache_expanded = self.k_cache - v_cache_expanded = self.v_cache + # Mismatch: try to handle by repeating the caches + # This handles cases like num_images_per_prompt > 1 + repeat_factor = batch_size // num_cached_samples + if batch_size % num_cached_samples == 0: + k_cache_list = [] + v_cache_list = [] + for i in range(num_cached_samples): + k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1)) + v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1)) + k_cache_expanded = torch.cat(k_cache_list, dim=0) + v_cache_expanded = torch.cat(v_cache_list, dim=0) + else: + raise ValueError( + f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. " + f"Batch size must be a multiple of the number of cached samples." + ) - k_cache = torch.cat([k_cache_expanded, k], dim=1) - v_cache = torch.cat([v_cache_expanded, v], dim=1) - return k_cache, v_cache + k_combined = torch.cat([k_cache_expanded, k], dim=1) + v_combined = torch.cat([v_cache_expanded, v], dim=1) + return k_combined, v_combined def clear(self): - self.k_cache = None - self.v_cache = None + self.k_caches = [] + self.v_caches = [] self.mode = None + self.current_sample_idx = 0 + + def next_sample(self): + """Move to the next sample for writing.""" + self.current_sample_idx += 1 class GlmImageKVCache: - """Container for all layers' KV caches.""" + """Container for all layers' KV caches. + + Supports per-sample caching for batch processing where each sample + may have different condition images. + """ def __init__(self, num_layers: int): self.num_layers = num_layers @@ -192,6 +249,12 @@ def set_mode(self, mode: Optional[str]): for cache in self.caches: cache.mode = mode + def next_sample(self): + """Move to the next sample for writing. Call this after processing + all condition images for one batch sample.""" + for cache in self.caches: + cache.next_sample() + def clear(self): for cache in self.caches: cache.clear() diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 5499b8769fa6..997299fbb488 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -260,25 +260,118 @@ def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> token_ids = token_ids.view(1, -1) return token_ids + @staticmethod + def _validate_and_normalize_images( + image: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]], + batch_size: int, + ) -> Optional[List[List[PIL.Image.Image]]]: + """ + Validate and normalize image inputs to List[List[PIL.Image]]. + + Rules: + - batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length + - batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]]) + - Other formats raise ValueError + + Args: + image: Input images in various formats + batch_size: Number of prompts in the batch + + Returns: + Normalized images as List[List[PIL.Image]], or None if no images provided + """ + if image is None or len(image) == 0: + return None + + first_element = image[0] + + if batch_size == 1: + # Legacy format: List[PIL.Image] -> [[img1, img2, ...]] + if not isinstance(first_element, (list, tuple)): + return [list(image)] + # Already in List[List[PIL.Image]] format + if len(image) != 1: + raise ValueError( + f"For batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got {len(image)}." + ) + return [list(image[0])] + + # batch_size > 1: must be List[List[PIL.Image]] + if not isinstance(first_element, (list, tuple)): + raise ValueError( + f"For batch_size > 1, images must be List[List[PIL.Image]] format. " + f"Got List[{type(first_element).__name__}] instead. " + f"Each prompt requires its own list of condition images." + ) + + if len(image) != batch_size: + raise ValueError( + f"Number of image lists ({len(image)}) must match batch size ({batch_size})." + ) + + # Validate homogeneous: all sublists must have same length + num_input_images_per_prompt = len(image[0]) + for idx, imgs in enumerate(image): + if len(imgs) != num_input_images_per_prompt: + raise ValueError( + f"All prompts must have the same number of condition images. " + f"Prompt 0 has {num_input_images_per_prompt} images, but prompt {idx} has {len(imgs)} images." + ) + + return [list(imgs) for imgs in image] + def generate_prior_tokens( self, - prompt: str, + prompt: Union[str, List[str]], height: int, width: int, - image: Optional[List[PIL.Image.Image]] = None, + image: Optional[List[List[PIL.Image.Image]]] = None, device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, ): + """ + Generate prior tokens for the DiT model using the AR model. + + Args: + prompt: Single prompt or list of prompts + height: Target image height + width: Target image width + image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated + using _validate_and_normalize_images() before calling this method. + device: Target device + generator: Random generator for reproducibility + + Returns: + Tuple of: + - prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens + - prior_token_image_ids: Tensor with upsampled source image tokens (or None for t2i) + - source_image_grid_thw: Upsampled grid info, shape (batch_size * num_condition_images, 3) + - num_condition_images: Number of condition images per sample (0 for t2i) + """ device = device or self._execution_device - is_text_to_image = image is None or len(image) == 0 - content = [] - if image is not None: - for img in image: - content.append({"type": "image", "image": img}) - content.append({"type": "text", "text": prompt}) - messages = [{"role": "user", "content": content}] + + # Normalize prompt to list format + prompt_list = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt_list) + + # Image should already be normalized by caller, but handle None case + is_text_to_image = image is None or len(image) == 0 or all(len(imgs) == 0 for imgs in image) + + # Build messages for each sample in the batch + all_messages = [] + for idx, p in enumerate(prompt_list): + content = [] + if not is_text_to_image and image is not None and idx < len(image): + for img in image[idx]: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": p}) + all_messages.append([{"role": "user", "content": content}]) + + # Process with the processor (supports batch with left padding) inputs = self.processor.apply_chat_template( - messages, + all_messages, tokenize=True, + padding=True if batch_size > 1 else False, target_h=height, target_w=width, return_dict=True, @@ -286,44 +379,106 @@ def generate_prior_tokens( ).to(device) image_grid_thw = inputs.get("image_grid_thw") + images_per_sample = inputs.get("images_per_sample") + + # Determine number of grids per sample + # For homogeneous batch, all samples have the same structure + num_condition_images = len(image[0]) if image is not None and len(image) > 0 else 0 + if images_per_sample is not None: + num_grids_per_sample = images_per_sample[0].item() + else: + # Fallback for batch_size=1: total grids is for single sample + num_grids_per_sample = image_grid_thw.shape[0] + + # Compute generation params (same for all samples in homogeneous batch) + first_sample_grids = image_grid_thw[:num_grids_per_sample] max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params( - image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image + image_grid_thw=first_sample_grids, is_text_to_image=is_text_to_image ) + # Generate source image tokens (prior_token_image_ids) for i2i mode prior_token_image_ids = None - if image is not None: - prior_token_image_embed = self.vision_language_encoder.get_image_features( - inputs["pixel_values"], image_grid_thw[:-1] - ) - prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) - prior_token_image_ids = self.vision_language_encoder.get_image_tokens( - prior_token_image_embed, image_grid_thw[:-1] - ) - - # For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs. - # max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS). + source_image_grid_thw = None + if not is_text_to_image and "pixel_values" in inputs and num_condition_images > 0: + # Extract source grids by selecting condition image indices (skip target grids) + # Grid order from processor: [s0_cond1, s0_cond2, ..., s0_target, s1_cond1, s1_cond2, ..., s1_target, ...] + # We need indices: [0, 1, ..., num_condition_images-1, num_grids_per_sample, num_grids_per_sample+1, ...] + source_indices = [] + for sample_idx in range(batch_size): + base = sample_idx * num_grids_per_sample + source_indices.extend(range(base, base + num_condition_images)) + source_grids = image_grid_thw[source_indices] + + if len(source_grids) > 0: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], source_grids + ) + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, source_grids + ) + + # Upsample each source image's prior tokens to match VAE/DiT resolution + split_sizes = source_grids.prod(dim=-1).tolist() + prior_ids_per_source = torch.split(prior_token_image_ids_d32, split_sizes) + upsampled_prior_ids = [] + for i, prior_ids in enumerate(prior_ids_per_source): + t, h, w = source_grids[i].tolist() + upsampled = self._upsample_token_ids(prior_ids, int(h), int(w)) + upsampled_prior_ids.append(upsampled.squeeze(0)) + prior_token_image_ids = torch.cat(upsampled_prior_ids, dim=0) + + # Upsample grid dimensions for later splitting + upsampled_grids = source_grids.clone() + upsampled_grids[:, 1] = upsampled_grids[:, 1] * 2 + upsampled_grids[:, 2] = upsampled_grids[:, 2] * 2 + source_image_grid_thw = upsampled_grids + + # Generate with AR model + # Set torch random seed from generator for reproducibility + # (transformers generate() doesn't accept generator parameter) + if generator is not None: + seed = generator.initial_seed() + torch.manual_seed(seed) + if device is not None and device.type == "cuda": + torch.cuda.manual_seed(seed) outputs = self.vision_language_encoder.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, ) - prior_token_ids_d32 = self._extract_large_image_tokens( - outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w - ) - prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + # Extract and upsample prior tokens for each sample + # For left-padded inputs, generated tokens start after the padded input sequence + all_prior_token_ids = [] + max_input_length = inputs["input_ids"].shape[-1] + for idx in range(batch_size): + # For left-padded sequences, generated tokens start at max_input_length + # (padding is on the left, so all sequences end at the same position) + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs[idx:idx+1], max_input_length, large_image_offset, token_h * token_w + ) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + all_prior_token_ids.append(prior_token_ids) + + prior_token_ids = torch.cat(all_prior_token_ids, dim=0) - return prior_token_ids, prior_token_image_ids + return prior_token_ids, prior_token_image_ids, source_image_grid_thw, num_condition_images def get_glyph_texts(self, prompt): - prompt = prompt[0] if isinstance(prompt, list) else prompt - ocr_texts = ( - re.findall(r"'([^']*)'", prompt) - + re.findall(r"“([^“”]*)”", prompt) - + re.findall(r'"([^"]*)"', prompt) - + re.findall(r"「([^「」]*)」", prompt) - ) - return ocr_texts + """Extract glyph texts from prompt(s). Returns a list of lists for batch processing.""" + if isinstance(prompt, str): + prompt = [prompt] + all_ocr_texts = [] + for p in prompt: + ocr_texts = ( + re.findall(r"'([^']*)'", p) + + re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p) + + re.findall(r'"([^"]*)"', p) + + re.findall(r"「([^「」]*)」", p) + ) + all_ocr_texts.append(ocr_texts) + return all_ocr_texts def _get_glyph_embeds( self, @@ -332,29 +487,50 @@ def _get_glyph_embeds( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): + """Get glyph embeddings for each prompt in the batch.""" device = device or self._execution_device dtype = dtype or self.text_encoder.dtype - glyph_texts = self.get_glyph_texts(prompt) - input_ids = self.tokenizer( - glyph_texts if len(glyph_texts) > 0 else [""], - max_length=max_sequence_length, - truncation=True, - ).input_ids - input_ids = [ - [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids - ] - max_length = max(len(input_ids_) for input_ids_ in input_ids) - attention_mask = torch.tensor( - [[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device - ) - input_ids = torch.tensor( - [input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids], - device=device, - ) - outputs = self.text_encoder(input_ids, attention_mask=attention_mask) - glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + # get_glyph_texts now returns a list of lists (one per prompt) + all_glyph_texts = self.get_glyph_texts(prompt) + + all_glyph_embeds = [] + for glyph_texts in all_glyph_texts: + if len(glyph_texts) == 0: + glyph_texts = [""] + input_ids = self.tokenizer( + glyph_texts, + max_length=max_sequence_length, + truncation=True, + ).input_ids + input_ids = [ + [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids + ] + max_length = max(len(input_ids_) for input_ids_ in input_ids) + attention_mask = torch.tensor( + [[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], + device=device, + ) + input_ids = torch.tensor( + [input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids], + device=device, + ) + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + all_glyph_embeds.append(glyph_embeds) + + # Pad to same sequence length and stack (use left padding to match transformers) + max_seq_len = max(emb.size(1) for emb in all_glyph_embeds) + padded_embeds = [] + for emb in all_glyph_embeds: + if emb.size(1) < max_seq_len: + pad = torch.zeros( + emb.size(0), max_seq_len - emb.size(1), emb.size(2), device=device, dtype=emb.dtype + ) + emb = torch.cat([pad, emb], dim=1) # left padding + padded_embeds.append(emb) + glyph_embeds = torch.cat(padded_embeds, dim=0) return glyph_embeds.to(device=device, dtype=dtype) def encode_prompt( @@ -399,9 +575,9 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) - seq_len = prompt_embeds.size(1) - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + # Repeat embeddings for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) # For GLM-Image, negative_prompt must be "" instead of None if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -409,9 +585,8 @@ def encode_prompt( negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) - seq_len = negative_prompt_embeds.size(1) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + if num_images_per_prompt > 1: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) return prompt_embeds, negative_prompt_embeds @@ -611,34 +786,46 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - if batch_size != 1: - raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}") device = self._execution_device - # 2. Preprocess image tokens and prompt tokens + # 2. Validate and normalize image format + normalized_image = self._validate_and_normalize_images(image, batch_size) + + # 3. Generate prior tokens (batch mode) + num_condition_images = 0 + # Get a single generator for AR model (use first if list provided) + ar_generator = generator[0] if isinstance(generator, list) else generator if prior_token_ids is None: - prior_token_ids, prior_token_image_ids = self.generate_prior_tokens( - prompt=prompt[0] if isinstance(prompt, list) else prompt, - image=image, + prior_token_ids, prior_token_image_ids, source_image_grid_thw, num_condition_images = self.generate_prior_tokens( + prompt=prompt, + image=normalized_image, height=height, width=width, device=device, + generator=ar_generator, ) - - # 3. Preprocess image - if image is not None: - preprocessed_condition_images = [] - for img in image: - image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] - multiple_of = self.vae_scale_factor * self.transformer.config.patch_size - image_height = (image_height // multiple_of) * multiple_of - image_width = (image_width // multiple_of) * multiple_of - img = self.image_processor.preprocess(img, height=image_height, width=image_width) - preprocessed_condition_images.append(img) - height = height or image_height - width = width or image_width - image = preprocessed_condition_images + else: + # User provided prior_token_ids directly + prior_token_image_ids = None + source_image_grid_thw = None + + # 4. Preprocess images for VAE encoding + preprocessed_images = None + if normalized_image is not None: + preprocessed_images = [] + for prompt_images in normalized_image: + prompt_preprocessed = [] + for img in prompt_images: + image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] + multiple_of = self.vae_scale_factor * self.transformer.config.patch_size + image_height = (image_height // multiple_of) * multiple_of + image_width = (image_width // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width) + prompt_preprocessed.append(img) + height = height or image_height + width = width or image_width + preprocessed_images.append(prompt_preprocessed) # 5. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -652,7 +839,7 @@ def __call__( dtype=self.dtype, ) - # 4. Prepare latents and (optional) image kv cache + # 6. Prepare latents and (optional) image kv cache latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size=batch_size * num_images_per_prompt, @@ -666,7 +853,7 @@ def __call__( ) kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) - if image is not None: + if preprocessed_images is not None and prior_token_image_ids is not None: kv_caches.set_mode("write") latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1) latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1) @@ -674,29 +861,54 @@ def __call__( latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype) latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype) - for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids): - condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype) - condition_latent = retrieve_latents( - self.vae.encode(condition_image), generator=generator, sample_mode="argmax" - ) - condition_latent = (condition_latent - latents_mean) / latents_std - - # Do not remove. - # It would be use to run the reference image through a - # forward pass at timestep 0 and keep the KV cache. - _ = self.transformer( - hidden_states=condition_latent, - encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], - prior_token_id=condition_image_prior_token_id, - prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), - timestep=torch.zeros((1,), device=device), - target_size=torch.tensor([condition_image.shape[-2:]], device=device), - crop_coords=torch.zeros((1, 2), device=device), - attention_kwargs=attention_kwargs, - kv_caches=kv_caches, - ) + # For homogeneous batch: split grids and prior_token_image_ids by sample + # source_image_grid_thw order: [s0_c1, s0_c2, ..., s1_c1, s1_c2, ...] + # Split into per-sample chunks of num_condition_images each + grids_per_sample = list(torch.split(source_image_grid_thw, num_condition_images)) + + # Calculate tokens per sample (may vary if condition images have different sizes) + tokens_per_image = source_image_grid_thw.prod(dim=-1).tolist() + tokens_per_sample = [] + for i in range(batch_size): + start_idx = i * num_condition_images + end_idx = start_idx + num_condition_images + tokens_per_sample.append(sum(tokens_per_image[start_idx:end_idx])) + prior_ids_per_sample = torch.split(prior_token_image_ids, tokens_per_sample) + + # Process each sample's condition images + for prompt_idx in range(batch_size): + prompt_images = preprocessed_images[prompt_idx] + prompt_prior_ids = prior_ids_per_sample[prompt_idx] + prompt_grid_thw = grids_per_sample[prompt_idx] + + # Split this sample's prior_token_image_ids by each image's token count + split_sizes = prompt_grid_thw.prod(dim=-1).tolist() + prior_ids_per_image = torch.split(prompt_prior_ids, split_sizes) + + # Process each condition image for this sample + for condition_image, condition_image_prior_token_id in zip(prompt_images, prior_ids_per_image): + condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype) + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], + prior_token_id=condition_image_prior_token_id, + prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=device), + target_size=torch.tensor([condition_image.shape[-2:]], device=device), + crop_coords=torch.zeros((1, 2), device=device), + attention_kwargs=attention_kwargs, + kv_caches=kv_caches, + ) + + # Move to next sample's cache slot + kv_caches.next_sample() - # 6. Prepare additional timestep conditions + # 7. Prepare additional timestep conditions target_size = (height, width) target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) @@ -726,10 +938,13 @@ def __call__( ) self._num_timesteps = len(timesteps) - # 7. Denoising loop + # 8. Denoising loop transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + # Repeat prior_token_ids for num_images_per_prompt + if num_images_per_prompt > 1: + prior_token_ids = prior_token_ids.repeat_interleave(num_images_per_prompt, dim=0) prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool) prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -742,7 +957,7 @@ def __call__( timestep = t.expand(latents.shape[0]) - 1 - if image is not None: + if normalized_image is not None: kv_caches.set_mode("read") noise_pred_cond = self.transformer( @@ -760,7 +975,7 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: - if image is not None: + if normalized_image is not None: kv_caches.set_mode("skip") noise_pred_uncond = self.transformer( hidden_states=latent_model_input, diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py index 7a380b99b0fb..d907d082d275 100644 --- a/tests/pipelines/glm_image/test_glm_image.py +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -169,7 +169,7 @@ def test_inference(self): # fmt: off expected_slice = np.array( [ - 0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728 + 0.5849247, 0.50278825, 0.45747858, 0.45895284, 0.43804976, 0.47044256, 0.5239665, 0.47904694, 0.3323419, 0.38725388, 0.28505728, 0.3161863, 0.35026982, 0.37546024, 0.4090118, 0.46629113 ] ) # fmt: on @@ -177,20 +177,109 @@ def test_inference(self): self.assertEqual(image.shape, (3, 32, 32)) self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)) - @unittest.skip("Not supported.") def test_inference_batch_single_identical(self): - # GLM-Image has batch_size=1 constraint due to AR model - pass + """Test that batch=1 produces consistent results with the same seed.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) - @unittest.skip("Not supported.") - def test_inference_batch_consistent(self): - # GLM-Image has batch_size=1 constraint due to AR model - pass + # Run twice with same seed + inputs1 = self.get_dummy_inputs(device, seed=42) + inputs2 = self.get_dummy_inputs(device, seed=42) + + image1 = pipe(**inputs1).images[0] + image2 = pipe(**inputs2).images[0] + + self.assertTrue(torch.allclose(image1, image2, atol=1e-4)) + + def test_inference_batch_multiple_prompts(self): + """Test batch processing with multiple prompts.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + height, width = 32, 32 + + inputs = { + "prompt": ["A photo of a cat", "A photo of a dog"], + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + } + + images = pipe(**inputs).images + + # Should return 2 images + self.assertEqual(len(images), 2) + self.assertEqual(images[0].shape, (3, 32, 32)) + self.assertEqual(images[1].shape, (3, 32, 32)) - @unittest.skip("Not supported.") def test_num_images_per_prompt(self): - # GLM-Image has batch_size=1 constraint due to AR model - pass + """Test generating multiple images per prompt.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + height, width = 32, 32 + + inputs = { + "prompt": "A photo of a cat", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + "num_images_per_prompt": 2, + } + + images = pipe(**inputs).images + + # Should return 2 images for single prompt + self.assertEqual(len(images), 2) + self.assertEqual(images[0].shape, (3, 32, 32)) + self.assertEqual(images[1].shape, (3, 32, 32)) + + def test_batch_with_num_images_per_prompt(self): + """Test batch prompts with num_images_per_prompt > 1.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + height, width = 32, 32 + + inputs = { + "prompt": ["A photo of a cat", "A photo of a dog"], + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + "num_images_per_prompt": 2, + } + + images = pipe(**inputs).images + + # Should return 4 images (2 prompts × 2 images per prompt) + self.assertEqual(len(images), 4) @unittest.skip("Needs to be revisited.") def test_encode_prompt_works_in_isolation(self):