[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Gregory Shtrasberg 2025-09-10 16:59:55 -04:00 committed by GitHub
parent 37e8182bfe
commit 9a161307f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 249 additions and 135 deletions

View File

@ -40,13 +40,12 @@ backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="V0 attn quant fusion only on ROCm")
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
gpu_memory_utilization=0.5,
max_model_len=2048)
sampling_params = SamplingParams(temperature=0.0,
@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
gpu_memory_utilization=0.5,
max_model_len=2048)
# check support
@ -171,6 +178,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
cache_config=vllm_config.cache_config,
prefix="model.layers.0.self_attn.attn",
)
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)
self.block_size = 16
@ -188,7 +197,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device,
)
def build_attn_metadata(self, batch_size: int):
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
"""Initialize attention metadata."""
# Create common attn metadata
@ -205,10 +214,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_blocks = batch_size * max_blocks
# Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(num_blocks,
2,
self.num_kv_heads,
@ -216,7 +223,17 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
if current_platform.is_rocm():
# k/v as 1st dimention
if use_hnd:
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
else:
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
else:
# k/v as 2nd dimention
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
self.attn.kv_cache = [kv_cache]
# Build attn metadata
@ -296,28 +313,51 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
out_dtype=attn_output.dtype)
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
if current_platform.is_cuda():
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)]
HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
TestAttentionFp8StaticQuantPatternModel)]
HEADS = [(32, 8), (40, 8)]
else:
MODELS = []
HEADS = []
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_name, model_class",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)])
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@pytest.mark.parametrize("batch_size",
[7, 256, 533] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if
current_platform.is_cuda() else [_Backend.ROCM_FLASH])
@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)),
reason="Only test on SM100(Blackwell)")
@pytest.mark.skipif(current_platform.is_cuda()
and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend, monkeypatch, dist_init):
backend: _Backend, split_attention: bool,
monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
device = torch.device("cuda:0")
torch.manual_seed(42)
@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
model_config=ModelConfig(
model=model_name,
max_model_len=2048,
dtype=dtype,
),
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig(
@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size)
batch_size, use_hnd=split_attention)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention)
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v)
# After the 1st round of the forward pass, output quant scale should be
# loaded into the attn layer's _o_scale_float, the 2nd round should
# reuse the loaded _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None
if backend == _Backend.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded
# _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)
# Check attn fusion support
quant_key = model_class.quant_key
@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
# Check that results are closed
# Check that results are close
torch.testing.assert_close(result_unfused,
result_fused_1,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)

View File

@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit
def cdiv_fn(x, y):
@ -34,6 +36,7 @@ def kernel_paged_attention_2d(
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale_inv,
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
@ -60,7 +63,9 @@ def kernel_paged_attention_2d(
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
):
USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@ -204,6 +209,9 @@ def kernel_paged_attention_2d(
# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
query_head_idx * output_stride_1)
@ -234,6 +242,7 @@ def chunked_prefill_paged_decode(
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
output_scale=None,
# Optional tensor for sinks
sinks=None,
):
@ -266,6 +275,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
fp8_out_scale=output_scale,
sinks=sinks,
)
@ -316,7 +326,7 @@ def chunked_prefill_paged_decode(
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),
dtype=output.dtype,
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
@ -345,6 +355,7 @@ def chunked_prefill_paged_decode(
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
fp8_out_scale=output_scale,
)
else:
kernel_paged_attention_2d[(
@ -362,6 +373,8 @@ def chunked_prefill_paged_decode(
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
out_scale_inv=1.0 /
output_scale if output_scale is not None else 1.0,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
@ -388,4 +401,5 @@ def chunked_prefill_paged_decode(
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
USE_FP8=output_scale is not None,
)

View File

@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)
float8_info = torch.finfo(current_platform.fp8_dtype())
# Here's an example autotuner config for this kernel. This config does provide
@ -43,6 +44,7 @@ def _fwd_kernel(Q,
sm_scale,
k_scale,
v_scale,
out_scale_inv,
B_Start_Loc,
B_Seqlen,
x: tl.constexpr,
@ -82,8 +84,11 @@ def _fwd_kernel(Q,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0):
MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
@ -284,6 +289,9 @@ def _fwd_kernel(Q,
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
@ -743,6 +751,7 @@ def context_attention_fwd(q,
sliding_window=None,
sm_scale=None,
skip_decode=False,
fp8_out_scale=None,
sinks=None):
q_dtype_is_f32 = q.dtype is torch.float32
@ -793,6 +802,7 @@ def context_attention_fwd(q,
if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
@ -870,6 +880,7 @@ def context_attention_fwd(q,
sm_scale,
k_scale,
v_scale,
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
b_start_loc,
b_seq_len,
k_cache.shape[4],
@ -905,6 +916,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,

View File

@ -10,9 +10,11 @@
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit
@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@ -281,6 +287,9 @@ def kernel_unified_attention_2d(
# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (query_offset_0[:, None] * output_stride_0 +
query_offset_1[:, None] * output_stride_1 +
@ -552,23 +561,27 @@ def kernel_unified_attention_3d(
@triton.jit
def reduce_segments(
output_ptr, # [num_tokens, num_query_heads, head_size]
segm_output_ptr,
#[num_tokens, num_query_heads, max_num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs]
num_seqs, # int
num_query_heads: tl.constexpr, # int
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
block_table_stride: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
output_ptr, # [num_tokens, num_query_heads, head_size]
segm_output_ptr,
#[num_tokens, num_query_heads, max_num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs]
num_seqs, # int
num_query_heads: tl.constexpr, # int
out_scale_inv, # float32
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
block_table_stride: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
query_token_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)
@ -624,6 +637,10 @@ def reduce_segments(
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
# write result
output_offset = (query_token_idx * output_stride_0 +
query_head_idx * output_stride_1 +
@ -649,6 +666,7 @@ def unified_attention(
k_descale,
v_descale,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
# Optional tensor for sinks
sinks=None,
@ -707,6 +725,7 @@ def unified_attention(
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
@ -736,6 +755,7 @@ def unified_attention(
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
@ -819,6 +839,8 @@ def unified_attention(
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
out_scale_inv=1 /
output_scale if output_scale is not None else 1.0,
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
block_table_stride=block_table.stride(0),
@ -828,4 +850,5 @@ def unified_attention(
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
USE_FP8=output_scale is not None,
)

View File

@ -454,11 +454,12 @@ class VllmBackend:
inductor_config = config.inductor_compile_config
PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config:
# Config should automatically wrap all inductor passes
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
# PassManager already added to config, make sure it's correct
assert (inductor_config[PASS_KEY].uuid() ==
self.post_grad_pass_manager.uuid())
else:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager

View File

@ -39,6 +39,7 @@ class AttentionQuantPattern(ABC):
self,
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
):
self.layer = layer
self.layer_name = layer.layer_name
@ -46,11 +47,16 @@ class AttentionQuantPattern(ABC):
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
self.dtype = dtype
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args, **kwargs):
kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args, **kwargs):
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@ -91,12 +97,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
def __init__(
self,
layer: Attention,
dtype: torch.dtype,
symmetric: bool = True,
):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(layer, quant_key)
super().__init__(layer, quant_key, dtype)
def _register(self, pm_pass: PatternMatcherPass):
@ -139,10 +146,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # attn_output
self.empty(5, self.num_heads, self.head_size,
dtype=self.dtype), # q
self.empty(5, self.num_heads, self.head_size,
dtype=self.dtype), # k
self.empty(5, self.num_heads, self.head_size,
dtype=self.dtype), # v
self.empty(5, self.num_heads, self.head_size,
dtype=self.dtype), # attn_output
self.empty_quant(5,
self.num_heads * self.head_size), # quant_output
empty_fp32(1, 1) # scale
@ -165,8 +176,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention):
super().__init__(layer, kNvfp4Quant)
def __init__(self, layer: Attention, dtype: torch.dtype):
super().__init__(layer, kNvfp4Quant, dtype)
def _register(self, pm_pass: PatternMatcherPass):
@ -255,12 +266,14 @@ class AttnFusionPass(VllmInductorPass):
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern_fp8 = AttentionFp8StaticQuantPattern(
layer, config.model_config.dtype)
pattern_fp8.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C,
"scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4 = AttentionNvfp4QuantPattern(
layer, config.model_config.dtype)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:

View File

@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
bias=bias)
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
return output
def rocm_per_tensor_w8a8_scaled_mm_fake(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
dtype=out_dtype)
@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
qinput, weight, out_dtype, scale_a, scale_b, bias)
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
direct_register_custom_op(
@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor, output_shape: list,
output_shape: list,
**kwargs) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b=scale_b.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = torch.narrow(output, 0, 0, qinput.shape[0])
output = output.view(*output_shape)
return output
@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
output = torch.narrow(output, 0, 0, qinput.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
@ -430,7 +431,6 @@ class Fp8LinearOp:
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
input_2d=input_2d,
output_shape=output_shape)

View File

@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import (
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool:
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
if output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
"fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl")
if attn_metadata is None:
@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
)
output_scale=output_scale)
return output