[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-08-23 06:09:05 +08:00 committed by GitHub
parent cc7ae5e7ca
commit 24d0c9e6ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 596 additions and 200 deletions

View File

@ -9,8 +9,11 @@ from typing import Optional
import flashinfer import flashinfer
import torch import torch
from vllm.utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fn
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn): def to_float8(x, dtype=torch.float8_e4m3fn):
@ -61,13 +64,13 @@ def benchmark_decode(
else: else:
raise ValueError(f"Invalid kv_layout: {kv_layout}") raise ValueError(f"Invalid kv_layout: {kv_layout}")
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) # Always using 1.0 scale to reflect the real perf in benchmarking
q_scale = 1.0
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE: if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query) query, _ = to_float8(ref_query)
ref_query = query.to(dtype) * q_scale
else: else:
q_scale = 1.0 query = ref_query
ref_query = query
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_seq_len kv_lens[-1] = max_seq_len
@ -75,14 +78,13 @@ def benchmark_decode(
seq_lens = kv_lens seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item() max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype) # Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE: if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache) kv_cache, _ = to_float8(ref_kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else: else:
kv_scale = 1.0 kv_cache = ref_kv_cache
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint( block_tables = torch.randint(
@ -142,11 +144,31 @@ def benchmark_decode(
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
o_scale = 1.0 o_scale = 1.0
o_sf_scale = None
output_baseline = torch.empty(ref_query.shape, dtype=dtype) output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) if o_quant_dtype == FP4_DTYPE:
o_sf_scale = 500.0
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_decode(): def baseline_decode():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) return wrapper.run(
ref_query,
ref_kv_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output_baseline,
)
def trtllm_decode(): def trtllm_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache( return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
@ -158,6 +180,7 @@ def benchmark_decode(
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale, bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale, bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm, out=output_trtllm,
) )
@ -237,6 +260,7 @@ if __name__ == "__main__":
(None, None, None), (None, None, None),
(None, FP8_DTYPE, None), (None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
] ]
for quant_dtype in quant_dtypes: for quant_dtype in quant_dtypes:

View File

@ -9,8 +9,11 @@ from typing import Optional
import flashinfer import flashinfer
import torch import torch
from vllm.utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fn
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn): def to_float8(x, dtype=torch.float8_e4m3fn):
@ -72,13 +75,15 @@ def benchmark_prefill(
] ]
) )
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) # Always using 1.0 scale to reflect the real perf in benchmarking
q_scale = 1.0
ref_query = torch.randn(
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
)
if q_quant_dtype == FP8_DTYPE: if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query) query, _ = to_float8(ref_query)
ref_query = query.to(dtype) * q_scale
else: else:
q_scale = 1.0 query = ref_query
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len kv_lens[-1] = max_kv_len
@ -86,14 +91,13 @@ def benchmark_prefill(
seq_lens = kv_lens + q_lens seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item() max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype) # Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE: if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache) kv_cache, _ = to_float8(ref_kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else: else:
kv_scale = 1.0 kv_cache = ref_kv_cache
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint( block_tables = torch.randint(
@ -152,11 +156,31 @@ def benchmark_prefill(
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
o_scale = 1.0 o_scale = 1.0
o_sf_scale = None
output_baseline = torch.empty(ref_query.shape, dtype=dtype) output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) if o_quant_dtype == FP4_DTYPE:
o_sf_scale = 500.0
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_prefill(): def baseline_prefill():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) return wrapper.run(
ref_query,
ref_kv_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output_baseline,
)
def trtllm_prefill(): def trtllm_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache( return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
@ -172,6 +196,7 @@ def benchmark_prefill(
batch_size=batch_size, batch_size=batch_size,
cum_seq_lens_q=q_indptr, cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr, cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm, out=output_trtllm,
) )
@ -250,6 +275,7 @@ if __name__ == "__main__":
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype) # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None), (None, None, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
] ]
for quant_dtype in quant_dtypes: for quant_dtype in quant_dtypes:

View File

@ -8,11 +8,12 @@ import vllm.envs as envs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, from vllm.compilation.fusion import FUSED_OPS, FusionPass
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
from .backend import TestBackend from .backend import TestBackend

View File

@ -7,11 +7,13 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.plugins import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, GroupShape, QuantKey) FusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig) VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -30,10 +32,8 @@ class TestModel(torch.nn.Module):
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE, quant_scale = ScaleDesc(torch.float32, static, group_shape)
static=static, self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
group_shape=group_shape,
symmetric=True)
if static: if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
else: else:

View File

@ -11,9 +11,10 @@ from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata) create_common_attn_metadata)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention from vllm.attention import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
@ -22,13 +23,14 @@ from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
set_current_vllm_config) set_current_vllm_config)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp) Fp8LinearOp)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# globals needed for string-import custom Dynamo backend field # globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None backend: Optional[TestBackend] = None
@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# check support # check support
attn_fusion_supported = [ attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype, layer.impl.fused_output_quant_supported(quant_key)
quant_key.static,
quant_key.group_shape)
for key, layer in compile_config.static_forward_context.items() for key, layer in compile_config.static_forward_context.items()
] ]
@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend = None backend = None
class TestAttentionStaticQuantPatternModel(torch.nn.Module): class AttentionQuantPatternModel(torch.nn.Module):
"""Test model for AttentionStaticQuantPattern fusion.""" """Base model for AttentionQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype, device: torch.device, kv_cache_dtype: torch.dtype, device: torch.device,
vllm_config: VllmConfig): vllm_config: VllmConfig, **kwargs):
super().__init__() super().__init__()
self.num_qo_heads = num_qo_heads self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
@ -172,11 +172,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
prefix="model.layers.0.self_attn.attn", prefix="model.layers.0.self_attn.attn",
) )
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
self.wscale = torch.tensor([1.0], dtype=torch.float32)
self.scale = torch.tensor([1.0], dtype=torch.float32)
self.block_size = 16 self.block_size = 16
# Initialize attn MetadataBuilder # Initialize attn MetadataBuilder
@ -230,23 +225,86 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
return self.attn_metadata return self.attn_metadata
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor): class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionFp8StaticQuantPattern fusion."""
quant_key = kFp8StaticTensorSym
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randn(hidden_size, hidden_size).to(
dtype=FP8_DTYPE, device=self.device).t(),
"wscale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
})
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused.""" """Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output, return self.fp8_linear.apply(input=attn_output,
weight=w, weight=self.w["weight"],
weight_scale=self.wscale, weight_scale=self.w["wscale"],
input_scale=self.scale) input_scale=self.w["scale"])
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionNvfp4QuantPattern fusion."""
quant_key = kNvfp4Quant
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randint(256, (hidden_size, hidden_size // 2),
dtype=FP4_DTYPE,
device=self.device),
"wscale_swizzled":
torch.randn(hidden_size, hidden_size // 16).to(
dtype=FP8_DTYPE, device=self.device),
"wscale":
torch.tensor([500], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([0.002], dtype=torch.float32, device=self.device),
})
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"])
return cutlass_scaled_fp4_mm(a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype)
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) @pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
@pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533]) @pytest.mark.parametrize("batch_size", [7, 256, 533])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize( @pytest.mark.parametrize("model_name, model_class",
"model_name, quant_key", [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)]) TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)])
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) @pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@ -255,8 +313,8 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int, head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str, dtype: torch.dtype, model_name: str,
quant_key: QuantKey, backend: _Backend, model_class: type[AttentionQuantPatternModel],
monkeypatch, dist_init): backend: _Backend, monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
cache_config=CacheConfig(cache_dtype="fp8")) cache_config=CacheConfig(cache_dtype="fp8"))
# Create test inputs # Create test inputs
hidden_size = num_qo_heads * head_size q = torch.randn(batch_size,
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) num_qo_heads * head_size,
dtype=dtype,
device=device)
k = torch.randn(batch_size, k = torch.randn(batch_size,
num_kv_heads * head_size, num_kv_heads * head_size,
dtype=dtype, dtype=dtype,
@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
num_kv_heads * head_size, num_kv_heads * head_size,
dtype=dtype, dtype=dtype,
device=device) device=device)
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
# Mark first dimension as dynamic for realistic testing # Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0) torch._dynamo.mark_dynamic(q, 0)
@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
with set_current_vllm_config(vllm_config_unfused), set_forward_context( with set_current_vllm_config(vllm_config_unfused), set_forward_context(
attn_metadata=None, vllm_config=vllm_config_unfused attn_metadata=None, vllm_config=vllm_config_unfused
), global_force_attn_backend_context_manager(backend): ), global_force_attn_backend_context_manager(backend):
model_unfused = TestAttentionStaticQuantPatternModel( model_unfused = model_class(num_qo_heads=num_qo_heads,
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, num_kv_heads=num_kv_heads,
vllm_config_unfused) head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused)
model_unfused = model_unfused.to(device) model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
batch_size) batch_size)
# Run model directly without compilation and fusion # Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v, linear_w) result_unfused = model_unfused(q, k, v)
# Run model with attn fusion enabled # Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
with set_current_vllm_config(vllm_config), set_forward_context( with set_current_vllm_config(vllm_config), set_forward_context(
attn_metadata=None, vllm_config=vllm_config attn_metadata=None, vllm_config=vllm_config
), global_force_attn_backend_context_manager(backend): ), global_force_attn_backend_context_manager(backend):
model_fused = TestAttentionStaticQuantPatternModel( model_fused = model_class(num_qo_heads=num_qo_heads,
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, num_kv_heads=num_kv_heads,
vllm_config) head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w)
model_fused = model_fused.to(device) model_fused = model_fused.to(device)
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
backend=test_backend, backend=test_backend,
fullgraph=True) fullgraph=True)
assert model_compiled.attn._o_scale_float is None assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v, linear_w) result_fused_1 = model_compiled(q, k, v)
# After the 1st round of the forward pass, output quant scale should be # After the 1st round of the forward pass, output quant scale should be
# loaded into the attn layer's _o_scale_float, the 2nd round should # loaded into the attn layer's _o_scale_float, the 2nd round should
# reuse the loaded _o_scale_float # reuse the loaded _o_scale_float
assert model_compiled.attn._o_scale_float is not None assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v, linear_w) result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None assert model_compiled.attn._o_scale_float is not None
# Check attn fusion support # Check attn fusion support
quant_key = model_class.quant_key
attn_fusion_supported = [ attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype, layer.impl.fused_output_quant_supported(quant_key) for key, layer in
quant_key.static, vllm_config.compilation_config.static_forward_context.items()
quant_key.group_shape) for key,
layer in vllm_config.compilation_config.static_forward_context.items()
] ]
if any(attn_fusion_supported): if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion # Check quantization ops in the graph before and after fusion
@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
"Attention should have output_scale after fusion" "Attention should have output_scale after fusion"
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
"Attention should not have output_block_scale before fusion"
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
"Attention should not have output_block_scale after FP8 fusion"
elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
# Check that results are closed # Check that results are closed
torch.testing.assert_close(result_unfused, torch.testing.assert_close(result_unfused,
result_fused_1, result_fused_1,

View File

@ -6,7 +6,11 @@ import flashinfer
import pytest import pytest
import torch import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100): if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100):
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn): def to_float8(x, dtype=torch.float8_e4m3fn):
@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16]
QUANT_DTYPES = [ QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype) # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None), (None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
] ]
BATCH_SIZE = [4, 12] BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)] MAX_SEQ_LENS = [(1024, 4096)]
@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype) output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output) wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0 o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output) _, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Decode # TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache( flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query, query=query,
kv_cache=kv_cache, kv_cache=kv_cache,
@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale, bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale, bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm, out=output_trtllm,
) )
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2 rtol, atol = 5e-2, 7e-2
else: else:
rtol, atol = 1e-2, 1e-2 rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}" f"{torch.max(torch.abs(output - output_trtllm))}"
@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_quant_dtype = kv_quant_dtype or dtype kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype o_quant_dtype = o_quant_dtype or dtype
if q_quant_dtype != kv_quant_dtype:
pytest.skip("Skipped mixed QKV dtypes for prefill")
max_q_len, max_kv_len = max_seq_lens max_q_len, max_kv_len = max_seq_lens
num_qo_heads, num_kv_heads = num_heads num_qo_heads, num_kv_heads = num_heads
@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype) output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output) wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0 o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output) _, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Prefill # TRTLLM Prefill
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache( flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query, query=query,
kv_cache=kv_cache, kv_cache=kv_cache,
@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
batch_size=batch_size, batch_size=batch_size,
cum_seq_lens_q=q_indptr, cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr, cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm, out=output_trtllm,
) )
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2 rtol, atol = 5e-2, 7e-2
else: else:
rtol, atol = 1e-2, 1e-2 rtol, atol = 1e-2, 1e-2

View File

@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
GroupShape)
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING: if TYPE_CHECKING:
@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, def fused_output_quant_supported(self, quant_key: QuantKey):
group_shape: GroupShape):
""" """
Does this attention implementation support fused output quantization. Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it. onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor :param quant_key: QuantKey object that describes the quantization op
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape.
:return: is fusion supported for this type of quantization :return: is fusion supported for this type of quantization
""" """
return False return False
@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -800,6 +800,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
attn_metadata: DifferentialFlashAttentionMetadata, attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -817,6 +818,11 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads). {q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values We use torch's .expand() to avoid duplicating values
""" """
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for DifferentialFlashAttentionImpl")
if self.lambda_full is None: if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[ self.lambda_init = self.differential_flash_attention_config[
"lambda_init"] "lambda_init"]

View File

@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
attn_metadata: DualChunkFlashAttentionMetadata, attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention. """Forward pass with DualChunkFlashAttention.
Args: Args:
@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
""" """
assert output is None, "Output tensor not supported for DualChunk" assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for FlashAttentionImpl") " for FlashAttentionImpl")

View File

@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -615,7 +616,7 @@ class FlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for FlashAttentionImpl") " for FlashAttentionImpl")

View File

@ -1238,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output is not None: if output is not None:
raise NotImplementedError( raise NotImplementedError(
"output is not yet supported for MLAImplBase") "output is not yet supported for MLAImplBase")
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for MLAImplBase") " for MLAImplBase")

View File

@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim)) head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, def fused_output_quant_supported(self, quant_key: QuantKey):
group_shape: GroupShape):
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype( return quant_key == kFp8StaticTensorSym
) and static and group_shape == GroupShape.PER_TENSOR
# Only supported in the Triton backend # Only supported in the Triton backend
return False return False
@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -606,6 +605,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"fused output quantization only supported for Triton" "fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now") " implementation in ROCMFlashAttentionImpl for now")
if output_block_scale is not None:
raise NotImplementedError(
"fused nvfp4 output quantization is not supported"
" for ROCMFlashAttentionImpl")
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None: if key is not None:
assert value is not None assert value is not None

View File

@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
@ -484,7 +485,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for XFormersImpl") " for XFormersImpl")

View File

@ -495,6 +495,7 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name) wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
@ -510,7 +511,8 @@ def unified_attention_with_output(
kv_cache, kv_cache,
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale) output_scale=output_scale,
output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
@ -522,6 +524,7 @@ def unified_attention_with_output_fake(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
return return
@ -529,7 +532,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="unified_attention_with_output", op_name="unified_attention_with_output",
op_func=unified_attention_with_output, op_func=unified_attention_with_output,
mutates_args=["output"], mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )

View File

@ -12,7 +12,8 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe from .fx_utils import find_getitem_maybe
@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs): def empty_bf16(*args, **kwargs):
@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
symmetric: bool = True
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = { QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
@ -74,6 +47,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501
} }
@ -187,11 +161,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
symmetric=True): symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False, fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey( quant=QuantKey(dtype=quant_dtype,
dtype=quant_dtype, scale=kStaticTensorScale,
static=True, symmetric=symmetric))
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, fused_key) super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
@ -244,11 +216,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
symmetric=True): symmetric=True):
key = FusedRMSQuantKey(fused_add=True, key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey( quant=QuantKey(dtype=quant_dtype,
dtype=quant_dtype, scale=kStaticTensorScale,
static=True, symmetric=symmetric))
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass, def register(self, pm_pass: PatternMatcherPass,
@ -337,10 +307,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True): symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=False, key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(dtype=quant_dtype,
static=False, scale=scale,
group_shape=group_shape,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
@ -435,10 +405,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True): symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=True, key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(dtype=quant_dtype,
static=False, scale=scale,
group_shape=group_shape,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
@ -11,44 +13,41 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern: class AttentionQuantPattern(ABC):
""" """
Fusion for Attention+StaticQuant. The base class for Attn+Quant fusions.
Should not be used directly.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
op will be removed from the graph, and its scale will be passed into
Attention op as the `output_scale` argument.
""" """
def __init__( def __init__(
self, self,
layer: Attention, layer: Attention,
quant_dtype: torch.dtype, quant_key: QuantKey,
symmetric=True,
): ):
self.layer = layer self.layer = layer
self.layer_name = layer.layer_name self.layer_name = layer.layer_name
self.num_heads = layer.num_heads self.num_heads = layer.num_heads
self.head_size = layer.head_size self.head_size = layer.head_size
self.quant_dtype = quant_dtype self.quant_key = quant_key
self.quant_key = QuantKey(dtype=quant_dtype, self.quant_dtype = quant_key.dtype
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)
assert self.quant_key in QUANT_OPS, \ assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}" f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key] self.QUANT_OP = QUANT_OPS[self.quant_key]
@ -57,12 +56,49 @@ class AttentionStaticQuantPattern:
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
def register_if_supported(self, pm_pass: PatternMatcherPass): def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported( if self.layer.impl.fused_output_quant_supported(self.quant_key):
self.quant_dtype, self.quant_key.static,
self.quant_key.group_shape):
self._register(pm_pass) self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
symmetric: bool = True,
):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(layer, quant_key)
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
@ -74,9 +110,10 @@ class AttentionStaticQuantPattern:
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self.layer_name, layer_name=self.layer_name,
output_scale=None) output_scale=None,
attn_out_view = RESHAPE_OP(at1[1], output_block_scale=None)
[-1, self.num_heads * self.head_size]) attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP, at2 = auto_functionalized(self.QUANT_OP,
result=output_quant, result=output_quant,
input=attn_out_view, input=attn_out_view,
@ -98,7 +135,8 @@ class AttentionStaticQuantPattern:
value=v, value=v,
output=output_attn, output=output_attn,
layer_name=self.layer_name, layer_name=self.layer_name,
output_scale=scale) output_scale=scale,
output_block_scale=None)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors. # Need custom fake mode, otherwise tracing happens with real tensors.
@ -114,21 +152,94 @@ class AttentionStaticQuantPattern:
empty_fp32(1, 1) # scale empty_fp32(1, 1) # scale
] ]
def wrap_trace_fn(process_fx, trace_fn): pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
def fx_view_to_reshape(gm: torch.fx.GraphModule): Only triggers when the attention implementation returns True in
from torch._inductor.fx_passes.post_grad import view_to_reshape `fused_output_quant_supported()`. If the pattern is found, the
view_to_reshape(gm) Nvfp4Quant op will be removed from the graph, and its scale
return gm will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention):
super().__init__(layer, kNvfp4Quant)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(
output_scale, FP8_DTYPE)
at2 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view)
output = RESHAPE_OP(at2[1],
[-1, self.num_heads * self.head_size // 2])
return output, at2[2]
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size //
2), # output_quant
empty_i32(128,
round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement( pm.register_replacement(
pattern, replacement, inputs, pattern, replacement, inputs,
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
class AttnFusionPass(VllmInductorPass): class AttnFusionPass(VllmInductorPass):
@ -151,8 +262,12 @@ class AttnFusionPass(VllmInductorPass):
attn_layers = get_layers_from_vllm_config(config, Attention) attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items(): for layer_name, layer in attn_layers.items():
pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE) pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern.register_if_supported(self.patterns) pattern_fp8.register_if_supported(self.patterns)
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0: if len(attn_layers) == 0:
logger.warning( logger.warning(
"Attention + quant fusion is enabled, but no attention layers " "Attention + quant fusion is enabled, but no attention layers "
@ -175,4 +290,6 @@ class AttnFusionPass(VllmInductorPass):
self.end_and_log() self.end_and_log()
def uuid(self): def uuid(self):
return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern) return VllmInductorPass.hash_source(self, AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern)

View File

@ -2,16 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass
from types import MappingProxyType from types import MappingProxyType
from typing import ClassVar, NamedTuple, Optional from typing import ClassVar, NamedTuple, Optional
import numpy import numpy
import torch import torch
from torch import fx
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# Use proxy as NamedTuple direct subclasses cannot have static members # Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple): class _GroupShape(NamedTuple):
@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1)
@dataclass(frozen=True)
class ScaleDesc:
"""
Class for describing a single quantization scaling factor.
dtype: data type of the scale
static: static scale if True, dynamic if False
group_shape: group shape of the scale
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'static' if self.static else 'dynamic'},{group_shape}")
@dataclass(frozen=True)
class QuantKey:
"""
Class for identifying the type of quantization.
dtype: quantized data type
scale: scale descriptor
scale2: second-level scale descriptor
symmetric: symmetric if True, asymmetric if False
"""
dtype: torch.dtype
scale: ScaleDesc
scale2: Optional[ScaleDesc] = None
symmetric: bool = True
def __str__(self):
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
f"scale({self.scale}),{scale2_str}"
f"{'a' if not self.symmetric else ''}symmetric)")
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE,
scale=kNvfp4GroupScale,
scale2=kStaticTensorScale)
# Normalize the group_shape to the full extent for any dims that are -1 # Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent # -1 means full extent

View File

@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -497,7 +498,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for TorchSDPABackendImpl") " for TorchSDPABackendImpl")

View File

@ -430,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -447,7 +448,7 @@ class FlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for FlashAttentionImpl") " for FlashAttentionImpl")

View File

@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper) MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention, from vllm.utils.flashinfer import (supports_trtllm_attention,
@ -40,6 +41,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
logger = init_logger(__name__) logger = init_logger(__name__)
@ -653,14 +655,12 @@ class FlashInferImpl(AttentionImpl):
and num_heads % num_kv_heads == 0) and num_heads % num_kv_heads == 0)
self.bmm1_scale: Optional[float] = None self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None
self.o_sf_scale: Optional[float] = None
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, def fused_output_quant_supported(self, quant_key: QuantKey):
group_shape: GroupShape):
supported_quant_type = (dtype == FP8_DTYPE and static
and group_shape == GroupShape.PER_TENSOR)
return (self.support_trtllm_attn return (self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8") and self.kv_cache_dtype.startswith("fp8")
and supported_quant_type) and quant_key in (kFp8StaticTensorSym, kNvfp4Quant))
def forward( def forward(
self, self,
@ -672,6 +672,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashInfer. """Forward pass with FlashInfer.
@ -705,19 +706,32 @@ class FlashInferImpl(AttentionImpl):
if output_scale is None: if output_scale is None:
assert attn_metadata.q_data_type != FP8_DTYPE, \ assert attn_metadata.q_data_type != FP8_DTYPE, \
"Query can only be FP8 if output fusion happened." "Query can only be FP8 if output fusion happened."
assert output_block_scale is None, "output_block_scale "\
"is not supported when fusion has not happened"
else: else:
assert attn_metadata.q_data_type == FP8_DTYPE, \ assert attn_metadata.q_data_type == FP8_DTYPE, \
"Query must be FP8 when attn+quant fusion happened." "Query must be FP8 when attn+quant fusion happened."
assert (attn_metadata.prefill_use_trtllm and assert (attn_metadata.prefill_use_trtllm and
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
assert output.dtype == FP8_DTYPE, \
"Output must be FP8 when attn+quant fusion happened."
# TRTLLM attn kernel requires o scale as a host scalar, store the if output.dtype == FP8_DTYPE:
# o scale to host scalar in warmup run with cuda graph not enabled assert output_block_scale is None, \
"output_block_scale should not be provided for fp8 output"
elif output.dtype == FP4_DTYPE:
assert output_block_scale is not None, \
"output_block_scale is required for nvfp4 output"
else:
raise ValueError(f"Unsupported output dtype: {output.dtype}")
# TRTLLM attn kernel requires o scale to pass as a host scalar,
# store the o scale as a host scalar in warmup run with cuda graph
# not enabled
if layer._o_scale_float is None: if layer._o_scale_float is None:
layer._o_scale_float = output_scale.cpu().item() layer._o_scale_float = output_scale.cpu().item()
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float if output.dtype == FP8_DTYPE:
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float
# Insert FP8 quant for query # Insert FP8 quant for query
num_tokens, num_heads, head_size = query.shape num_tokens, num_heads, head_size = query.shape
@ -818,6 +832,16 @@ class FlashInferImpl(AttentionImpl):
assert block_tables_prefill.is_contiguous() assert block_tables_prefill.is_contiguous()
assert seq_lens_prefill.is_contiguous() assert seq_lens_prefill.is_contiguous()
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
out = FP4Tensor(data=output[num_decode_tokens:],
scale=output_block_scale,
scale_start_index=num_decode_tokens,
original_shape=prefill_query.shape)
else:
assert self.o_sf_scale is None
out = output[num_decode_tokens:]
trtllm_batch_context_with_kv_cache( trtllm_batch_context_with_kv_cache(
query=prefill_query, query=prefill_query,
kv_cache=kv_cache_permute, kv_cache=kv_cache_permute,
@ -833,7 +857,8 @@ class FlashInferImpl(AttentionImpl):
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
window_left=self.window_left, window_left=self.window_left,
sinks=self.sinks, sinks=self.sinks,
out=output[num_decode_tokens:], o_sf_scale=self.o_sf_scale,
out=out,
) )
if num_decode_tokens > 0: if num_decode_tokens > 0:
@ -870,6 +895,16 @@ class FlashInferImpl(AttentionImpl):
assert block_tables_decode.is_contiguous() assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous() assert seq_lens_decode.is_contiguous()
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
out = FP4Tensor(data=output[:num_decode_tokens],
scale=output_block_scale,
scale_start_index=0,
original_shape=decode_query.shape)
else:
assert self.o_sf_scale is None
out = output[:num_decode_tokens]
trtllm_batch_decode_with_kv_cache( trtllm_batch_decode_with_kv_cache(
query=decode_query, query=decode_query,
kv_cache=kv_cache_permute, kv_cache=kv_cache_permute,
@ -881,7 +916,8 @@ class FlashInferImpl(AttentionImpl):
bmm2_scale=self.bmm2_scale, bmm2_scale=self.bmm2_scale,
window_left=self.window_left, window_left=self.window_left,
sinks=self.sinks, sinks=self.sinks,
out=output[:num_decode_tokens], o_sf_scale=self.o_sf_scale,
out=out,
) )
return output_padded return output_padded

View File

@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata: FlexAttentionMetadata, attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FLexAttention. """Forward pass with FLexAttention.
@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for FlexAttentionImpl") " for FlexAttentionImpl")

View File

@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for MLACommonImpl") " for MLACommonImpl")

View File

@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for PallasAttentionBackendImpl") " for PallasAttentionBackendImpl")

View File

@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
attn_metadata: AiterFlashAttentionMetadata, attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with AiterFlashAttention. """Forward pass with AiterFlashAttention.
@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for FlashAttentionImpl") " for FlashAttentionImpl")

View File

@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl):
attn_metadata: TreeAttentionMetadata, attn_metadata: TreeAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with TreeAttention. """Forward pass with TreeAttention.
@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for TreeAttentionImpl") " for TreeAttentionImpl")

View File

@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for TritonAttentionImpl") " for TritonAttentionImpl")

View File

@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl):
attn_metadata: XFormersAttentionMetadata, attn_metadata: XFormersAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with XFormers. """Forward pass with XFormers.
@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for XFormersAttentionImpl") " for XFormersAttentionImpl")