mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 05:07:02 +08:00
[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
37e8182bfe
commit
9a161307f5
@ -40,13 +40,12 @@ backend_unfused: Optional[TestBackend] = None
|
||||
@pytest.mark.parametrize(
|
||||
"model, quant_key",
|
||||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
||||
@pytest.mark.parametrize(
|
||||
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
|
||||
@pytest.mark.parametrize("use_triton_fa", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test CUDA and ROCm")
|
||||
def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
quant_key: QuantKey, use_triton_fa: bool):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="V0 attn quant fusion only on ROCm")
|
||||
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
||||
quant_key: QuantKey, use_triton_fa: bool):
|
||||
# Clean Dynamo cache to avoid reusing other test cases
|
||||
# (for some reason the reset at the end is not enough)
|
||||
torch._dynamo.reset()
|
||||
@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
))
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
|
||||
llm = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config,
|
||||
model_config=ModelConfig(
|
||||
model=model,
|
||||
dtype=torch.bfloat16,
|
||||
))
|
||||
|
||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||
# so we initialize it during compilation.
|
||||
@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
llm2 = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
gpu_memory_utilization=0.5,
|
||||
max_model_len=2048)
|
||||
|
||||
# check support
|
||||
@ -171,6 +178,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
cache_config=vllm_config.cache_config,
|
||||
prefix="model.layers.0.self_attn.attn",
|
||||
)
|
||||
self.attn._k_scale = self.attn._k_scale.to(device)
|
||||
self.attn._v_scale = self.attn._v_scale.to(device)
|
||||
|
||||
self.block_size = 16
|
||||
|
||||
@ -188,7 +197,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int):
|
||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
|
||||
"""Initialize attention metadata."""
|
||||
|
||||
# Create common attn metadata
|
||||
@ -205,10 +214,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
num_blocks = batch_size * max_blocks
|
||||
|
||||
# Create dummy KV cache for FlashInfer TRTLLM
|
||||
# - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
# - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
|
||||
# Create kv_cache in HND layout and permute to NHD layout
|
||||
# (later will be permuted back to HND layout in forward pass)
|
||||
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
|
||||
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
2,
|
||||
self.num_kv_heads,
|
||||
@ -216,7 +223,17 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device)
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
|
||||
if current_platform.is_rocm():
|
||||
# k/v as 1st dimention
|
||||
if use_hnd:
|
||||
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
|
||||
else:
|
||||
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
|
||||
else:
|
||||
# k/v as 2nd dimention
|
||||
# Create kv_cache in HND layout and permute to NHD layout
|
||||
# (later will be permuted back to HND layout in forward pass)
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
|
||||
self.attn.kv_cache = [kv_cache]
|
||||
|
||||
# Build attn metadata
|
||||
@ -296,28 +313,51 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
out_dtype=attn_output.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
|
||||
if current_platform.is_cuda():
|
||||
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel),
|
||||
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel)]
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
elif current_platform.is_rocm():
|
||||
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
TestAttentionFp8StaticQuantPatternModel)]
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
else:
|
||||
MODELS = []
|
||||
HEADS = []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("batch_size", [7, 256, 533])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("model_name, model_class",
|
||||
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel),
|
||||
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel)])
|
||||
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||
@pytest.mark.parametrize("batch_size",
|
||||
[7, 256, 533] if current_platform.is_cuda() else [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if
|
||||
current_platform.is_cuda() else [_Backend.ROCM_FLASH])
|
||||
@pytest.mark.parametrize(
|
||||
"split_attention",
|
||||
[False, True] if current_platform.is_rocm() else [False])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test ROCm or CUDA")
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)),
|
||||
reason="Only test on SM100(Blackwell)")
|
||||
@pytest.mark.skipif(current_platform.is_cuda()
|
||||
and not current_platform.is_device_capability((10, 0)),
|
||||
reason="On CUDA only test on SM100(Blackwell)")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test ROCm or CUDA")
|
||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
head_size: int, batch_size: int,
|
||||
dtype: torch.dtype, model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend, monkeypatch, dist_init):
|
||||
backend: _Backend, split_attention: bool,
|
||||
monkeypatch, dist_init):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
if split_attention:
|
||||
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.manual_seed(42)
|
||||
@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
model_config=ModelConfig(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
dtype=dtype,
|
||||
),
|
||||
scheduler_config=SchedulerConfig(max_num_seqs=1024),
|
||||
compilation_config=CompilationConfig(
|
||||
@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
|
||||
batch_size)
|
||||
batch_size, use_hnd=split_attention)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v)
|
||||
@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
|
||||
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
|
||||
batch_size, use_hnd=split_attention)
|
||||
|
||||
# Create test backend with fusion passes enabled
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
assert model_compiled.attn._o_scale_float is None
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
|
||||
# After the 1st round of the forward pass, output quant scale should be
|
||||
# loaded into the attn layer's _o_scale_float, the 2nd round should
|
||||
# reuse the loaded _o_scale_float
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
result_fused_2 = model_compiled(q, k, v)
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
if backend == _Backend.FLASHINFER:
|
||||
# With the Flashinfer backend after the 1st round of the forward
|
||||
# pass, output quant scale should be loaded into the attn layer's
|
||||
# _o_scale_float, the 2nd round should reuse the loaded
|
||||
# _o_scale_float
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
result_fused_2 = model_compiled(q, k, v)
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_2,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key = model_class.quant_key
|
||||
@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
|
||||
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
|
||||
|
||||
# Check that results are closed
|
||||
# Check that results are close
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_1,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_2,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
@ -34,6 +36,7 @@ def kernel_paged_attention_2d(
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
out_scale_inv,
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
num_queries_per_kv_padded: tl.constexpr, # int
|
||||
@ -60,7 +63,9 @@ def kernel_paged_attention_2d(
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
):
|
||||
USE_FP8: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
|
||||
@ -204,6 +209,9 @@ def kernel_paged_attention_2d(
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
|
||||
query_head_idx * output_stride_1)
|
||||
@ -234,6 +242,7 @@ def chunked_prefill_paged_decode(
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
output_scale=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
@ -266,6 +275,7 @@ def chunked_prefill_paged_decode(
|
||||
sliding_window=sliding_window,
|
||||
sm_scale=sm_scale,
|
||||
skip_decode=True,
|
||||
fp8_out_scale=output_scale,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
@ -316,7 +326,7 @@ def chunked_prefill_paged_decode(
|
||||
tmp_output = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions,
|
||||
head_size),
|
||||
dtype=output.dtype,
|
||||
dtype=query.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
@ -345,6 +355,7 @@ def chunked_prefill_paged_decode(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
fp8_out_scale=output_scale,
|
||||
)
|
||||
else:
|
||||
kernel_paged_attention_2d[(
|
||||
@ -362,6 +373,8 @@ def chunked_prefill_paged_decode(
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out_scale_inv=1.0 /
|
||||
output_scale if output_scale is not None else 1.0,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
@ -388,4 +401,5 @@ def chunked_prefill_paged_decode(
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
|
||||
@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
|
||||
|
||||
# To check compatibility
|
||||
IS_TURING = current_platform.get_device_capability() == (7, 5)
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
# Here's an example autotuner config for this kernel. This config does provide
|
||||
@ -43,6 +44,7 @@ def _fwd_kernel(Q,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
out_scale_inv,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
x: tl.constexpr,
|
||||
@ -82,8 +84,11 @@ def _fwd_kernel(Q,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
USE_SINKS: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0):
|
||||
MAX_CTX_LEN: tl.constexpr = 0,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max):
|
||||
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@ -284,6 +289,9 @@ def _fwd_kernel(Q,
|
||||
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
|
||||
@ -743,6 +751,7 @@ def context_attention_fwd(q,
|
||||
sliding_window=None,
|
||||
sm_scale=None,
|
||||
skip_decode=False,
|
||||
fp8_out_scale=None,
|
||||
sinks=None):
|
||||
|
||||
q_dtype_is_f32 = q.dtype is torch.float32
|
||||
@ -793,6 +802,7 @@ def context_attention_fwd(q,
|
||||
|
||||
if alibi_slopes is not None:
|
||||
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||
assert fp8_out_scale is None, "FP8 output not supported with alibi"
|
||||
# need to reduce num. blocks when using fp32
|
||||
# due to increased use of GPU shared memory
|
||||
# if q.dtype is torch.float32:
|
||||
@ -870,6 +880,7 @@ def context_attention_fwd(q,
|
||||
sm_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
k_cache.shape[4],
|
||||
@ -905,6 +916,7 @@ def context_attention_fwd(q,
|
||||
BLOCK_DMODEL_PADDED=Lk_padded,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
SKIP_DECODE=skip_decode,
|
||||
USE_FP8=fp8_out_scale is not None,
|
||||
BLOCK_M=128,
|
||||
BLOCK_N=64,
|
||||
num_unroll_cache=4,
|
||||
|
||||
@ -10,9 +10,11 @@
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
sink_ptr, # [num_query_heads]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
out_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
USE_FP8: tl.constexpr, # bool
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@ -281,6 +287,9 @@ def kernel_unified_attention_2d(
|
||||
|
||||
# epilogue
|
||||
acc = acc / L[:, None]
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
output_offset = (query_offset_0[:, None] * output_stride_0 +
|
||||
query_offset_1[:, None] * output_stride_1 +
|
||||
@ -552,23 +561,27 @@ def kernel_unified_attention_3d(
|
||||
|
||||
@triton.jit
|
||||
def reduce_segments(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
segm_output_ptr,
|
||||
#[num_tokens, num_query_heads, max_num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
num_seqs, # int
|
||||
num_query_heads: tl.constexpr, # int
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
block_table_stride: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
segm_output_ptr,
|
||||
#[num_tokens, num_query_heads, max_num_segments, head_size]
|
||||
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
num_seqs, # int
|
||||
num_query_heads: tl.constexpr, # int
|
||||
out_scale_inv, # float32
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
block_table_stride: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
USE_FP8: tl.constexpr, # bool
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
query_token_idx = tl.program_id(0)
|
||||
query_head_idx = tl.program_id(1)
|
||||
@ -624,6 +637,10 @@ def reduce_segments(
|
||||
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
|
||||
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
|
||||
|
||||
if USE_FP8:
|
||||
acc = acc * tl.load(out_scale_inv)
|
||||
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
|
||||
|
||||
# write result
|
||||
output_offset = (query_token_idx * output_stride_0 +
|
||||
query_head_idx * output_stride_1 +
|
||||
@ -649,6 +666,7 @@ def unified_attention(
|
||||
k_descale,
|
||||
v_descale,
|
||||
alibi_slopes=None,
|
||||
output_scale=None,
|
||||
qq_bias=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
@ -707,6 +725,7 @@ def unified_attention(
|
||||
scale=softmax_scale,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
out_scale=1 / output_scale if output_scale is not None else 1.0,
|
||||
softcap=softcap,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
@ -736,6 +755,7 @@ def unified_attention(
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
else:
|
||||
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
|
||||
@ -819,6 +839,8 @@ def unified_attention(
|
||||
seq_lens_ptr=seqused_k,
|
||||
num_seqs=num_seqs,
|
||||
num_query_heads=num_query_heads,
|
||||
out_scale_inv=1 /
|
||||
output_scale if output_scale is not None else 1.0,
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
block_table_stride=block_table.stride(0),
|
||||
@ -828,4 +850,5 @@ def unified_attention(
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
USE_FP8=output_scale is not None,
|
||||
)
|
||||
|
||||
@ -454,11 +454,12 @@ class VllmBackend:
|
||||
inductor_config = config.inductor_compile_config
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
# Config should automatically wrap all inductor passes
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
# PassManager already added to config, make sure it's correct
|
||||
assert (inductor_config[PASS_KEY].uuid() ==
|
||||
self.post_grad_pass_manager.uuid())
|
||||
else:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
|
||||
@ -39,6 +39,7 @@ class AttentionQuantPattern(ABC):
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
@ -46,11 +47,16 @@ class AttentionQuantPattern(ABC):
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
@ -91,12 +97,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
):
|
||||
quant_key = QuantKey(dtype=FP8_DTYPE,
|
||||
scale=kStaticTensorScale,
|
||||
symmetric=symmetric)
|
||||
super().__init__(layer, quant_key)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
@ -139,10 +146,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # attn_output
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # q
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # k
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # v
|
||||
self.empty(5, self.num_heads, self.head_size,
|
||||
dtype=self.dtype), # attn_output
|
||||
self.empty_quant(5,
|
||||
self.num_heads * self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
@ -165,8 +176,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention):
|
||||
super().__init__(layer, kNvfp4Quant)
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype):
|
||||
super().__init__(layer, kNvfp4Quant, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
@ -255,12 +266,14 @@ class AttnFusionPass(VllmInductorPass):
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C,
|
||||
"scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
|
||||
@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
bias=bias)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_fake(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor) -> torch.Tensor:
|
||||
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
|
||||
dtype=out_dtype)
|
||||
|
||||
@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
qinput, weight, out_dtype, scale_a, scale_b, bias)
|
||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor, output_shape: list,
|
||||
output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_b=scale_b.t(),
|
||||
bias=bias)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
|
||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
||||
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
@ -430,7 +431,6 @@ class Fp8LinearOp:
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
input_2d=input_2d,
|
||||
output_shape=output_shape)
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool:
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for TritonAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
output_scale=output_scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
)
|
||||
output_scale=output_scale)
|
||||
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user