Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions fms_mo/aiu_addons/fp8/fp8_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,24 @@ def _math_fp8_store_op(
k_scale = key_cache._scale
v_scale = value_cache._scale
else:
k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32)
v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32)
k_scale = (
(torch.abs(keys).amax(dim=(1, 2, 3)) / K_RANGE)
.clamp(min=1e-5)
.to(dtype=torch.float32)
)
v_scale = (
(torch.abs(values).amax(dim=(1, 2, 3)) / V_RANGE)
.clamp(min=1e-5)
.to(dtype=torch.float32)
)

# Scale kv tensors for storage
keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1)
values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1)
keys = (
(keys / k_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn).transpose(2, 1)
)
values = (
(values / v_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn).transpose(2, 1)
)

if (
isinstance(key_cache, ScaledTensor)
Expand Down Expand Up @@ -134,10 +146,20 @@ def _math_fp8_compute_op(
value_cache = value_cache._data
else:
# Store op wasn't run (e.g. encoders, use_cache=False)
k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32)
v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32)
key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn)
value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn)
k_scale = (
(torch.abs(key_cache).amax(dim=(1, 2, 3)) / K_RANGE)
.clamp(min=1e-5)
.to(dtype=torch.float32)
)
v_scale = (
(torch.abs(value_cache).amax(dim=(1, 2, 3)) / V_RANGE)
.clamp(min=1e-5)
.to(dtype=torch.float32)
)
key_cache = (key_cache / k_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn)
value_cache = (value_cache / v_scale.view(-1, 1, 1, 1)).to(
torch.float8_e4m3fn
)

# If store wasn't run, we need to transpose the tensors here
# TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here
Expand Down Expand Up @@ -192,14 +214,20 @@ def _math_fp8_compute_op(
* scale_factor
)
else:
key_t = (key_cache.to(dtype=orig_dtype) * k_scale).transpose(-2, -1)
key_t = (
(key_cache.to(dtype=orig_dtype) * k_scale.view(-1, 1, 1, 1))
.to(dtype=orig_dtype)
.transpose(-2, -1)
)
attn_weight = query @ key_t
attn_weight *= scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, p_dropout, train=True)
# Do matmul in orig_dtype
attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale)
attn = attn_weight @ (
value_cache.to(dtype=orig_dtype) * v_scale.view(-1, 1, 1, 1)
).to(dtype=orig_dtype)

attn = attn.to(orig_dtype).transpose(2, 1).contiguous()
return attn
Expand All @@ -226,9 +254,15 @@ def _spyre_scaled_paged_store_op(
value_cache, ScaledTensor
), "kv cache must be preallocated"
if not key_cache._scaled:
key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32)
value_cache._scale = (torch.abs(values).max() / 100.0).to(
dtype=torch.float32
key_cache._scale = (
(torch.abs(keys).amax(dim=(1, 2, 3)) / K_RANGE)
.clamp(min=1e-5)
.to(dtype=torch.float32)
)
value_cache._scale = (
(torch.abs(values).amax(dim=(1, 2, 3)) / V_RANGE)
.clamp(min=1e-5)
.to(dtype=torch.float32)
)

result_key_cache_data, result_value_cache_data = (
Expand Down
1 change: 1 addition & 0 deletions fms_mo/aiu_addons/fp8/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch

# Local
from fms_mo.aiu_addons.fp8 import fp8_spyre_op # pylint: disable=unused-import
from fms_mo.prep import available_packages

# pylint: disable=not-callable
Expand Down
56 changes: 42 additions & 14 deletions fms_mo/aiu_addons/fp8/fp8_spyre_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482


aten = torch.ops.aten
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]


@torch.library.register_kernel("aten::_scaled_mm", "cpu")
def _scaled_mm_cpu(
def _scaled_mm_cpu_out(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
Expand All @@ -43,15 +38,50 @@ def _scaled_mm_cpu(
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
*,
out: Optional[Tensor] = None,
) -> Tensor:
if out_dtype is None:
out_dtype = torch.float32
mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype)
mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype)

if bias is not None:
return torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
return torch.mm(mat1, mat2).to(dtype=out_dtype)
ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype)
else:
ret = torch.mm(mat1, mat2).to(dtype=out_dtype)

if out is not None:
out.copy_(ret)
return out
return ret


torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out)


@torch.library.register_kernel("aten::_scaled_mm", "cpu")
def _scaled_mm_cpu(
mat1: Tensor,
mat2: Tensor,
scale1: Tensor,
scale2: Tensor,
bias: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> Tensor:
return _scaled_mm_cpu_out(
mat1,
mat2,
scale1,
scale2,
bias,
scale_result,
out_dtype,
use_fast_accum,
out=None,
)


@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
Expand Down Expand Up @@ -127,7 +157,6 @@ def scaled_paged_attn_store(
Scales key and value tensors, and stores them to the paged KV cache
using the same schema as vLLM.
"""
print("Should never hit")
result_key_cache = key_cache.clone()
result_value_cache = value_cache.clone()
for seq_i, slot_mapping_seq in enumerate(slot_mapping):
Expand All @@ -136,10 +165,10 @@ def scaled_paged_attn_store(
position = slot.item() % 64

result_key_cache[block_number, position, :, :] = (
key[seq_i, tok_i, :, :] / key_scale
key[seq_i, tok_i, :, :] / key_scale[seq_i]
).to(dtype=torch.float8_e4m3fn)
result_value_cache[block_number, position, :, :] = (
value[seq_i, tok_i, :, :] / value_scale
value[seq_i, tok_i, :, :] / value_scale[seq_i]
).to(dtype=torch.float8_e4m3fn)
return result_key_cache, result_value_cache

Expand Down Expand Up @@ -179,7 +208,6 @@ def scaled_paged_attn_compute(
Implements a CPU fallback to run the kernel that has been confirmed
to match the vLLM fused kernel.
"""
print("Should never hit")
# torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype),
output = torch.zeros_like(query)
num_query_heads = query.shape[2]
Expand Down Expand Up @@ -220,10 +248,10 @@ def scaled_paged_attn_compute(

out = F.scaled_dot_product_attention( # noqa: E1102
q.transpose(0, 1).unsqueeze(0), # format for sdpa
(keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale).to(
(keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to(
dtype=q.dtype
), # format for sdpa
(values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale).to(
(values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to(
dtype=q.dtype
), # format for sdpa
is_causal=False, # decode assumes no causal mask
Expand Down
Loading