Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/diffusers/models/transformers/transformer_prx.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
// patch_size)` is the number of patches.
"""
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
b, c, h, w = img.shape
p = patch_size

# Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
img = img.reshape(b, c, h // p, p, w // p, p)

# Permute to (B, H//p, W//p, C, p, p) using einsum
# n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
img = torch.einsum("nchpwq->nhwcpq", img)

# Flatten to (B, L, C * p * p)
img = img.reshape(b, -1, c * p * p)
return img


def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
Expand All @@ -554,12 +566,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
Reconstructed image tensor of shape `(B, C, H, W)`.
"""
if isinstance(shape, tuple):
shape = shape[-2:]
h, w = shape[-2:]
elif isinstance(shape, torch.Tensor):
shape = (int(shape[0]), int(shape[1]))
h, w = (int(shape[0]), int(shape[1]))
else:
raise NotImplementedError(f"shape type {type(shape)} not supported")
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)

b, l, d = seq.shape
p = patch_size
c = d // (p * p)

# Reshape back to grid structure: (B, H//p, W//p, C, p, p)
seq = seq.reshape(b, h // p, w // p, c, p, p)

# Permute back to image layout: (B, C, H//p, p, W//p, p)
# n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
seq = torch.einsum("nhwcpq->nchpwq", seq)

# Final reshape to (B, C, H, W)
seq = seq.reshape(b, c, h, w)
return seq


class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
Expand Down
Loading