Skip to content

Commit ae0f69b

Browse files
authored
Add SpecDec support to selective_state_update (#29488)
Signed-off-by: Roi Koren <roik@nvidia.com>
1 parent 799804d commit ae0f69b

File tree

2 files changed

+507
-74
lines changed

2 files changed

+507
-74
lines changed

tests/kernels/mamba/test_mamba_ssm.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
425425
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
426426

427427

428+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
429+
@pytest.mark.parametrize("has_z", [False, True])
430+
@pytest.mark.parametrize("dstate", [16, 64])
431+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
432+
@pytest.mark.parametrize("max_seq_len", [1, 2, 4])
433+
def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
434+
device = "cuda"
435+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
436+
if itype == torch.bfloat16:
437+
rtol, atol = 5e-2, 1.5e-1
438+
if torch.version.hip:
439+
atol *= 2
440+
# set seed
441+
current_platform.seed_everything(0)
442+
batch_size = 4
443+
token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
444+
total_tokens = int(token_counts.sum().item())
445+
cu_seqlens = torch.tensor(
446+
[0] + torch.cumsum(token_counts, dim=0).tolist(),
447+
dtype=torch.int32,
448+
device=device,
449+
)
450+
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
451+
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
452+
out = torch.empty_like(x)
453+
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
454+
dt_bias = torch.rand(dim, device=device) - 4.0
455+
A = -torch.rand(dim, dstate, device=device) - 1.0
456+
B = torch.randn(total_tokens, dstate, device=device)
457+
C = torch.randn(total_tokens, dstate, device=device)
458+
D = torch.randn(dim, device=device)
459+
z = torch.randn_like(x) if has_z else None
460+
state_ref = state.detach().clone()
461+
selective_state_update(
462+
state,
463+
x,
464+
dt,
465+
A,
466+
B,
467+
C,
468+
D=D,
469+
z=z,
470+
dt_bias=dt_bias,
471+
dt_softplus=True,
472+
out=out,
473+
cu_seqlens=cu_seqlens,
474+
)
475+
476+
out_ref_list = []
477+
for seq_idx in range(batch_size):
478+
start_idx = cu_seqlens[seq_idx].item()
479+
end_idx = cu_seqlens[seq_idx + 1].item()
480+
num_tokens = end_idx - start_idx
481+
for token_idx in range(num_tokens):
482+
idx = start_idx + token_idx
483+
out_ref_list.append(
484+
selective_state_update_ref(
485+
state_ref[seq_idx : seq_idx + 1],
486+
x[idx : idx + 1],
487+
dt[idx : idx + 1],
488+
A,
489+
B[idx : idx + 1],
490+
C[idx : idx + 1],
491+
D=D,
492+
z=z[idx : idx + 1] if has_z else None,
493+
dt_bias=dt_bias,
494+
dt_softplus=True,
495+
)
496+
)
497+
out_ref = torch.cat(out_ref_list, dim=0)
498+
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
499+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
500+
501+
428502
@pytest.mark.parametrize("wtype", [torch.float32])
429503
@pytest.mark.parametrize("itype", [torch.float32])
430504
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
@@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
766840
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
767841
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
768842
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
843+
844+
845+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
846+
@pytest.mark.parametrize("has_z", [False, True])
847+
@pytest.mark.parametrize("dstate", [16, 64])
848+
@pytest.mark.parametrize("dim", [2048, 4096])
849+
@pytest.mark.parametrize("max_seq_len", [2, 4])
850+
def test_selective_state_update_with_num_accepted_tokens(
851+
dim, dstate, has_z, itype, max_seq_len
852+
):
853+
device = "cuda"
854+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
855+
if itype == torch.bfloat16:
856+
rtol, atol = 5e-2, 1.5e-1
857+
if torch.version.hip:
858+
atol *= 2
859+
860+
current_platform.seed_everything(0)
861+
batch_size = 4
862+
863+
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
864+
total_tokens = int(tokens_per_seq.sum().item())
865+
866+
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
867+
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
868+
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
869+
870+
cu_seqlens = torch.tensor(
871+
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
872+
dtype=torch.int32,
873+
device=device,
874+
)
875+
876+
total_state_slots = 50
877+
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
878+
879+
state_batch_indices = torch.full(
880+
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
881+
)
882+
initial_state_slots = torch.randint(
883+
0, 15, (batch_size,), device=device, dtype=torch.int32
884+
)
885+
for seq_idx in range(batch_size):
886+
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
887+
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
888+
889+
dst_state_batch_indices = torch.full(
890+
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
891+
)
892+
slot_offset = 15
893+
dst_slots_map = {}
894+
for seq_idx in range(batch_size):
895+
for token_idx in range(tokens_per_seq[seq_idx].item()):
896+
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
897+
dst_slots_map[(seq_idx, token_idx)] = slot_offset
898+
slot_offset += 1
899+
900+
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
901+
out = torch.empty_like(x)
902+
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
903+
dt_bias = torch.rand(dim, device=device) - 4.0
904+
A = -torch.rand(dim, dstate, device=device) - 1.0
905+
B = torch.randn(total_tokens, dstate, device=device)
906+
C = torch.randn(total_tokens, dstate, device=device)
907+
D = torch.randn(dim, device=device)
908+
z = torch.randn_like(x) if has_z else None
909+
910+
state_ref_intermediate = {}
911+
out_ref_list = []
912+
913+
for seq_idx in range(batch_size):
914+
seq_start = cu_seqlens[seq_idx].item()
915+
seq_end = cu_seqlens[seq_idx + 1].item()
916+
num_tokens = seq_end - seq_start
917+
918+
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
919+
initial_slot = state_batch_indices[seq_idx, token_pos].item()
920+
state_seq = state[initial_slot : initial_slot + 1].clone()
921+
922+
for token_idx in range(num_tokens):
923+
global_idx = seq_start + token_idx
924+
925+
out_token = selective_state_update_ref(
926+
state_seq,
927+
x[global_idx : global_idx + 1],
928+
dt[global_idx : global_idx + 1],
929+
A,
930+
B[global_idx : global_idx + 1],
931+
C[global_idx : global_idx + 1],
932+
D=D,
933+
z=z[global_idx : global_idx + 1] if has_z else None,
934+
dt_bias=dt_bias,
935+
dt_softplus=True,
936+
)
937+
out_ref_list.append(out_token)
938+
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
939+
940+
out_ref = torch.cat(out_ref_list, dim=0)
941+
942+
selective_state_update(
943+
state,
944+
x,
945+
dt,
946+
A,
947+
B,
948+
C,
949+
D=D,
950+
z=z,
951+
dt_bias=dt_bias,
952+
dt_softplus=True,
953+
out=out,
954+
cu_seqlens=cu_seqlens,
955+
state_batch_indices=state_batch_indices,
956+
dst_state_batch_indices=dst_state_batch_indices,
957+
num_accepted_tokens=num_accepted_tokens,
958+
pad_slot_id=PAD_SLOT_ID,
959+
)
960+
961+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
962+
963+
for seq_idx in range(batch_size):
964+
num_tokens = tokens_per_seq[seq_idx].item()
965+
for token_idx in range(num_tokens):
966+
dst_slot = dst_slots_map[(seq_idx, token_idx)]
967+
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
968+
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)
969+
970+
971+
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
972+
@pytest.mark.parametrize("has_z", [False, True])
973+
@pytest.mark.parametrize("dstate", [16, 64])
974+
@pytest.mark.parametrize("dim", [2048, 4096])
975+
@pytest.mark.parametrize("max_seq_len", [2, 4])
976+
def test_selective_state_update_varlen_with_num_accepted(
977+
dim, dstate, has_z, itype, max_seq_len
978+
):
979+
device = "cuda"
980+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
981+
if itype == torch.bfloat16:
982+
rtol, atol = 5e-2, 1.5e-1
983+
if torch.version.hip:
984+
atol *= 2
985+
986+
current_platform.seed_everything(0)
987+
batch_size = 4
988+
989+
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
990+
total_tokens = int(tokens_per_seq.sum().item())
991+
992+
num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
993+
num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens
994+
num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted
995+
996+
cu_seqlens = torch.tensor(
997+
[0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
998+
dtype=torch.int32,
999+
device=device,
1000+
)
1001+
1002+
total_state_slots = 50
1003+
state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)
1004+
1005+
state_batch_indices = torch.full(
1006+
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
1007+
)
1008+
1009+
initial_state_slots = torch.randint(
1010+
0, 15, (batch_size,), device=device, dtype=torch.int32
1011+
)
1012+
for seq_idx in range(batch_size):
1013+
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
1014+
state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]
1015+
1016+
dst_state_batch_indices = torch.full(
1017+
(batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
1018+
)
1019+
1020+
slot_offset = 15
1021+
dst_slots_map = {}
1022+
for seq_idx in range(batch_size):
1023+
for token_idx in range(tokens_per_seq[seq_idx].item()):
1024+
dst_state_batch_indices[seq_idx, token_idx] = slot_offset
1025+
dst_slots_map[(seq_idx, token_idx)] = slot_offset
1026+
slot_offset += 1
1027+
1028+
x = torch.randn(total_tokens, dim, device=device, dtype=itype)
1029+
out = torch.empty_like(x)
1030+
dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
1031+
dt_bias = torch.rand(dim, device=device) - 4.0
1032+
A = -torch.rand(dim, dstate, device=device) - 1.0
1033+
B = torch.randn(total_tokens, dstate, device=device)
1034+
C = torch.randn(total_tokens, dstate, device=device)
1035+
D = torch.randn(dim, device=device)
1036+
z = torch.randn_like(x) if has_z else None
1037+
1038+
state_ref_intermediate = {}
1039+
1040+
for seq_idx in range(batch_size):
1041+
seq_start = cu_seqlens[seq_idx].item()
1042+
seq_end = cu_seqlens[seq_idx + 1].item()
1043+
num_tokens = seq_end - seq_start
1044+
1045+
token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
1046+
initial_slot = state_batch_indices[seq_idx, token_pos].item()
1047+
state_seq = state[initial_slot : initial_slot + 1].clone()
1048+
1049+
for token_idx in range(num_tokens):
1050+
global_idx = seq_start + token_idx
1051+
1052+
selective_state_update_ref(
1053+
state_seq,
1054+
x[global_idx : global_idx + 1],
1055+
dt[global_idx : global_idx + 1],
1056+
A,
1057+
B[global_idx : global_idx + 1],
1058+
C[global_idx : global_idx + 1],
1059+
D=D,
1060+
z=z[global_idx : global_idx + 1] if has_z else None,
1061+
dt_bias=dt_bias,
1062+
dt_softplus=True,
1063+
)
1064+
1065+
state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()
1066+
1067+
selective_state_update(
1068+
state,
1069+
x,
1070+
dt,
1071+
A,
1072+
B,
1073+
C,
1074+
D=D,
1075+
z=z,
1076+
dt_bias=dt_bias,
1077+
dt_softplus=True,
1078+
out=out,
1079+
cu_seqlens=cu_seqlens,
1080+
state_batch_indices=state_batch_indices,
1081+
dst_state_batch_indices=dst_state_batch_indices,
1082+
num_accepted_tokens=num_accepted_tokens,
1083+
pad_slot_id=PAD_SLOT_ID,
1084+
)
1085+
1086+
for seq_idx in range(batch_size):
1087+
num_tokens = tokens_per_seq[seq_idx].item()
1088+
1089+
for token_idx in range(num_tokens):
1090+
dst_slot = dst_slots_map[(seq_idx, token_idx)]
1091+
state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
1092+
1093+
assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)