Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn
from ..modeling_outputs import Transformer2DModelOutput
from ...utils import is_torch_npu_available


ADALN_EMBED_DIM = 256
Expand Down Expand Up @@ -311,6 +312,62 @@ def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
return x


class RopeEmbedderNPU:
def __init__(
self,
theta: float = 256.0,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (64, 128, 128),
):
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
self.freqs_real = None
self.freqs_imag = None

@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_real_list = []
freqs_imag_list = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_real = torch.cos(freqs)
freqs_imag = torch.sin(freqs)
freqs_real_list.append(freqs_real.to(torch.float32))
freqs_imag_list.append(freqs_imag.to(torch.float32))

return freqs_real_list, freqs_imag_list

def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device

if self.freqs_real is None or self.freqs_imag is None:
freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_real = [fr.to(device) for fr in freqs_real]
self.freqs_imag = [fi.to(device) for fi in freqs_imag]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_real[0].device != device:
self.freqs_real = [fr.to(device) for fr in freqs_real]
self.freqs_imag = [fi.to(device) for fi in freqs_imag]

result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
real_part = self.freqs_real[i][index]
imag_part = self.freqs_imag[i][index]
complex_part = torch.complex(real_part, imag_part)
result.append(complex_part)
return torch.cat(result, dim=-1)


class RopeEmbedder:
def __init__(
self,
Expand Down Expand Up @@ -478,7 +535,10 @@ def __init__(
self.axes_dims = axes_dims
self.axes_lens = axes_lens

self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
if is_torch_npu_available:
self.rope_embedder = RopeEmbedderNPU(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
else:
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)

def unpatchify(
self,
Expand Down