diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 73862230..03ab7bba 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -93,7 +93,7 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs + freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 21dc3b33..a7bf24f2 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank + freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype)