[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-08-23 06:09:05 +08:00 committed by GitHub
parent cc7ae5e7ca
commit 24d0c9e6ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 596 additions and 200 deletions

View File

@ -9,8 +9,11 @@ from typing import Optional
import flashinfer
import torch
from vllm.utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -61,13 +64,13 @@ def benchmark_decode(
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
q_scale = 1.0
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
query, _ = to_float8(ref_query)
else:
q_scale = 1.0
ref_query = query
query = ref_query
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_seq_len
@ -75,14 +78,13 @@ def benchmark_decode(
seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
kv_cache, _ = to_float8(ref_kv_cache)
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
kv_cache = ref_kv_cache
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
@ -142,11 +144,31 @@ def benchmark_decode(
return sum(times) / len(times), torch.std(torch.tensor(times))
o_scale = 1.0
o_sf_scale = None
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
o_sf_scale = 500.0
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_decode():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
return wrapper.run(
ref_query,
ref_kv_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output_baseline,
)
def trtllm_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
@ -158,6 +180,7 @@ def benchmark_decode(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@ -237,6 +260,7 @@ if __name__ == "__main__":
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
for quant_dtype in quant_dtypes:

View File

@ -9,8 +9,11 @@ from typing import Optional
import flashinfer
import torch
from vllm.utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -72,13 +75,15 @@ def benchmark_prefill(
]
)
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
q_scale = 1.0
ref_query = torch.randn(
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
query, _ = to_float8(ref_query)
else:
q_scale = 1.0
ref_query = query
query = ref_query
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
@ -86,14 +91,13 @@ def benchmark_prefill(
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
kv_cache, _ = to_float8(ref_kv_cache)
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
kv_cache = ref_kv_cache
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
@ -152,11 +156,31 @@ def benchmark_prefill(
return sum(times) / len(times), torch.std(torch.tensor(times))
o_scale = 1.0
o_sf_scale = None
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
o_sf_scale = 500.0
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_prefill():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
return wrapper.run(
ref_query,
ref_kv_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output_baseline,
)
def trtllm_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
@ -172,6 +196,7 @@ def benchmark_prefill(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@ -250,6 +275,7 @@ if __name__ == "__main__":
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
for quant_dtype in quant_dtypes:

View File

@ -8,11 +8,12 @@ import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fusion import FUSED_OPS, FusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
from .backend import TestBackend

View File

@ -7,11 +7,13 @@ import torch
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, GroupShape, QuantKey)
FusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
from vllm.platforms import current_platform
@ -30,10 +32,8 @@ class TestModel(torch.nn.Module):
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE,
static=static,
group_shape=group_shape,
symmetric=True)
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
else:

View File

@ -11,9 +11,10 @@ from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
@ -22,13 +23,14 @@ from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
set_current_vllm_config)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in compile_config.static_forward_context.items()
]
@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend = None
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
"""Test model for AttentionStaticQuantPattern fusion."""
class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype, device: torch.device,
vllm_config: VllmConfig):
vllm_config: VllmConfig, **kwargs):
super().__init__()
self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads
@ -172,11 +172,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
prefix="model.layers.0.self_attn.attn",
)
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
self.wscale = torch.tensor([1.0], dtype=torch.float32)
self.scale = torch.tensor([1.0], dtype=torch.float32)
self.block_size = 16
# Initialize attn MetadataBuilder
@ -230,23 +225,86 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
return self.attn_metadata
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor):
class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionFp8StaticQuantPattern fusion."""
quant_key = kFp8StaticTensorSym
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randn(hidden_size, hidden_size).to(
dtype=FP8_DTYPE, device=self.device).t(),
"wscale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
})
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output,
weight=w,
weight_scale=self.wscale,
input_scale=self.scale)
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"])
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionNvfp4QuantPattern fusion."""
quant_key = kNvfp4Quant
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randint(256, (hidden_size, hidden_size // 2),
dtype=FP4_DTYPE,
device=self.device),
"wscale_swizzled":
torch.randn(hidden_size, hidden_size // 16).to(
dtype=FP8_DTYPE, device=self.device),
"wscale":
torch.tensor([500], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([0.002], dtype=torch.float32, device=self.device),
})
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"])
return cutlass_scaled_fp4_mm(a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype)
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"model_name, quant_key",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
@pytest.mark.parametrize("model_name, model_class",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)])
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@ -255,8 +313,8 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
quant_key: QuantKey, backend: _Backend,
monkeypatch, dist_init):
model_class: type[AttentionQuantPatternModel],
backend: _Backend, monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_USE_V1", "1")
@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
cache_config=CacheConfig(cache_dtype="fp8"))
# Create test inputs
hidden_size = num_qo_heads * head_size
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
q = torch.randn(batch_size,
num_qo_heads * head_size,
dtype=dtype,
device=device)
k = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
num_kv_heads * head_size,
dtype=dtype,
device=device)
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
# Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0)
@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
attn_metadata=None, vllm_config=vllm_config_unfused
), global_force_attn_backend_context_manager(backend):
model_unfused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config_unfused)
model_unfused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused)
model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context()
@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
batch_size)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v, linear_w)
result_unfused = model_unfused(q, k, v)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
with set_current_vllm_config(vllm_config), set_forward_context(
attn_metadata=None, vllm_config=vllm_config
), global_force_attn_backend_context_manager(backend):
model_fused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config)
model_fused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
backend=test_backend,
fullgraph=True)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v, linear_w)
result_fused_1 = model_compiled(q, k, v)
# After the 1st round of the forward pass, output quant scale should be
# loaded into the attn layer's _o_scale_float, the 2nd round should
# reuse the loaded _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v, linear_w)
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None
# Check attn fusion support
quant_key = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape) for key,
layer in vllm_config.compilation_config.static_forward_context.items()
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
"Attention should have output_scale after fusion"
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
"Attention should not have output_block_scale before fusion"
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
"Attention should not have output_block_scale after FP8 fusion"
elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
# Check that results are closed
torch.testing.assert_close(result_unfused,
result_fused_1,

View File

@ -6,7 +6,11 @@ import flashinfer
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100):
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
if q_quant_dtype != kv_quant_dtype:
pytest.skip("Skipped mixed QKV dtypes for prefill")
max_q_len, max_kv_len = max_seq_lens
num_qo_heads, num_kv_heads = num_heads
@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Prefill
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query,
kv_cache=kv_cache,
@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2

View File

@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING:
@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
"""
return False
@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -800,6 +800,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -817,6 +818,11 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for DifferentialFlashAttentionImpl")
if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[
"lambda_init"]

View File

@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
"""
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -615,7 +616,7 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@ -1238,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")

View File

@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == GroupShape.PER_TENSOR
return quant_key == kFp8StaticTensorSym
# Only supported in the Triton backend
return False
@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -606,6 +605,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
if output_block_scale is not None:
raise NotImplementedError(
"fused nvfp4 output quantization is not supported"
" for ROCMFlashAttentionImpl")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None

View File

@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -484,7 +485,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")

View File

@ -495,6 +495,7 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
@ -510,7 +511,8 @@ def unified_attention_with_output(
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale)
output_scale=output_scale,
output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
@ -522,6 +524,7 @@ def unified_attention_with_output_fake(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
@ -529,7 +532,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output"],
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -12,7 +12,8 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs):
@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
symmetric: bool = True
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
@ -74,6 +47,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501
}
@ -187,11 +161,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
@ -244,11 +216,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
@ -337,10 +307,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)
@ -435,10 +405,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
@ -11,44 +13,41 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
from vllm.attention import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern:
class AttentionQuantPattern(ABC):
"""
Fusion for Attention+StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
op will be removed from the graph, and its scale will be passed into
Attention op as the `output_scale` argument.
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer: Attention,
quant_dtype: torch.dtype,
symmetric=True,
quant_key: QuantKey,
):
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
@ -57,12 +56,49 @@ class AttentionStaticQuantPattern:
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(
self.quant_dtype, self.quant_key.static,
self.quant_key.group_shape):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
symmetric: bool = True,
):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(layer, quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
@ -74,9 +110,10 @@ class AttentionStaticQuantPattern:
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None)
attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size])
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
result=output_quant,
input=attn_out_view,
@ -98,7 +135,8 @@ class AttentionStaticQuantPattern:
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale)
output_scale=scale,
output_block_scale=None)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
@ -114,21 +152,94 @@ class AttentionStaticQuantPattern:
empty_fp32(1, 1) # scale
]
def wrap_trace_fn(process_fx, trace_fn):
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention):
super().__init__(layer, kNvfp4Quant)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(
output_scale, FP8_DTYPE)
at2 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view)
output = RESHAPE_OP(at2[1],
[-1, self.num_heads * self.head_size // 2])
return output, at2[2]
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size //
2), # output_quant
empty_i32(128,
round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement(
pattern, replacement, inputs,
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
class AttnFusionPass(VllmInductorPass):
@ -151,8 +262,12 @@ class AttnFusionPass(VllmInductorPass):
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE)
pattern.register_if_supported(self.patterns)
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern_fp8.register_if_supported(self.patterns)
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but no attention layers "
@ -175,4 +290,6 @@ class AttnFusionPass(VllmInductorPass):
self.end_and_log()
def uuid(self):
return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern)
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern)

View File

@ -2,16 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from torch import fx
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
@dataclass(frozen=True)
class ScaleDesc:
"""
Class for describing a single quantization scaling factor.
dtype: data type of the scale
static: static scale if True, dynamic if False
group_shape: group shape of the scale
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'static' if self.static else 'dynamic'},{group_shape}")
@dataclass(frozen=True)
class QuantKey:
"""
Class for identifying the type of quantization.
dtype: quantized data type
scale: scale descriptor
scale2: second-level scale descriptor
symmetric: symmetric if True, asymmetric if False
"""
dtype: torch.dtype
scale: ScaleDesc
scale2: Optional[ScaleDesc] = None
symmetric: bool = True
def __str__(self):
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
f"scale({self.scale}),{scale2_str}"
f"{'a' if not self.symmetric else ''}symmetric)")
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE,
scale=kNvfp4GroupScale,
scale2=kStaticTensorScale)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent

View File

@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@ -497,7 +498,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")

View File

@ -430,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -447,7 +448,7 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention,
@ -40,6 +41,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
logger = init_logger(__name__)
@ -653,14 +655,12 @@ class FlashInferImpl(AttentionImpl):
and num_heads % num_kv_heads == 0)
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
self.o_sf_scale: Optional[float] = None
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
supported_quant_type = (dtype == FP8_DTYPE and static
and group_shape == GroupShape.PER_TENSOR)
def fused_output_quant_supported(self, quant_key: QuantKey):
return (self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
and supported_quant_type)
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant))
def forward(
self,
@ -672,6 +672,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
@ -705,19 +706,32 @@ class FlashInferImpl(AttentionImpl):
if output_scale is None:
assert attn_metadata.q_data_type != FP8_DTYPE, \
"Query can only be FP8 if output fusion happened."
assert output_block_scale is None, "output_block_scale "\
"is not supported when fusion has not happened"
else:
assert attn_metadata.q_data_type == FP8_DTYPE, \
"Query must be FP8 when attn+quant fusion happened."
assert (attn_metadata.prefill_use_trtllm and
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
assert output.dtype == FP8_DTYPE, \
"Output must be FP8 when attn+quant fusion happened."
# TRTLLM attn kernel requires o scale as a host scalar, store the
# o scale to host scalar in warmup run with cuda graph not enabled
if output.dtype == FP8_DTYPE:
assert output_block_scale is None, \
"output_block_scale should not be provided for fp8 output"
elif output.dtype == FP4_DTYPE:
assert output_block_scale is not None, \
"output_block_scale is required for nvfp4 output"
else:
raise ValueError(f"Unsupported output dtype: {output.dtype}")
# TRTLLM attn kernel requires o scale to pass as a host scalar,
# store the o scale as a host scalar in warmup run with cuda graph
# not enabled
if layer._o_scale_float is None:
layer._o_scale_float = output_scale.cpu().item()
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
if output.dtype == FP8_DTYPE:
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float
# Insert FP8 quant for query
num_tokens, num_heads, head_size = query.shape
@ -818,6 +832,16 @@ class FlashInferImpl(AttentionImpl):
assert block_tables_prefill.is_contiguous()
assert seq_lens_prefill.is_contiguous()
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
out = FP4Tensor(data=output[num_decode_tokens:],
scale=output_block_scale,
scale_start_index=num_decode_tokens,
original_shape=prefill_query.shape)
else:
assert self.o_sf_scale is None
out = output[num_decode_tokens:]
trtllm_batch_context_with_kv_cache(
query=prefill_query,
kv_cache=kv_cache_permute,
@ -833,7 +857,8 @@ class FlashInferImpl(AttentionImpl):
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
window_left=self.window_left,
sinks=self.sinks,
out=output[num_decode_tokens:],
o_sf_scale=self.o_sf_scale,
out=out,
)
if num_decode_tokens > 0:
@ -870,6 +895,16 @@ class FlashInferImpl(AttentionImpl):
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
out = FP4Tensor(data=output[:num_decode_tokens],
scale=output_block_scale,
scale_start_index=0,
original_shape=decode_query.shape)
else:
assert self.o_sf_scale is None
out = output[:num_decode_tokens]
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
@ -881,7 +916,8 @@ class FlashInferImpl(AttentionImpl):
bmm2_scale=self.bmm2_scale,
window_left=self.window_left,
sinks=self.sinks,
out=output[:num_decode_tokens],
o_sf_scale=self.o_sf_scale,
out=out,
)
return output_padded

View File

@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FLexAttention.
@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlexAttentionImpl")

View File

@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")

View File

@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl")

View File

@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.
@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl):
attn_metadata: TreeAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with TreeAttention.
@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TreeAttentionImpl")

View File

@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TritonAttentionImpl")

View File

@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl):
attn_metadata: XFormersAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with XFormers.
@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl")