Skip to content

Commit c920911

Browse files
Merge pull request #162 from ani300/fp8_sequence_fixes
feat: Per-sequence scaling in FP8 attention, FP8 fixes
2 parents 67a5e55 + dacf1d2 commit c920911

File tree

3 files changed

+90
-27
lines changed

3 files changed

+90
-27
lines changed

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,24 @@ def _math_fp8_store_op(
7070
k_scale = key_cache._scale
7171
v_scale = value_cache._scale
7272
else:
73-
k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32)
74-
v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32)
73+
k_scale = (
74+
(torch.abs(keys).amax(dim=(1, 2, 3)) / K_RANGE)
75+
.clamp(min=1e-5)
76+
.to(dtype=torch.float32)
77+
)
78+
v_scale = (
79+
(torch.abs(values).amax(dim=(1, 2, 3)) / V_RANGE)
80+
.clamp(min=1e-5)
81+
.to(dtype=torch.float32)
82+
)
7583

7684
# Scale kv tensors for storage
77-
keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1)
78-
values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1)
85+
keys = (
86+
(keys / k_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn).transpose(2, 1)
87+
)
88+
values = (
89+
(values / v_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn).transpose(2, 1)
90+
)
7991

8092
if (
8193
isinstance(key_cache, ScaledTensor)
@@ -134,10 +146,20 @@ def _math_fp8_compute_op(
134146
value_cache = value_cache._data
135147
else:
136148
# Store op wasn't run (e.g. encoders, use_cache=False)
137-
k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32)
138-
v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32)
139-
key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn)
140-
value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn)
149+
k_scale = (
150+
(torch.abs(key_cache).amax(dim=(1, 2, 3)) / K_RANGE)
151+
.clamp(min=1e-5)
152+
.to(dtype=torch.float32)
153+
)
154+
v_scale = (
155+
(torch.abs(value_cache).amax(dim=(1, 2, 3)) / V_RANGE)
156+
.clamp(min=1e-5)
157+
.to(dtype=torch.float32)
158+
)
159+
key_cache = (key_cache / k_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn)
160+
value_cache = (value_cache / v_scale.view(-1, 1, 1, 1)).to(
161+
torch.float8_e4m3fn
162+
)
141163

142164
# If store wasn't run, we need to transpose the tensors here
143165
# TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here
@@ -192,14 +214,20 @@ def _math_fp8_compute_op(
192214
* scale_factor
193215
)
194216
else:
195-
key_t = (key_cache.to(dtype=orig_dtype) * k_scale).transpose(-2, -1)
217+
key_t = (
218+
(key_cache.to(dtype=orig_dtype) * k_scale.view(-1, 1, 1, 1))
219+
.to(dtype=orig_dtype)
220+
.transpose(-2, -1)
221+
)
196222
attn_weight = query @ key_t
197223
attn_weight *= scale_factor
198224
attn_weight += attn_bias
199225
attn_weight = torch.softmax(attn_weight, dim=-1)
200226
attn_weight = torch.dropout(attn_weight, p_dropout, train=True)
201227
# Do matmul in orig_dtype
202-
attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale)
228+
attn = attn_weight @ (
229+
value_cache.to(dtype=orig_dtype) * v_scale.view(-1, 1, 1, 1)
230+
).to(dtype=orig_dtype)
203231

204232
attn = attn.to(orig_dtype).transpose(2, 1).contiguous()
205233
return attn
@@ -226,9 +254,15 @@ def _spyre_scaled_paged_store_op(
226254
value_cache, ScaledTensor
227255
), "kv cache must be preallocated"
228256
if not key_cache._scaled:
229-
key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32)
230-
value_cache._scale = (torch.abs(values).max() / 100.0).to(
231-
dtype=torch.float32
257+
key_cache._scale = (
258+
(torch.abs(keys).amax(dim=(1, 2, 3)) / K_RANGE)
259+
.clamp(min=1e-5)
260+
.to(dtype=torch.float32)
261+
)
262+
value_cache._scale = (
263+
(torch.abs(values).amax(dim=(1, 2, 3)) / V_RANGE)
264+
.clamp(min=1e-5)
265+
.to(dtype=torch.float32)
232266
)
233267

234268
result_key_cache_data, result_value_cache_data = (

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121

2222
# Local
23+
from fms_mo.aiu_addons.fp8 import fp8_spyre_op # pylint: disable=unused-import
2324
from fms_mo.prep import available_packages
2425

2526
# pylint: disable=not-callable

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@
2929
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
3030

3131

32-
aten = torch.ops.aten
33-
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
34-
35-
36-
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
37-
def _scaled_mm_cpu(
32+
def _scaled_mm_cpu_out(
3833
mat1: Tensor,
3934
mat2: Tensor,
4035
scale1: Tensor,
@@ -43,15 +38,50 @@ def _scaled_mm_cpu(
4338
scale_result: Optional[Tensor] = None,
4439
out_dtype: Optional[torch.dtype] = None,
4540
use_fast_accum: bool = False,
41+
*,
42+
out: Optional[Tensor] = None,
4643
) -> Tensor:
4744
if out_dtype is None:
4845
out_dtype = torch.float32
4946
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
5047
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)
5148

5249
if bias is not None:
53-
return torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
54-
return torch.mm(mat1, mat2).to(dtype=out_dtype)
50+
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
51+
else:
52+
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)
53+
54+
if out is not None:
55+
out.copy_(ret)
56+
return out
57+
return ret
58+
59+
60+
torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out)
61+
62+
63+
@torch.library.register_kernel("aten::_scaled_mm", "cpu")
64+
def _scaled_mm_cpu(
65+
mat1: Tensor,
66+
mat2: Tensor,
67+
scale1: Tensor,
68+
scale2: Tensor,
69+
bias: Optional[Tensor] = None,
70+
scale_result: Optional[Tensor] = None,
71+
out_dtype: Optional[torch.dtype] = None,
72+
use_fast_accum: bool = False,
73+
) -> Tensor:
74+
return _scaled_mm_cpu_out(
75+
mat1,
76+
mat2,
77+
scale1,
78+
scale2,
79+
bias,
80+
scale_result,
81+
out_dtype,
82+
use_fast_accum,
83+
out=None,
84+
)
5585

5686

5787
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
@@ -127,7 +157,6 @@ def scaled_paged_attn_store(
127157
Scales key and value tensors, and stores them to the paged KV cache
128158
using the same schema as vLLM.
129159
"""
130-
print("Should never hit")
131160
result_key_cache = key_cache.clone()
132161
result_value_cache = value_cache.clone()
133162
for seq_i, slot_mapping_seq in enumerate(slot_mapping):
@@ -136,10 +165,10 @@ def scaled_paged_attn_store(
136165
position = slot.item() % 64
137166

138167
result_key_cache[block_number, position, :, :] = (
139-
key[seq_i, tok_i, :, :] / key_scale
168+
key[seq_i, tok_i, :, :] / key_scale[seq_i]
140169
).to(dtype=torch.float8_e4m3fn)
141170
result_value_cache[block_number, position, :, :] = (
142-
value[seq_i, tok_i, :, :] / value_scale
171+
value[seq_i, tok_i, :, :] / value_scale[seq_i]
143172
).to(dtype=torch.float8_e4m3fn)
144173
return result_key_cache, result_value_cache
145174

@@ -179,7 +208,6 @@ def scaled_paged_attn_compute(
179208
Implements a CPU fallback to run the kernel that has been confirmed
180209
to match the vLLM fused kernel.
181210
"""
182-
print("Should never hit")
183211
# torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
184212
output = torch.zeros_like(query)
185213
num_query_heads = query.shape[2]
@@ -220,10 +248,10 @@ def scaled_paged_attn_compute(
220248

221249
out = F.scaled_dot_product_attention( # noqa: E1102
222250
q.transpose(0, 1).unsqueeze(0), # format for sdpa
223-
(keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale).to(
251+
(keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to(
224252
dtype=q.dtype
225253
), # format for sdpa
226-
(values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale).to(
254+
(values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to(
227255
dtype=q.dtype
228256
), # format for sdpa
229257
is_causal=False, # decode assumes no causal mask

0 commit comments

Comments
 (0)