@@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
425425 assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
426426
427427
428+ @pytest .mark .parametrize ("itype" , [torch .float32 , torch .bfloat16 ])
429+ @pytest .mark .parametrize ("has_z" , [False , True ])
430+ @pytest .mark .parametrize ("dstate" , [16 , 64 ])
431+ @pytest .mark .parametrize ("dim" , [2048 , 2048 + 16 , 4096 ])
432+ @pytest .mark .parametrize ("max_seq_len" , [1 , 2 , 4 ])
433+ def test_selective_state_update_varlen (dim , dstate , has_z , itype , max_seq_len ):
434+ device = "cuda"
435+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 1e-2 )
436+ if itype == torch .bfloat16 :
437+ rtol , atol = 5e-2 , 1.5e-1
438+ if torch .version .hip :
439+ atol *= 2
440+ # set seed
441+ current_platform .seed_everything (0 )
442+ batch_size = 4
443+ token_counts = torch .randint (1 , max_seq_len + 1 , (batch_size ,), device = device )
444+ total_tokens = int (token_counts .sum ().item ())
445+ cu_seqlens = torch .tensor (
446+ [0 ] + torch .cumsum (token_counts , dim = 0 ).tolist (),
447+ dtype = torch .int32 ,
448+ device = device ,
449+ )
450+ state = torch .randn (batch_size , dim , dstate , dtype = itype , device = device )
451+ x = torch .randn (total_tokens , dim , device = device , dtype = itype )
452+ out = torch .empty_like (x )
453+ dt = torch .randn (total_tokens , dim , device = device , dtype = itype )
454+ dt_bias = torch .rand (dim , device = device ) - 4.0
455+ A = - torch .rand (dim , dstate , device = device ) - 1.0
456+ B = torch .randn (total_tokens , dstate , device = device )
457+ C = torch .randn (total_tokens , dstate , device = device )
458+ D = torch .randn (dim , device = device )
459+ z = torch .randn_like (x ) if has_z else None
460+ state_ref = state .detach ().clone ()
461+ selective_state_update (
462+ state ,
463+ x ,
464+ dt ,
465+ A ,
466+ B ,
467+ C ,
468+ D = D ,
469+ z = z ,
470+ dt_bias = dt_bias ,
471+ dt_softplus = True ,
472+ out = out ,
473+ cu_seqlens = cu_seqlens ,
474+ )
475+
476+ out_ref_list = []
477+ for seq_idx in range (batch_size ):
478+ start_idx = cu_seqlens [seq_idx ].item ()
479+ end_idx = cu_seqlens [seq_idx + 1 ].item ()
480+ num_tokens = end_idx - start_idx
481+ for token_idx in range (num_tokens ):
482+ idx = start_idx + token_idx
483+ out_ref_list .append (
484+ selective_state_update_ref (
485+ state_ref [seq_idx : seq_idx + 1 ],
486+ x [idx : idx + 1 ],
487+ dt [idx : idx + 1 ],
488+ A ,
489+ B [idx : idx + 1 ],
490+ C [idx : idx + 1 ],
491+ D = D ,
492+ z = z [idx : idx + 1 ] if has_z else None ,
493+ dt_bias = dt_bias ,
494+ dt_softplus = True ,
495+ )
496+ )
497+ out_ref = torch .cat (out_ref_list , dim = 0 )
498+ assert torch .allclose (state , state_ref , rtol = rtol , atol = atol )
499+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
500+
501+
428502@pytest .mark .parametrize ("wtype" , [torch .float32 ])
429503@pytest .mark .parametrize ("itype" , [torch .float32 ])
430504@pytest .mark .parametrize ("seqlen" , [1 , 256 , 1024 , 4096 ])
@@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
766840 print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
767841 assert torch .allclose (state [state_indices , :], state_ref , rtol = rtol , atol = atol )
768842 assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
843+
844+
845+ @pytest .mark .parametrize ("itype" , [torch .float32 , torch .bfloat16 ])
846+ @pytest .mark .parametrize ("has_z" , [False , True ])
847+ @pytest .mark .parametrize ("dstate" , [16 , 64 ])
848+ @pytest .mark .parametrize ("dim" , [2048 , 4096 ])
849+ @pytest .mark .parametrize ("max_seq_len" , [2 , 4 ])
850+ def test_selective_state_update_with_num_accepted_tokens (
851+ dim , dstate , has_z , itype , max_seq_len
852+ ):
853+ device = "cuda"
854+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 1e-2 )
855+ if itype == torch .bfloat16 :
856+ rtol , atol = 5e-2 , 1.5e-1
857+ if torch .version .hip :
858+ atol *= 2
859+
860+ current_platform .seed_everything (0 )
861+ batch_size = 4
862+
863+ tokens_per_seq = torch .randint (1 , max_seq_len + 1 , (batch_size ,), device = device )
864+ total_tokens = int (tokens_per_seq .sum ().item ())
865+
866+ num_accepted_tokens = torch .randint (0 , max_seq_len , (batch_size ,), device = device )
867+ num_accepted_tokens [0 ] = 0 # Add edge-case of no accepted tokens
868+ num_accepted_tokens [1 ] = max_seq_len # Add edge-case of all tokens accepted
869+
870+ cu_seqlens = torch .tensor (
871+ [0 ] + torch .cumsum (tokens_per_seq , dim = 0 ).tolist (),
872+ dtype = torch .int32 ,
873+ device = device ,
874+ )
875+
876+ total_state_slots = 50
877+ state = torch .randn (total_state_slots , dim , dstate , dtype = itype , device = device )
878+
879+ state_batch_indices = torch .full (
880+ (batch_size , max_seq_len ), PAD_SLOT_ID , dtype = torch .int32 , device = device
881+ )
882+ initial_state_slots = torch .randint (
883+ 0 , 15 , (batch_size ,), device = device , dtype = torch .int32
884+ )
885+ for seq_idx in range (batch_size ):
886+ token_pos = max (num_accepted_tokens [seq_idx ].item () - 1 , 0 )
887+ state_batch_indices [seq_idx , token_pos ] = initial_state_slots [seq_idx ]
888+
889+ dst_state_batch_indices = torch .full (
890+ (batch_size , max_seq_len ), PAD_SLOT_ID , dtype = torch .int32 , device = device
891+ )
892+ slot_offset = 15
893+ dst_slots_map = {}
894+ for seq_idx in range (batch_size ):
895+ for token_idx in range (tokens_per_seq [seq_idx ].item ()):
896+ dst_state_batch_indices [seq_idx , token_idx ] = slot_offset
897+ dst_slots_map [(seq_idx , token_idx )] = slot_offset
898+ slot_offset += 1
899+
900+ x = torch .randn (total_tokens , dim , device = device , dtype = itype )
901+ out = torch .empty_like (x )
902+ dt = torch .randn (total_tokens , dim , device = device , dtype = itype )
903+ dt_bias = torch .rand (dim , device = device ) - 4.0
904+ A = - torch .rand (dim , dstate , device = device ) - 1.0
905+ B = torch .randn (total_tokens , dstate , device = device )
906+ C = torch .randn (total_tokens , dstate , device = device )
907+ D = torch .randn (dim , device = device )
908+ z = torch .randn_like (x ) if has_z else None
909+
910+ state_ref_intermediate = {}
911+ out_ref_list = []
912+
913+ for seq_idx in range (batch_size ):
914+ seq_start = cu_seqlens [seq_idx ].item ()
915+ seq_end = cu_seqlens [seq_idx + 1 ].item ()
916+ num_tokens = seq_end - seq_start
917+
918+ token_pos = max (num_accepted_tokens [seq_idx ].item () - 1 , 0 )
919+ initial_slot = state_batch_indices [seq_idx , token_pos ].item ()
920+ state_seq = state [initial_slot : initial_slot + 1 ].clone ()
921+
922+ for token_idx in range (num_tokens ):
923+ global_idx = seq_start + token_idx
924+
925+ out_token = selective_state_update_ref (
926+ state_seq ,
927+ x [global_idx : global_idx + 1 ],
928+ dt [global_idx : global_idx + 1 ],
929+ A ,
930+ B [global_idx : global_idx + 1 ],
931+ C [global_idx : global_idx + 1 ],
932+ D = D ,
933+ z = z [global_idx : global_idx + 1 ] if has_z else None ,
934+ dt_bias = dt_bias ,
935+ dt_softplus = True ,
936+ )
937+ out_ref_list .append (out_token )
938+ state_ref_intermediate [(seq_idx , token_idx )] = state_seq .clone ()
939+
940+ out_ref = torch .cat (out_ref_list , dim = 0 )
941+
942+ selective_state_update (
943+ state ,
944+ x ,
945+ dt ,
946+ A ,
947+ B ,
948+ C ,
949+ D = D ,
950+ z = z ,
951+ dt_bias = dt_bias ,
952+ dt_softplus = True ,
953+ out = out ,
954+ cu_seqlens = cu_seqlens ,
955+ state_batch_indices = state_batch_indices ,
956+ dst_state_batch_indices = dst_state_batch_indices ,
957+ num_accepted_tokens = num_accepted_tokens ,
958+ pad_slot_id = PAD_SLOT_ID ,
959+ )
960+
961+ assert torch .allclose (out , out_ref , rtol = rtol , atol = atol )
962+
963+ for seq_idx in range (batch_size ):
964+ num_tokens = tokens_per_seq [seq_idx ].item ()
965+ for token_idx in range (num_tokens ):
966+ dst_slot = dst_slots_map [(seq_idx , token_idx )]
967+ state_ref = state_ref_intermediate [(seq_idx , token_idx )].squeeze (0 )
968+ assert torch .allclose (state [dst_slot ], state_ref , rtol = rtol , atol = atol )
969+
970+
971+ @pytest .mark .parametrize ("itype" , [torch .float32 , torch .bfloat16 ])
972+ @pytest .mark .parametrize ("has_z" , [False , True ])
973+ @pytest .mark .parametrize ("dstate" , [16 , 64 ])
974+ @pytest .mark .parametrize ("dim" , [2048 , 4096 ])
975+ @pytest .mark .parametrize ("max_seq_len" , [2 , 4 ])
976+ def test_selective_state_update_varlen_with_num_accepted (
977+ dim , dstate , has_z , itype , max_seq_len
978+ ):
979+ device = "cuda"
980+ rtol , atol = (3e-4 , 1e-3 ) if itype == torch .float32 else (5e-3 , 1e-2 )
981+ if itype == torch .bfloat16 :
982+ rtol , atol = 5e-2 , 1.5e-1
983+ if torch .version .hip :
984+ atol *= 2
985+
986+ current_platform .seed_everything (0 )
987+ batch_size = 4
988+
989+ tokens_per_seq = torch .randint (1 , max_seq_len + 1 , (batch_size ,), device = device )
990+ total_tokens = int (tokens_per_seq .sum ().item ())
991+
992+ num_accepted_tokens = torch .randint (0 , max_seq_len , (batch_size ,), device = device )
993+ num_accepted_tokens [0 ] = 0 # Add edge-case of no accepted tokens
994+ num_accepted_tokens [1 ] = max_seq_len # Add edge-case of all tokens accepted
995+
996+ cu_seqlens = torch .tensor (
997+ [0 ] + torch .cumsum (tokens_per_seq , dim = 0 ).tolist (),
998+ dtype = torch .int32 ,
999+ device = device ,
1000+ )
1001+
1002+ total_state_slots = 50
1003+ state = torch .randn (total_state_slots , dim , dstate , dtype = itype , device = device )
1004+
1005+ state_batch_indices = torch .full (
1006+ (batch_size , max_seq_len ), PAD_SLOT_ID , dtype = torch .int32 , device = device
1007+ )
1008+
1009+ initial_state_slots = torch .randint (
1010+ 0 , 15 , (batch_size ,), device = device , dtype = torch .int32
1011+ )
1012+ for seq_idx in range (batch_size ):
1013+ token_pos = max (num_accepted_tokens [seq_idx ].item () - 1 , 0 )
1014+ state_batch_indices [seq_idx , token_pos ] = initial_state_slots [seq_idx ]
1015+
1016+ dst_state_batch_indices = torch .full (
1017+ (batch_size , max_seq_len ), PAD_SLOT_ID , dtype = torch .int32 , device = device
1018+ )
1019+
1020+ slot_offset = 15
1021+ dst_slots_map = {}
1022+ for seq_idx in range (batch_size ):
1023+ for token_idx in range (tokens_per_seq [seq_idx ].item ()):
1024+ dst_state_batch_indices [seq_idx , token_idx ] = slot_offset
1025+ dst_slots_map [(seq_idx , token_idx )] = slot_offset
1026+ slot_offset += 1
1027+
1028+ x = torch .randn (total_tokens , dim , device = device , dtype = itype )
1029+ out = torch .empty_like (x )
1030+ dt = torch .randn (total_tokens , dim , device = device , dtype = itype )
1031+ dt_bias = torch .rand (dim , device = device ) - 4.0
1032+ A = - torch .rand (dim , dstate , device = device ) - 1.0
1033+ B = torch .randn (total_tokens , dstate , device = device )
1034+ C = torch .randn (total_tokens , dstate , device = device )
1035+ D = torch .randn (dim , device = device )
1036+ z = torch .randn_like (x ) if has_z else None
1037+
1038+ state_ref_intermediate = {}
1039+
1040+ for seq_idx in range (batch_size ):
1041+ seq_start = cu_seqlens [seq_idx ].item ()
1042+ seq_end = cu_seqlens [seq_idx + 1 ].item ()
1043+ num_tokens = seq_end - seq_start
1044+
1045+ token_pos = max (num_accepted_tokens [seq_idx ].item () - 1 , 0 )
1046+ initial_slot = state_batch_indices [seq_idx , token_pos ].item ()
1047+ state_seq = state [initial_slot : initial_slot + 1 ].clone ()
1048+
1049+ for token_idx in range (num_tokens ):
1050+ global_idx = seq_start + token_idx
1051+
1052+ selective_state_update_ref (
1053+ state_seq ,
1054+ x [global_idx : global_idx + 1 ],
1055+ dt [global_idx : global_idx + 1 ],
1056+ A ,
1057+ B [global_idx : global_idx + 1 ],
1058+ C [global_idx : global_idx + 1 ],
1059+ D = D ,
1060+ z = z [global_idx : global_idx + 1 ] if has_z else None ,
1061+ dt_bias = dt_bias ,
1062+ dt_softplus = True ,
1063+ )
1064+
1065+ state_ref_intermediate [(seq_idx , token_idx )] = state_seq .clone ()
1066+
1067+ selective_state_update (
1068+ state ,
1069+ x ,
1070+ dt ,
1071+ A ,
1072+ B ,
1073+ C ,
1074+ D = D ,
1075+ z = z ,
1076+ dt_bias = dt_bias ,
1077+ dt_softplus = True ,
1078+ out = out ,
1079+ cu_seqlens = cu_seqlens ,
1080+ state_batch_indices = state_batch_indices ,
1081+ dst_state_batch_indices = dst_state_batch_indices ,
1082+ num_accepted_tokens = num_accepted_tokens ,
1083+ pad_slot_id = PAD_SLOT_ID ,
1084+ )
1085+
1086+ for seq_idx in range (batch_size ):
1087+ num_tokens = tokens_per_seq [seq_idx ].item ()
1088+
1089+ for token_idx in range (num_tokens ):
1090+ dst_slot = dst_slots_map [(seq_idx , token_idx )]
1091+ state_ref = state_ref_intermediate [(seq_idx , token_idx )].squeeze (0 )
1092+
1093+ assert torch .allclose (state [dst_slot ], state_ref , rtol = rtol , atol = atol )
0 commit comments