Skip to content

Conversation

@CalamitousFelicitousness
Copy link
Contributor

@CalamitousFelicitousness CalamitousFelicitousness commented Jan 21, 2026

What does this PR do?

This PR adds an inpainting pipeline for Z-Image. The summary of changes are below:

  • Implemented the ZImageInpaintPipeline class for mask-based inpainting
  • Updated the pipeline structure to include ZImageInpaintPipeline alongside ZImagePipeline and ZImageImg2ImgPipeline
  • Mapped the new ZImageInpaintPipeline in AUTO_INPAINT_PIPELINES_MAPPING
  • Added unit tests for ZImageInpaintPipeline with torch.empty fix for test stability
  • Updated dummy objects to include ZImageInpaintPipeline
  • Added documentation with usage example

Closes issue #12752

Tested using a simple script:

Testing script
  #!/usr/bin/env python
  """Test script for ZImage inpaint support."""

  import sys
  sys.path.insert(0, '/home/ohiom/diffusers/src')

  import torch
  import numpy as np
  from PIL import Image
  from diffusers import ZImageInpaintPipeline

  # Paths
  MODEL_PATH =
  "database/models/huggingface/models--Tongyi-MAI--Z-Image-Turbo/snapshots/78771b7e11b922c868dd766476bda1f4fc6bfc96"
  INPUT_IMAGE_PATH = "death_remix_1024.png"

  print("Loading ZImageInpaintPipeline...")
  pipe = ZImageInpaintPipeline.from_pretrained(
      MODEL_PATH,
      torch_dtype=torch.bfloat16,
      local_files_only=True,
  )
  pipe.to("cuda")
  print("Pipeline loaded.")

  # Load input image
  print(f"\nLoading input image from {INPUT_IMAGE_PATH}...")
  input_image = Image.open(INPUT_IMAGE_PATH).convert("RGB")
  print(f"Input image size: {input_image.size}")

  # Create a mask (white = inpaint, black = preserve)
  width, height = input_image.size
  mask = np.zeros((height, width), dtype=np.uint8)
  h_start, h_end = height // 4, 3 * height // 4
  w_start, w_end = width // 4, 3 * width // 4
  mask[h_start:h_end, w_start:w_end] = 255
  mask_image = Image.fromarray(mask)

  # Generate an inpainted image
  prompt = "a woman with pale skin in a black shirt, oil painting style"
  strength = 0.75

  print(f"\nGenerating inpainted image with prompt: {prompt}")
  print(f"Strength: {strength}")

  image = pipe(
      prompt=prompt,
      image=input_image,
      mask_image=mask_image,
      strength=strength,
      num_inference_steps=8,
      guidance_scale=1.0,
      generator=torch.Generator(device="cuda").manual_seed(42),
  ).images[0]

  output_path = "test_zimage_inpaint_output.png"
  image.save(output_path)
  print(f"\nImage saved to {output_path}")

LoRA functionality is also supported (inherited from ZImageLoraLoaderMixin).

Clipboard3

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu @asomoza @sayakpaul

Updated the pipeline structure to include ZImageInpaintPipeline
    alongside ZImagePipeline and ZImageImg2ImgPipeline.
Implemented the ZImageInpaintPipeline class for inpainting
    tasks, including necessary methods for encoding prompts,
    preparing masked latents, and denoising.
Enhanced the auto_pipeline to map the new ZImageInpaintPipeline
    for inpainting generation tasks.
Added unit tests for ZImageInpaintPipeline to ensure
    functionality and performance.
Updated dummy objects to include ZImageInpaintPipeline for
    testing purposes.
- Add torch.empty fix for x_pad_token and cap_pad_token in test
- Add # Copied from annotations for encode_prompt methods
- Add documentation with usage example and autodoc directive
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a comprehensive inpainting pipeline for Z-Image, extending the existing Z-Image family of pipelines (text-to-image, img2img, controlnet) with mask-based inpainting capabilities. The implementation follows established patterns from other inpainting pipelines in the diffusers library while adapting to Z-Image's specific requirements (e.g., complex64 RoPE embeddings, flow matching scheduler).

Changes:

  • Implemented ZImageInpaintPipeline with full inpainting support including mask blending and strength-based denoising control
  • Added comprehensive test suite covering inference, batch processing, strength validation, mask functionality, VAE tiling, and device offloading
  • Integrated the pipeline into auto_pipeline infrastructure with proper model mapping for "z-image" model type
  • Updated all necessary init.py files and dummy objects for proper exports
  • Added documentation with usage examples

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py New inpainting pipeline implementation with prepare_mask_latents, prepare_latents, and main call method for mask-based image inpainting
tests/pipelines/z_image/test_z_image_inpaint.py Comprehensive test suite including inference tests, strength parameter validation, mask functionality tests, and compatibility tests
src/diffusers/pipelines/z_image/init.py Added ZImageInpaintPipeline to module exports
src/diffusers/pipelines/init.py Added ZImageInpaintPipeline to main pipelines module exports
src/diffusers/init.py Added ZImageInpaintPipeline to top-level diffusers exports
src/diffusers/pipelines/auto_pipeline.py Mapped ZImageInpaintPipeline to "z-image" in AUTO_INPAINT_PIPELINES_MAPPING
src/diffusers/utils/dummy_torch_and_transformers_objects.py Added dummy ZImageInpaintPipeline class for when dependencies are not available
docs/source/en/api/pipelines/z_image.md Added inpainting section with usage example and API reference

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +356 to +358
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prepare_mask_latents method is missing validation checks for batch size divisibility. When expanding masks or masked_image_latents to match the batch size, the code directly performs division without checking if the batch size is divisible by the input tensor's batch dimension. This could lead to silent errors or incorrect tensor expansion.

Reference implementations like StableDiffusionInpaintPipeline include explicit checks:

if not batch_size % mask.shape[0] == 0:
    raise ValueError(...)

These checks should be added before lines 356 and 358 to ensure proper batch handling and provide clear error messages when batch sizes don't match.

Suggested change
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
if batch_size % mask.shape[0] != 0:
raise ValueError(
f"Cannot expand mask batch of size {mask.shape[0]} to {batch_size}. "
"The batch size must be divisible by mask.shape[0]."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if batch_size % masked_image_latents.shape[0] != 0:
raise ValueError(
"Cannot expand masked_image_latents batch of size "
f"{masked_image_latents.shape[0]} to {batch_size}. "
"The batch size must be divisible by masked_image_latents.shape[0]."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)

Copilot uses AI. Check for mistakes.

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback handling doesn't update mask and masked_image_latents from the callback outputs, even though they're listed in _callback_tensor_inputs at line 177. If these tensors are intended to be modifiable through callbacks, the code should handle their updates similar to how latents and prompt_embeds are handled. If they're not intended to be modifiable, they should be removed from _callback_tensor_inputs to avoid confusion.

Suggested change
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
mask = callback_outputs.pop("mask", mask)
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant