Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions diffsynth/utils/xfuser/xdit_context_parallel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch
from typing import Optional
from einops import rearrange
from yunchang.kernels import AttnType
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention

from ... import IS_NPU_AVAILABLE
from ...core.device import parse_nccl_backend, parse_device_type


Expand All @@ -30,13 +33,16 @@ def sinusoidal_embedding_1d(dim, position):
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
original_tensor_device = original_tensor.device
if original_tensor.device == "npu":
original_tensor = original_tensor.cpu()
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
return padded_tensor

def rope_apply(x, freqs, num_heads):
Expand Down Expand Up @@ -133,7 +139,12 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)

x = xFuserLongContextAttention()(
attn_type = AttnType.FA
ring_impl_type = "basic"
if IS_NPU_AVAILABLE:
attn_type = AttnType.NPU
ring_impl_type = "basic_npu"
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
None,
query=q,
key=k,
Expand Down
8 changes: 8 additions & 0 deletions docs/en/Pipeline_Usage/GPU_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ video = pipe(
save_video(video, "video.mp4", fps=15, quality=5)
```

#### USP(Unified Sequence Parallel)
If you want to use this feature on NPU, please install additional third-party libraries as follows:
```shell
pip install git+https://github.com/feifeibear/long-context-attention.git
pip install git+https://github.com/xdit-project/xDiT.git
```


### Training
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.

Expand Down
7 changes: 7 additions & 0 deletions docs/zh/Pipeline_Usage/GPU_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ video = pipe(
save_video(video, "video.mp4", fps=15, quality=5)
```

#### USP(Unified Sequence Parallel)
如果想要在NPU上使用该特性,请通过如下方式安装额外的第三方库:
```shell
pip install git+https://github.com/feifeibear/long-context-attention.git
pip install git+https://github.com/xdit-project/xDiT.git
```

### 训练
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。

Expand Down