2929# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3030
3131
32- aten = torch .ops .aten
33- DispatchKey = torch ._C .DispatchKey # type: ignore[attr-defined]
34-
35-
36- @torch .library .register_kernel ("aten::_scaled_mm" , "cpu" )
37- def _scaled_mm_cpu (
32+ def _scaled_mm_cpu_out (
3833 mat1 : Tensor ,
3934 mat2 : Tensor ,
4035 scale1 : Tensor ,
@@ -43,15 +38,50 @@ def _scaled_mm_cpu(
4338 scale_result : Optional [Tensor ] = None ,
4439 out_dtype : Optional [torch .dtype ] = None ,
4540 use_fast_accum : bool = False ,
41+ * ,
42+ out : Optional [Tensor ] = None ,
4643) -> Tensor :
4744 if out_dtype is None :
4845 out_dtype = torch .float32
4946 mat1 = (mat1 .to (dtype = out_dtype ) * scale1 ).to (dtype = out_dtype )
5047 mat2 = (mat2 .to (dtype = out_dtype ) * scale2 ).to (dtype = out_dtype )
5148
5249 if bias is not None :
53- return torch .addmm (bias , mat1 , mat2 ).to (dtype = out_dtype )
54- return torch .mm (mat1 , mat2 ).to (dtype = out_dtype )
50+ ret = torch .addmm (bias , mat1 , mat2 ).to (dtype = out_dtype )
51+ else :
52+ ret = torch .mm (mat1 , mat2 ).to (dtype = out_dtype )
53+
54+ if out is not None :
55+ out .copy_ (ret )
56+ return out
57+ return ret
58+
59+
60+ torch .library .register_kernel (torch .ops .aten ._scaled_mm .out , "cpu" , _scaled_mm_cpu_out )
61+
62+
63+ @torch .library .register_kernel ("aten::_scaled_mm" , "cpu" )
64+ def _scaled_mm_cpu (
65+ mat1 : Tensor ,
66+ mat2 : Tensor ,
67+ scale1 : Tensor ,
68+ scale2 : Tensor ,
69+ bias : Optional [Tensor ] = None ,
70+ scale_result : Optional [Tensor ] = None ,
71+ out_dtype : Optional [torch .dtype ] = None ,
72+ use_fast_accum : bool = False ,
73+ ) -> Tensor :
74+ return _scaled_mm_cpu_out (
75+ mat1 ,
76+ mat2 ,
77+ scale1 ,
78+ scale2 ,
79+ bias ,
80+ scale_result ,
81+ out_dtype ,
82+ use_fast_accum ,
83+ out = None ,
84+ )
5585
5686
5787@torch .library .custom_op ("spyre::scaled_bmm" , mutates_args = ())
@@ -127,7 +157,6 @@ def scaled_paged_attn_store(
127157 Scales key and value tensors, and stores them to the paged KV cache
128158 using the same schema as vLLM.
129159 """
130- print ("Should never hit" )
131160 result_key_cache = key_cache .clone ()
132161 result_value_cache = value_cache .clone ()
133162 for seq_i , slot_mapping_seq in enumerate (slot_mapping ):
@@ -136,10 +165,10 @@ def scaled_paged_attn_store(
136165 position = slot .item () % 64
137166
138167 result_key_cache [block_number , position , :, :] = (
139- key [seq_i , tok_i , :, :] / key_scale
168+ key [seq_i , tok_i , :, :] / key_scale [ seq_i ]
140169 ).to (dtype = torch .float8_e4m3fn )
141170 result_value_cache [block_number , position , :, :] = (
142- value [seq_i , tok_i , :, :] / value_scale
171+ value [seq_i , tok_i , :, :] / value_scale [ seq_i ]
143172 ).to (dtype = torch .float8_e4m3fn )
144173 return result_key_cache , result_value_cache
145174
@@ -179,7 +208,6 @@ def scaled_paged_attn_compute(
179208 Implements a CPU fallback to run the kernel that has been confirmed
180209 to match the vLLM fused kernel.
181210 """
182- print ("Should never hit" )
183211 # torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
184212 output = torch .zeros_like (query )
185213 num_query_heads = query .shape [2 ]
@@ -220,10 +248,10 @@ def scaled_paged_attn_compute(
220248
221249 out = F .scaled_dot_product_attention ( # noqa: E1102
222250 q .transpose (0 , 1 ).unsqueeze (0 ), # format for sdpa
223- (keys .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * key_scale ).to (
251+ (keys .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * key_scale [ i ] ).to (
224252 dtype = q .dtype
225253 ), # format for sdpa
226- (values .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * value_scale ).to (
254+ (values .transpose (0 , 1 ).unsqueeze (0 ).to (dtype = q .dtype ) * value_scale [ i ] ).to (
227255 dtype = q .dtype
228256 ), # format for sdpa
229257 is_causal = False , # decode assumes no causal mask
0 commit comments