diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5983c34ab640..5afa6bb5b49f 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention_dispatch import dispatch_attention_fn from ..modeling_outputs import Transformer2DModelOutput +from ...utils import is_torch_npu_available ADALN_EMBED_DIM = 256 @@ -311,6 +312,62 @@ def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): return x +class RopeEmbedderNPU: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + self.freqs_real = None + self.freqs_imag = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_real_list = [] + freqs_imag_list = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_real = torch.cos(freqs) + freqs_imag = torch.sin(freqs) + freqs_real_list.append(freqs_real.to(torch.float32)) + freqs_imag_list.append(freqs_imag.to(torch.float32)) + + return freqs_real_list, freqs_imag_list + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_real is None or self.freqs_imag is None: + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_real[0].device != device: + self.freqs_real = [fr.to(device) for fr in freqs_real] + self.freqs_imag = [fi.to(device) for fi in freqs_imag] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + real_part = self.freqs_real[i][index] + imag_part = self.freqs_imag[i][index] + complex_part = torch.complex(real_part, imag_part) + result.append(complex_part) + return torch.cat(result, dim=-1) + + class RopeEmbedder: def __init__( self, @@ -478,7 +535,10 @@ def __init__( self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + if is_torch_npu_available: + self.rope_embedder = RopeEmbedderNPU(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + else: + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) def unpatchify( self,