mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:54:56 +08:00
VLLM_USE_TRITON_FLASH_ATTN V0 variable deprecation (#27611)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Andreas Karatzas <Andreas.Karatzas@amd.com>
This commit is contained in:
parent
7f829be7d3
commit
9f0247cfa4
@ -78,17 +78,13 @@ HF_MOUNT="/root/.cache/huggingface"
|
||||
commands=$@
|
||||
echo "Commands:$commands"
|
||||
|
||||
if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
|
||||
fi
|
||||
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"pytest -v -s basic_correctness/test_basic_correctness.py"}
|
||||
|
||||
if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
|
||||
fi
|
||||
|
||||
if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then
|
||||
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
|
||||
fi
|
||||
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"pytest -v -s compile/test_basic_correctness.py"}
|
||||
|
||||
if [[ $commands == *"pytest -v -s lora"* ]]; then
|
||||
commands=${commands//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"}
|
||||
|
||||
@ -1,516 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the triton_flash_attention kernel
|
||||
|
||||
Run `pytest tests/kernels/test_triton_flash_attention.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import (
|
||||
SUPPORTED_LAYOUTS,
|
||||
MetaData,
|
||||
compute_alibi_tensor,
|
||||
scale_fp8,
|
||||
triton_attention_rocm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class ReferenceAttention:
|
||||
def __init__(
|
||||
self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata
|
||||
):
|
||||
self.Z = Z
|
||||
self.HQ = HQ
|
||||
self.HK = HK
|
||||
self.N_CTX_Q = N_CTX_Q
|
||||
self.N_CTX_K = N_CTX_K
|
||||
self.D_HEAD = D_HEAD
|
||||
self.use_alibi = use_alibi
|
||||
self.dtype = dtype
|
||||
self.input_metadata = input_metadata
|
||||
|
||||
def fwd(self, q, k, v):
|
||||
scores = (
|
||||
torch.einsum("bhqd,bhkd->bhqk", q, k).float() * self.input_metadata.sm_scale
|
||||
)
|
||||
if self.input_metadata.causal:
|
||||
mask = torch.tril(
|
||||
torch.ones(self.N_CTX_Q, self.N_CTX_K, device="cuda"),
|
||||
diagonal=self.N_CTX_K - self.N_CTX_Q,
|
||||
)
|
||||
scores[:, :, mask == 0] = float("-inf")
|
||||
|
||||
if self.input_metadata.bias is not None:
|
||||
scores += self.input_metadata.bias
|
||||
|
||||
if self.use_alibi:
|
||||
scores += compute_alibi_tensor(
|
||||
self.input_metadata.alibi_slopes, self.N_CTX_Q, self.N_CTX_K
|
||||
)
|
||||
|
||||
p = torch.softmax(scores, dim=-1)
|
||||
if self.input_metadata.causal:
|
||||
# If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going
|
||||
# into softmax. This creates a row of NaNs as -inf - -inf == NaN.
|
||||
# So we fix this by converting the NaNs to 0s, which is what they
|
||||
# should be out of the softmax.
|
||||
nan_mask = torch.isnan(p)
|
||||
p[nan_mask == 1] = 0
|
||||
ref_out = torch.einsum("bhqk,bhkd->bhqd", p.to(self.dtype), v)
|
||||
# compare
|
||||
if self.input_metadata.layout == "bshd":
|
||||
ref_out = ref_out.transpose(1, 2).clone()
|
||||
return ref_out
|
||||
|
||||
def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
|
||||
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
|
||||
self.dtype
|
||||
)
|
||||
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
|
||||
self.dtype
|
||||
)
|
||||
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
|
||||
self.dtype
|
||||
)
|
||||
result = self.fwd(q, k, v)
|
||||
if self.input_metadata.o_scale is not None:
|
||||
result, _ = scale_fp8(result, self.input_metadata.o_scale)
|
||||
return result
|
||||
|
||||
def fwd_fp8_kv(self, q, k_quantized, v_quantized):
|
||||
k_descale, v_descale = (
|
||||
self.input_metadata.k_descale,
|
||||
self.input_metadata.v_descale,
|
||||
)
|
||||
k_dequantized = (
|
||||
k_quantized.to(torch.float32) * k_descale.to(torch.float32)
|
||||
).to(self.dtype)
|
||||
v_dequantized = (
|
||||
v_quantized.to(torch.float32) * v_descale.to(torch.float32)
|
||||
).to(self.dtype)
|
||||
return self.fwd(q, k_dequantized, v_dequantized)
|
||||
|
||||
def varlen_fwd(self, q, k, v, is_mqa=False):
|
||||
ref_out = torch.empty_like(q)
|
||||
if is_mqa:
|
||||
# Make KV look like HQ/HK "groups" of HK. Later, we will reshape so
|
||||
# the size aligns with Q.
|
||||
k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(
|
||||
-1, -1, self.HQ // self.HK, -1
|
||||
)
|
||||
v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(
|
||||
-1, -1, self.HQ // self.HK, -1
|
||||
)
|
||||
else:
|
||||
k_ref = k
|
||||
v_ref = v
|
||||
|
||||
for i in range(0, self.input_metadata.num_contexts):
|
||||
start_q, start_k = (
|
||||
self.input_metadata.cu_seqlens_q[i],
|
||||
self.input_metadata.cu_seqlens_k[i],
|
||||
)
|
||||
end_q, end_k = (
|
||||
self.input_metadata.cu_seqlens_q[i + 1],
|
||||
self.input_metadata.cu_seqlens_k[i + 1],
|
||||
)
|
||||
k_curr = k_ref[start_k:end_k]
|
||||
v_curr = v_ref[start_k:end_k]
|
||||
if is_mqa:
|
||||
k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3])
|
||||
v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3])
|
||||
scores = torch.einsum("qhd,khd->qhk", q[start_q:end_q], k_curr).float()
|
||||
p = torch.softmax(scores * self.input_metadata.sm_scale, dim=-1).half()
|
||||
ref_out[start_q:end_q] = torch.einsum("qhk,khd->qhd", p, v_curr)
|
||||
return ref_out
|
||||
|
||||
|
||||
def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False):
|
||||
q_descale = None
|
||||
if not fp8_kv:
|
||||
q, q_descale = scale_fp8(q)
|
||||
k, k_descale = scale_fp8(k)
|
||||
v, v_descale = scale_fp8(v)
|
||||
|
||||
# In real world use case, the p scale would be a parameter trained by the
|
||||
# model.
|
||||
p_scale = None
|
||||
|
||||
o_scale = torch.rand(1, device="cuda", requires_grad=False) if use_o_scale else None
|
||||
|
||||
return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale
|
||||
|
||||
|
||||
def input_helper(
|
||||
Z,
|
||||
HQ,
|
||||
HK,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
layout=None,
|
||||
use_alibi=None,
|
||||
causal=None,
|
||||
is_fp8=False,
|
||||
fp8_kv=False,
|
||||
use_o_scale=False,
|
||||
use_bias=False,
|
||||
):
|
||||
assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# Initialize q, k, v
|
||||
if layout == "bhsd":
|
||||
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
|
||||
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
|
||||
elif layout == "bshd":
|
||||
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
|
||||
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
|
||||
|
||||
if use_alibi:
|
||||
# for n heads the set of slopes is the geometric sequence that starts
|
||||
# 2^(-8/n)
|
||||
alibi_slopes = torch.tensor(
|
||||
[2 ** (-8 / HQ * i) for i in range(1, HQ + 1)],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
).repeat(Z, 1)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
if use_bias:
|
||||
bias = torch.randn(
|
||||
(1, HQ, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda", requires_grad=False
|
||||
)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
|
||||
k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
|
||||
v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
|
||||
|
||||
if is_fp8:
|
||||
(q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale) = quantize_input(
|
||||
q, k, v, use_o_scale=use_o_scale, fp8_kv=fp8_kv
|
||||
)
|
||||
else:
|
||||
q_descale = k_descale = v_descale = p_scale = o_scale = None
|
||||
|
||||
input_metadata = MetaData(
|
||||
sm_scale=D_HEAD**-0.5,
|
||||
max_seqlens_q=N_CTX_Q,
|
||||
max_seqlens_k=N_CTX_K,
|
||||
layout=layout,
|
||||
alibi_slopes=alibi_slopes,
|
||||
alibi_batch=Z,
|
||||
alibi_nheads=HQ,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
p_scale=p_scale,
|
||||
o_scale=o_scale,
|
||||
bias=bias,
|
||||
seqlen_q=N_CTX_Q,
|
||||
seqlen_k=N_CTX_K,
|
||||
)
|
||||
return q, k, v, input_metadata
|
||||
|
||||
|
||||
def varlen_input_helper(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# Random sequence lengths. Using N_CTX as kind of max of sum of individual
|
||||
# seqs
|
||||
if not equal_seqlens:
|
||||
max_seqlens_q = N_CTX_Q // Z
|
||||
max_seqlens_k = N_CTX_K // Z
|
||||
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
|
||||
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
|
||||
else:
|
||||
seqlens_q = torch.full((Z,), N_CTX_Q // Z)
|
||||
seqlens_k = torch.full((Z,), N_CTX_K // Z)
|
||||
|
||||
# Calculate cumulative sequence lengths
|
||||
cu_seqlens_q = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
seqlens_q.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
cu_seqlens_k = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
seqlens_k.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
|
||||
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
|
||||
|
||||
# Initialize q, k, v with variable lengths
|
||||
total_q = cu_seqlens_q[-1].item()
|
||||
total_k = cu_seqlens_k[-1].item()
|
||||
q = (
|
||||
torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
k = (
|
||||
torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
v = (
|
||||
torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.0, std=0.5)
|
||||
.requires_grad_()
|
||||
)
|
||||
sm_scale = D_HEAD**-0.5
|
||||
input_metadata = MetaData(sm_scale=sm_scale)
|
||||
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
|
||||
return q, k, v, input_metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(1, 48, 12, 1, 1, 64),
|
||||
(4, 4, 4, 128, 128, 65),
|
||||
(16, 48, 48, 1, 1, 128),
|
||||
(64, 48, 24, 3, 3, 128),
|
||||
(4, 4, 4, 113, 123, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("use_alibi", [True, False])
|
||||
@pytest.mark.parametrize("layout", ["bshd"])
|
||||
def test_op_fwd(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
q, k, v, input_metadata = input_helper(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, use_alibi, causal
|
||||
)
|
||||
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# triton implementation
|
||||
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
|
||||
|
||||
# Transpose here if layout is bshd so we have same reference code for all
|
||||
# layouts
|
||||
if layout == "bshd":
|
||||
q = q.transpose(1, 2).clone()
|
||||
k = k.transpose(1, 2).clone()
|
||||
v = v.transpose(1, 2).clone()
|
||||
# Replicate K and V if using MQA/GQA
|
||||
if HQ != HK:
|
||||
k = (
|
||||
k.view(k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3])
|
||||
.expand(-1, -1, HQ // HK, -1, -1)
|
||||
.reshape(k.shape[0], -1, k.shape[2], k.shape[3])
|
||||
)
|
||||
v = (
|
||||
v.view(v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3])
|
||||
.expand(-1, -1, HQ // HK, -1, -1)
|
||||
.reshape(v.shape[0], -1, v.shape[2], v.shape[3])
|
||||
)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd(q, k, v)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(4, 48, 1, 1, 64),
|
||||
(4, 48, 1, 1, 128),
|
||||
(4, 48, 3, 3, 128),
|
||||
(4, 4, 128, 128, 65),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("layout", ["bhsd"])
|
||||
@pytest.mark.parametrize("use_o_scale", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="Triton FP8 requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_op_fwd_fp8(
|
||||
Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, use_o_scale, dtype=torch.float32
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# Disable grad to save memory it won't run into OOM on CI machine.
|
||||
# q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD,
|
||||
# dtype, layout)
|
||||
|
||||
q_quantized, k_quantized, v_quantized, input_metadata = input_helper(
|
||||
Z,
|
||||
H,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
causal=causal,
|
||||
layout=layout,
|
||||
is_fp8=True,
|
||||
use_o_scale=use_o_scale,
|
||||
)
|
||||
|
||||
o = torch.empty_like(q_quantized) if use_o_scale else None
|
||||
|
||||
tri_out, _ = triton_attention_rocm(
|
||||
q_quantized, k_quantized, v_quantized, o, input_metadata
|
||||
)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized)
|
||||
|
||||
# compare
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(torch.float32), tri_out.to(torch.float32), atol=7e-2, rtol=2e-1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(4, 48, 1, 1, 64),
|
||||
(4, 48, 1, 1, 128),
|
||||
(4, 48, 3, 3, 128),
|
||||
(4, 4, 128, 128, 65),
|
||||
(4, 4, 113, 123, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("layout", ["bhsd"])
|
||||
def test_op_fwd_fp8_kv(
|
||||
Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float32
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
q, k_quantized, v_quantized, input_metadata = input_helper(
|
||||
Z,
|
||||
H,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
causal=causal,
|
||||
layout=layout,
|
||||
is_fp8=True,
|
||||
fp8_kv=True,
|
||||
)
|
||||
|
||||
o = torch.empty_like(q)
|
||||
|
||||
tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
|
||||
[
|
||||
(4, 48, 1, 1, 64),
|
||||
(4, 48, 1, 1, 128),
|
||||
(4, 48, 3, 3, 128),
|
||||
(4, 4, 128, 128, 65),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
|
||||
current_platform.seed_everything(0)
|
||||
q, k, v, input_metadata = input_helper(
|
||||
Z,
|
||||
H,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
D_HEAD,
|
||||
dtype,
|
||||
layout="bhsd",
|
||||
causal=causal,
|
||||
use_bias=use_bias,
|
||||
)
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# triton implementation
|
||||
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.fwd(q, k, v)
|
||||
|
||||
# compare
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
# NOTE: Uses thd layout, so also tests thd.
|
||||
@pytest.mark.parametrize(
|
||||
"Z, H, N_CTX, D_HEAD",
|
||||
[(1, 48, 256, 64), (4, 48, 512, 64), (16, 48, 512, 64), (64, 48, 128, 128)],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype)
|
||||
|
||||
tri_out = torch.empty_like(q)
|
||||
triton_attention_rocm(q, k, v, tri_out, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
# NOTE: Uses thd layout, so also tests thd.
|
||||
@pytest.mark.parametrize(
|
||||
"Z, HQ, HK, N_CTX, D_HEAD",
|
||||
[
|
||||
(2, 48, 24, 128, 64),
|
||||
(4, 48, 12, 256, 64),
|
||||
(4, 48, 4, 512, 64),
|
||||
(4, 64, 16, 128, 128),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [False])
|
||||
def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16):
|
||||
q, k, v, input_metadata = varlen_input_helper(
|
||||
Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype
|
||||
)
|
||||
|
||||
tri_out = torch.empty_like(q)
|
||||
triton_attention_rocm(q, k, v, tri_out, input_metadata)
|
||||
|
||||
ref_impl = ReferenceAttention(
|
||||
Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata
|
||||
)
|
||||
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True)
|
||||
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
|
||||
@ -27,13 +27,7 @@ def test_models(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_embeddings_close
|
||||
|
||||
@ -51,13 +50,7 @@ def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["pooler_config"] = PoolerConfig(
|
||||
|
||||
@ -2,18 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def test_idefics_multimodal(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -59,13 +52,7 @@ def update_config(config):
|
||||
|
||||
def test_gemma_multimodal(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
|
||||
@ -76,7 +76,6 @@ def test_prm_models(
|
||||
math_step_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
check_transformers_version(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2"
|
||||
@ -85,11 +84,6 @@ def test_prm_models(
|
||||
if current_platform.is_cpu():
|
||||
pytest.skip("CPU only supports V1")
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.reward(math_step_prompts)
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ image, embedding, and video support for different VLMs in vLLM.
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import PosixPath
|
||||
|
||||
@ -38,13 +37,6 @@ from .vlm_utils.types import (
|
||||
VLMTestType,
|
||||
)
|
||||
|
||||
# This hack is needed for phi3v & paligemma models
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
COMMON_BROADCAST_SETTINGS = {
|
||||
"test_type": VLMTestType.IMAGE,
|
||||
"dtype": "half",
|
||||
|
||||
@ -11,7 +11,6 @@ from huggingface_hub import snapshot_download
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import (
|
||||
IMAGE_ASSETS,
|
||||
@ -46,12 +45,6 @@ models = [model_path]
|
||||
|
||||
target_dtype = "half"
|
||||
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
|
||||
@ -14,7 +14,6 @@ from vllm.assets.image import ImageAsset
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.image import convert_image_mode, rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import (
|
||||
IMAGE_ASSETS,
|
||||
@ -68,12 +67,6 @@ def vllm_to_hf_output(
|
||||
|
||||
target_dtype = "half"
|
||||
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
|
||||
@ -8,7 +8,6 @@ See also `tests/kernels/moe/test_ocp_mx_moe.py`.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
|
||||
@ -246,8 +245,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
|
||||
task = "gsm8k"
|
||||
rtol = 0.03
|
||||
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=config.get_model_args(tp_size=8, model_max_len=38768),
|
||||
@ -263,8 +260,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
|
||||
and measured_value + rtol > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
|
||||
|
||||
@ -1,932 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||
|
||||
Features supported:
|
||||
|
||||
1) Fwd with causal masking
|
||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||
3) Support for different sequence lengths for q and k
|
||||
4) Nested tensor API currently does not support dropout or bias.
|
||||
|
||||
Not currently supported:
|
||||
|
||||
1) Non power of two head dims
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Avoid misleading ROCm warning.
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx1x
|
||||
else:
|
||||
on_gfx1x = lambda *args, **kwargs: False
|
||||
|
||||
torch_dtype: tl.constexpr = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(
|
||||
philox_seed, philox_offset, dropout_p, m, n, stride
|
||||
).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_fn(block_ptr, first, second, pad):
|
||||
if first and second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||
elif first:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
|
||||
elif second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
|
||||
else:
|
||||
tensor = tl.load(block_ptr)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
actual_seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr,
|
||||
OFFS_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
PADDED_HEAD: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
):
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
k = load_fn(
|
||||
K_block_ptr,
|
||||
PADDED_HEAD,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
"zero",
|
||||
)
|
||||
if PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS: # noqa: SIM102
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if USE_FP8:
|
||||
qk *= qk_scale
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(
|
||||
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
|
||||
)
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
qk += bias * 1.44269504089
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (
|
||||
batch_philox_offset
|
||||
+ start_m * BLOCK_M * actual_seqlen_k
|
||||
+ start_n
|
||||
- BLOCK_N
|
||||
)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
dropout_p,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
actual_seqlen_k,
|
||||
)
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
|
||||
if USE_FP8:
|
||||
p *= p_descale
|
||||
|
||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(
|
||||
encoded_softmax_block_ptr, (0, BLOCK_N)
|
||||
)
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def get_cdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": True},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 4, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# "BLOCK_M": 16,
|
||||
# "BLOCK_N": 16,
|
||||
# "waves_per_eu": 1,
|
||||
# "PRE_LOAD_V": False,
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"]
|
||||
|
||||
|
||||
def get_rdna_autotune_configs():
|
||||
return [
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=2,
|
||||
),
|
||||
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 4,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 2,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
# # Fall-back config.
|
||||
# triton.Config(
|
||||
# {
|
||||
# 'BLOCK_M': 16,
|
||||
# 'BLOCK_N': 16,
|
||||
# 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False
|
||||
# },
|
||||
# num_stages=1,
|
||||
# num_warps=2),
|
||||
], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"]
|
||||
|
||||
|
||||
def get_autotune_configs():
|
||||
if on_gfx1x():
|
||||
return get_rdna_autotune_configs()
|
||||
else:
|
||||
return get_cdna_autotune_configs()
|
||||
|
||||
|
||||
autotune_configs, autotune_keys = get_autotune_configs()
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=autotune_configs,
|
||||
key=autotune_keys,
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz: tl.int64,
|
||||
stride_qh: tl.int64,
|
||||
stride_qm: tl.int64,
|
||||
stride_qk: tl.int64,
|
||||
stride_kz: tl.int64,
|
||||
stride_kh: tl.int64,
|
||||
stride_kn: tl.int64,
|
||||
stride_kk: tl.int64,
|
||||
stride_vz: tl.int64,
|
||||
stride_vh: tl.int64,
|
||||
stride_vk: tl.int64,
|
||||
stride_vn: tl.int64,
|
||||
stride_oz: tl.int64,
|
||||
stride_oh: tl.int64,
|
||||
stride_om: tl.int64,
|
||||
stride_on: tl.int64,
|
||||
stride_bz: tl.int64,
|
||||
stride_bh: tl.int64,
|
||||
stride_bm: tl.int64,
|
||||
stride_bn: tl.int64,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
VARLEN: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
USE_FP8_OUT: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
|
||||
)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (
|
||||
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||
# for these masked blocks.
|
||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||
# tl.store(l_ptrs, l)
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if BIAS_TYPE != 0:
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + off_h_q * stride_bh,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = (
|
||||
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
|
||||
)
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||
# valid.
|
||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(seqlen_k, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
encoded_softmax_block_ptr = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||
# have native e^x support in HW.
|
||||
qk_scale = sm_scale * 1.44269504089
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||
if not USE_FP8:
|
||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||
acc_scale = 1.0
|
||||
else:
|
||||
qk_scale *= q_scale * k_scale
|
||||
acc_scale = p_scale * v_scale
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||
# block. In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min,
|
||||
block_max,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
bias_ptr,
|
||||
# IS_CAUSAL, ....
|
||||
False,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
False,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(
|
||||
encoded_softmax_block_ptr, (0, n_full_blocks)
|
||||
)
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
True,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
padded_head,
|
||||
USE_FP8,
|
||||
qk_scale,
|
||||
p_descale,
|
||||
)
|
||||
# epilogue
|
||||
|
||||
if USE_FP8:
|
||||
acc *= acc_scale
|
||||
acc = acc / l_i[:, None]
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
if USE_FP8_OUT:
|
||||
acc *= o_descale
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full(
|
||||
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
|
||||
)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
||||
z = tl.zeros((1,), tl.float32)
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
# overflow_size = end_m_idx - seqlen_q
|
||||
# if overflow_size > 0:
|
||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
# # This is a > check because mask being 0 blocks the store.
|
||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||
# else:
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# Need boundary check on this to make sure the padding from the
|
||||
# Q and KV tensors in both dims are not part of what we store back.
|
||||
# TODO: Do the boundary check optionally.
|
||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
max_seqlens=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
if varlen:
|
||||
assert q.dim() == 3
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||
_, nheads_k, seqlen_k, _ = k.shape
|
||||
assert max_seqlens > 0
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
causal=False,
|
||||
sm_scale=1.0,
|
||||
bias=None,
|
||||
fp8_scales=None,
|
||||
fp8_out_scale=None,
|
||||
):
|
||||
if fp8_scales is not None:
|
||||
use_fp8 = True
|
||||
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
|
||||
float8 = current_platform.fp8_dtype()
|
||||
|
||||
def check_and_convert(t, scale):
|
||||
if t.dtype != float8:
|
||||
descale = 1.0 / scale
|
||||
ts = (t * descale).clamp(min=float8_info.min, max=float8_info.max)
|
||||
return ts.to(float8)
|
||||
else:
|
||||
return t
|
||||
|
||||
q = check_and_convert(q, q_scale)
|
||||
k = check_and_convert(k, k_scale)
|
||||
v = check_and_convert(v, v_scale)
|
||||
else:
|
||||
use_fp8 = False
|
||||
q_scale = k_scale = v_scale = p_scale = 1.0
|
||||
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
)
|
||||
if True: # varlen
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
batch = len(cu_seqlens_q) - 1
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
else:
|
||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
unpadded_head_dims = {32, 64, 128, 256}
|
||||
if head_size not in unpadded_head_dims:
|
||||
padded_d_model = None
|
||||
for i in unpadded_head_dims:
|
||||
if i > head_size:
|
||||
padded_d_model = i
|
||||
break
|
||||
assert padded_d_model is not None
|
||||
else:
|
||||
padded_d_model = head_size
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||
nheads_q,
|
||||
batch,
|
||||
)
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (
|
||||
bias.stride(0),
|
||||
bias.stride(1),
|
||||
bias.stride(2),
|
||||
bias.stride(3),
|
||||
)
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
p_descale = 1.0 / p_scale
|
||||
o_descale = 1.0 / fp8_out_scale.item() if fp8_out_scale is not None else 1.0
|
||||
|
||||
arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
|
||||
arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
|
||||
|
||||
attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
sm_scale,
|
||||
q_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
p_scale,
|
||||
p_descale,
|
||||
o_descale,
|
||||
None,
|
||||
o,
|
||||
*q_strides,
|
||||
*k_strides,
|
||||
*v_strides,
|
||||
*o_strides,
|
||||
*bias_strides,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p=0.0,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=arg_max_seqlens_q,
|
||||
MAX_SEQLENS_K=arg_max_seqlens_k,
|
||||
IS_CAUSAL=causal,
|
||||
VARLEN=True,
|
||||
BLOCK_DMODEL=padded_d_model,
|
||||
BIAS_TYPE=0 if bias is None else 1,
|
||||
ENABLE_DROPOUT=False,
|
||||
RETURN_ENCODED_SOFTMAX=False,
|
||||
USE_FP8=use_fp8,
|
||||
USE_FP8_OUT=fp8_out_scale is not None,
|
||||
)
|
||||
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = head_size
|
||||
ctx.causal = causal
|
||||
ctx.dropout_p = 0.0
|
||||
ctx.philox_seed = philox_seed
|
||||
ctx.philox_offset = philox_offset
|
||||
ctx.encoded_softmax = encoded_softmax
|
||||
ctx.return_encoded_softmax = False
|
||||
return o, encoded_softmax
|
||||
|
||||
|
||||
triton_attention = _attention.apply
|
||||
@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
|
||||
VLLM_NCCL_SO_PATH: str | None = None
|
||||
LD_LIBRARY_PATH: str | None = None
|
||||
VLLM_USE_TRITON_FLASH_ATTN: bool = True
|
||||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
||||
VLLM_FLASH_ATTN_VERSION: int | None = None
|
||||
LOCAL_RANK: int = 0
|
||||
@ -521,10 +520,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
|
||||
# library file in the locations specified by `LD_LIBRARY_PATH`
|
||||
"LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None),
|
||||
# flag to control if vllm should use triton flash attention
|
||||
"VLLM_USE_TRITON_FLASH_ATTN": lambda: (
|
||||
os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Use separate prefill and decode kernels for V1 attention instead of
|
||||
# the unified triton kernel.
|
||||
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: (
|
||||
@ -1554,7 +1549,6 @@ def compute_hash() -> str:
|
||||
"VLLM_PP_LAYER_PARTITION",
|
||||
"VLLM_MLA_DISABLE",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
"VLLM_USE_TRITON_FLASH_ATTN",
|
||||
"VLLM_USE_TRITON_AWQ",
|
||||
"VLLM_DP_RANK",
|
||||
"VLLM_DP_SIZE",
|
||||
|
||||
@ -49,25 +49,8 @@ _ROCM_UNSUPPORTED_MODELS: list[str] = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_SWA_REASON = (
|
||||
"Sliding window attention (SWA) is not yet supported in "
|
||||
"Triton flash attention. For half-precision SWA support, "
|
||||
"please use CK flash attention by setting "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
|
||||
)
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
|
||||
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
|
||||
"MistralForCausalLM": _ROCM_SWA_REASON,
|
||||
"MixtralForCausalLM": _ROCM_SWA_REASON,
|
||||
"PaliGemmaForConditionalGeneration": (
|
||||
"ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
|
||||
),
|
||||
"Phi3VForCausalLM": (
|
||||
"ROCm Triton flash attention may run into compilation errors due to "
|
||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`"
|
||||
),
|
||||
}
|
||||
_ROCM_SWA_REASON = ()
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
|
||||
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
"0x74a0": "AMD_Instinct_MI300A",
|
||||
"0x74a1": "AMD_Instinct_MI300X",
|
||||
|
||||
@ -37,7 +37,6 @@ _GLOBAL_RUNTIME_DATA = dict[str, str | int | bool]()
|
||||
|
||||
_USAGE_ENV_VARS_TO_COLLECT = [
|
||||
"VLLM_USE_MODELSCOPE",
|
||||
"VLLM_USE_TRITON_FLASH_ATTN",
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
"VLLM_USE_FLASHINFER_SAMPLER",
|
||||
"VLLM_PP_LAYER_PARTITION",
|
||||
|
||||
@ -5,22 +5,18 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@ -99,54 +95,17 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
self.triton_fa_func = triton_attention if HAS_TRITON else None
|
||||
|
||||
def _flash_attn_varlen_diff_headdims_rocm(
|
||||
self, q, k, v, softmax_scale=None, **kwargs
|
||||
):
|
||||
assert self.triton_fa_func is not None
|
||||
|
||||
# Triton Attention requires a padded V
|
||||
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||
# The output of triton_attention is a tuple of
|
||||
# [output_tensor, encoded_softmax] where encoded_softmax is always None
|
||||
output_tensor, _ = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
padded_v,
|
||||
None, # output
|
||||
kwargs["cu_seqlens_q"],
|
||||
kwargs["cu_seqlens_k"],
|
||||
kwargs["max_seqlen_q"],
|
||||
kwargs["max_seqlen_k"],
|
||||
kwargs["causal"],
|
||||
softmax_scale,
|
||||
None, # bias
|
||||
)
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and self.use_triton_flash_attn
|
||||
and not return_softmax_lse
|
||||
):
|
||||
return self._flash_attn_varlen_diff_headdims_rocm(
|
||||
q, k, v, softmax_scale=softmax_scale, **kwargs
|
||||
)
|
||||
else:
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user