diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 72b54b40a2d1e..603ce5ecf0d2c 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -9,8 +9,11 @@ from typing import Optional import flashinfer import torch +from vllm.utils import round_up + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -61,13 +64,13 @@ def benchmark_decode( else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) if q_quant_dtype == FP8_DTYPE: - query, q_scale = to_float8(query) - ref_query = query.to(dtype) * q_scale + query, _ = to_float8(ref_query) else: - q_scale = 1.0 - ref_query = query + query = ref_query kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_seq_len @@ -75,14 +78,13 @@ def benchmark_decode( seq_lens = kv_lens max_seq_len = torch.max(seq_lens).item() - kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) if kv_quant_dtype == FP8_DTYPE: - kv_cache, kv_scale = to_float8(kv_cache) - ref_kv_cache = kv_cache.to(dtype) * kv_scale + kv_cache, _ = to_float8(ref_kv_cache) else: - kv_scale = 1.0 - ref_kv_cache = kv_cache - k_scale = v_scale = kv_scale + kv_cache = ref_kv_cache max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( @@ -142,11 +144,31 @@ def benchmark_decode( return sum(times) / len(times), torch.std(torch.tensor(times)) o_scale = 1.0 + o_sf_scale = None output_baseline = torch.empty(ref_query.shape, dtype=dtype) - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_decode(): - return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) def trtllm_decode(): return flashinfer.decode.trtllm_batch_decode_with_kv_cache( @@ -158,6 +180,7 @@ def benchmark_decode( max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -237,6 +260,7 @@ if __name__ == "__main__": (None, None, None), (None, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] for quant_dtype in quant_dtypes: diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 49810e20c7d82..40903c6c3444f 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -9,8 +9,11 @@ from typing import Optional import flashinfer import torch +from vllm.utils import round_up + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -72,13 +75,15 @@ def benchmark_prefill( ] ) - query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn( + torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype + ) if q_quant_dtype == FP8_DTYPE: - query, q_scale = to_float8(query) - ref_query = query.to(dtype) * q_scale + query, _ = to_float8(ref_query) else: - q_scale = 1.0 - ref_query = query + query = ref_query kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len @@ -86,14 +91,13 @@ def benchmark_prefill( seq_lens = kv_lens + q_lens max_seq_len = torch.max(seq_lens).item() - kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) if kv_quant_dtype == FP8_DTYPE: - kv_cache, kv_scale = to_float8(kv_cache) - ref_kv_cache = kv_cache.to(dtype) * kv_scale + kv_cache, _ = to_float8(ref_kv_cache) else: - kv_scale = 1.0 - ref_kv_cache = kv_cache - k_scale = v_scale = kv_scale + kv_cache = ref_kv_cache max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( @@ -152,11 +156,31 @@ def benchmark_prefill( return sum(times) / len(times), torch.std(torch.tensor(times)) o_scale = 1.0 + o_sf_scale = None output_baseline = torch.empty(ref_query.shape, dtype=dtype) - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_prefill(): - return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( @@ -172,6 +196,7 @@ def benchmark_prefill( batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -250,6 +275,7 @@ if __name__ == "__main__": # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] for quant_dtype in quant_dtypes: diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index aade29b99de7e..0c7e6fbccf20c 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -8,11 +8,12 @@ import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, - kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fusion import FUSED_OPS, FusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from .backend import TestBackend diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4a3820e20fd89..5cfad935a0fb1 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -7,11 +7,13 @@ import torch import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass, GroupShape, QuantKey) + FusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) from vllm.platforms import current_platform @@ -30,10 +32,8 @@ class TestModel(torch.nn.Module): self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - self.key = QuantKey(dtype=FP8_DTYPE, - static=static, - group_shape=group_shape, - symmetric=True) + quant_scale = ScaleDesc(torch.float32, static, group_shape) + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index bef0fdef985ec..dba668cfa16a6 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -11,9 +11,10 @@ from tests.models.utils import check_outputs_equal from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata) from vllm import LLM, SamplingParams +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass @@ -22,13 +23,14 @@ from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, set_current_vllm_config) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 # globals needed for string-import custom Dynamo backend field backend: Optional[TestBackend] = None @@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # check support attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) + layer.impl.fused_output_quant_supported(quant_key) for key, layer in compile_config.static_forward_context.items() ] @@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend = None -class TestAttentionStaticQuantPatternModel(torch.nn.Module): - """Test model for AttentionStaticQuantPattern fusion.""" +class AttentionQuantPatternModel(torch.nn.Module): + """Base model for AttentionQuantPattern fusion.""" def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype, device: torch.device, - vllm_config: VllmConfig): + vllm_config: VllmConfig, **kwargs): super().__init__() self.num_qo_heads = num_qo_heads self.num_kv_heads = num_kv_heads @@ -172,11 +172,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module): prefix="model.layers.0.self_attn.attn", ) - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) - self.wscale = torch.tensor([1.0], dtype=torch.float32) - self.scale = torch.tensor([1.0], dtype=torch.float32) - self.block_size = 16 # Initialize attn MetadataBuilder @@ -230,23 +225,86 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module): return self.attn_metadata - def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - w: torch.Tensor): + +class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionFp8StaticQuantPattern fusion.""" + + quant_key = kFp8StaticTensorSym + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.quant_key.scale.static, + act_quant_group_shape=self.quant_key.scale.group_shape) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randn(hidden_size, hidden_size).to( + dtype=FP8_DTYPE, device=self.device).t(), + "wscale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) return self.fp8_linear.apply(input=attn_output, - weight=w, - weight_scale=self.wscale, - input_scale=self.scale) + weight=self.w["weight"], + weight_scale=self.w["wscale"], + input_scale=self.w["scale"]) + + +class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionNvfp4QuantPattern fusion.""" + + quant_key = kNvfp4Quant + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE, + device=self.device), + "wscale_swizzled": + torch.randn(hidden_size, hidden_size // 16).to( + dtype=FP8_DTYPE, device=self.device), + "wscale": + torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([0.002], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + quant_output, output_block_scale = scaled_fp4_quant( + attn_output, 1 / self.w["scale"]) + return cutlass_scaled_fp4_mm(a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype) @pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("batch_size", [7, 256, 533]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize( - "model_name, quant_key", - [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)]) +@pytest.mark.parametrize("model_name, model_class", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)]) @pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @@ -255,8 +313,8 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module): def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, model_name: str, - quant_key: QuantKey, backend: _Backend, - monkeypatch, dist_init): + model_class: type[AttentionQuantPatternModel], + backend: _Backend, monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" monkeypatch.setenv("VLLM_USE_V1", "1") @@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, cache_config=CacheConfig(cache_dtype="fp8")) # Create test inputs - hidden_size = num_qo_heads * head_size - q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + q = torch.randn(batch_size, + num_qo_heads * head_size, + dtype=dtype, + device=device) k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, @@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, num_kv_heads * head_size, dtype=dtype, device=device) - linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t() # Mark first dimension as dynamic for realistic testing torch._dynamo.mark_dynamic(q, 0) @@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, with set_current_vllm_config(vllm_config_unfused), set_forward_context( attn_metadata=None, vllm_config=vllm_config_unfused ), global_force_attn_backend_context_manager(backend): - model_unfused = TestAttentionStaticQuantPatternModel( - num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - vllm_config_unfused) + model_unfused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config_unfused) model_unfused = model_unfused.to(device) forward_ctx = get_forward_context() @@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, batch_size) # Run model directly without compilation and fusion - result_unfused = model_unfused(q, k, v, linear_w) + result_unfused = model_unfused(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( @@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, with set_current_vllm_config(vllm_config), set_forward_context( attn_metadata=None, vllm_config=vllm_config ), global_force_attn_backend_context_manager(backend): - model_fused = TestAttentionStaticQuantPatternModel( - num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - vllm_config) + model_fused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config, + w=model_unfused.w) model_fused = model_fused.to(device) forward_ctx = get_forward_context() @@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, backend=test_backend, fullgraph=True) assert model_compiled.attn._o_scale_float is None - result_fused_1 = model_compiled(q, k, v, linear_w) + result_fused_1 = model_compiled(q, k, v) # After the 1st round of the forward pass, output quant scale should be # loaded into the attn layer's _o_scale_float, the 2nd round should # reuse the loaded _o_scale_float assert model_compiled.attn._o_scale_float is not None - result_fused_2 = model_compiled(q, k, v, linear_w) + result_fused_2 = model_compiled(q, k, v) assert model_compiled.attn._o_scale_float is not None # Check attn fusion support + quant_key = model_class.quant_key attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) for key, - layer in vllm_config.compilation_config.static_forward_context.items() + layer.impl.fused_output_quant_supported(quant_key) for key, layer in + vllm_config.compilation_config.static_forward_context.items() ] if any(attn_fusion_supported): # Check quantization ops in the graph before and after fusion @@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ "Attention should have output_scale after fusion" + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale before fusion" + if quant_key.dtype == FP8_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale after FP8 fusion" + elif quant_key.dtype == FP4_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ + "Attention should have output_block_scale after FP4 fusion" # noqa: E501 + # Check that results are closed torch.testing.assert_close(result_unfused, result_fused_1, diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 69e44264cd440..8d0a11d8eb8ab 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -6,7 +6,11 @@ import flashinfer import pytest import torch +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm.platforms import current_platform +from vllm.utils import round_up if not current_platform.is_device_capability(100): pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", @@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100): FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16] QUANT_DTYPES = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), + (None, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] BATCH_SIZE = [4, 12] MAX_SEQ_LENS = [(1024, 4096)] @@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline( output = torch.empty(ref_query.shape, dtype=dtype) wrapper.run(ref_query, ref_kv_cache, out=output) o_scale = 1.0 + o_sf_scale = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) # TRTLLM Decode - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=query, kv_cache=kv_cache, @@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline( max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) - if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 3e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 5e-2, 7e-2 else: - rtol, atol = 1e-2, 1e-2 + rtol, atol = 1e-2, 2e-2 torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" @@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_quant_dtype = kv_quant_dtype or dtype o_quant_dtype = o_quant_dtype or dtype + if q_quant_dtype != kv_quant_dtype: + pytest.skip("Skipped mixed QKV dtypes for prefill") + max_q_len, max_kv_len = max_seq_lens num_qo_heads, num_kv_heads = num_heads @@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline( output = torch.empty(ref_query.shape, dtype=dtype) wrapper.run(ref_query, ref_kv_cache, out=output) o_scale = 1.0 + o_sf_scale = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) # TRTLLM Prefill - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + flashinfer.prefill.trtllm_batch_context_with_kv_cache( query=query, kv_cache=kv_cache, @@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline( batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: output_trtllm = output_trtllm.to(dtype) * o_scale + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) - if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 4e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 5e-2, 7e-2 else: rtol, atol = 1e-2, 1e-2 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index d21f07756871a..0b9c625533cb7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. - TODO(luka) merge parameters into QuantDescriptor - :param dtype: quantized dtype - :param static: static or dynamic quantization - :param group_shape: quant group shape. + :param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization """ return False @@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index fac3c318a87a0..ce9467efd23c7 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -800,6 +800,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl): attn_metadata: DifferentialFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -817,6 +818,11 @@ class DifferentialFlashAttentionImpl(AttentionImpl): {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for DifferentialFlashAttentionImpl") + if self.lambda_full is None: self.lambda_init = self.differential_flash_attention_config[ "lambda_init"] diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index fa6f3f1b39cca..85957bea1e26d 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): attn_metadata: DualChunkFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with DualChunkFlashAttention. Args: @@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): """ assert output is None, "Output tensor not supported for DualChunk" - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e52480d5c5ce2..ba7a9afe86782 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -615,7 +616,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 9d6ab7e3217b0..c5ed4c6e40326 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1238,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLAImplBase") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 63e467f5a7a22..e4c27a0ef36e9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): if self.use_triton_flash_attn: - return dtype == current_platform.fp8_dtype( - ) and static and group_shape == GroupShape.PER_TENSOR + return quant_key == kFp8StaticTensorSym # Only supported in the Triton backend return False @@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): attn_metadata: ROCmFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -606,6 +605,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): "fused output quantization only supported for Triton" " implementation in ROCMFlashAttentionImpl for now") + if output_block_scale is not None: + raise NotImplementedError( + "fused nvfp4 output quantization is not supported" + " for ROCMFlashAttentionImpl") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0bc38b4142901..c1213f7620a7a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): attn_metadata: "XFormersMetadata", output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -484,7 +485,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersImpl") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 04ab100c8775d..9fbead31782a9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -495,6 +495,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -510,7 +511,8 @@ def unified_attention_with_output( kv_cache, attn_metadata, output=output, - output_scale=output_scale) + output_scale=output_scale, + output_block_scale=output_block_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -522,6 +524,7 @@ def unified_attention_with_output_fake( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: return @@ -529,7 +532,7 @@ def unified_attention_with_output_fake( direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, - mutates_args=["output"], + mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 3dec939c28351..413948799de35 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,7 +12,8 @@ from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe @@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def empty_bf16(*args, **kwargs): @@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") +def empty_i32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") + + RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default - -class QuantKey(NamedTuple): - """ - Named tuple for identifying the type of quantization. - dtype: quantized data type - static: static quantization if True, dynamic if False - group_shape: quantization group shape - symmetric: symmetric if True, asymmetric if False - - TODO(luka) use QuantDescriptor once standardized: - https://github.com/vllm-project/vllm/issues/8913 - - """ - dtype: torch.dtype - static: bool - group_shape: GroupShape - symmetric: bool = True - - def __str__(self): - group_shape = ('per_tensor' - if self.group_shape == GroupShape.PER_TENSOR else - ('per_token' if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape))) - - return (f"QuantKey({'static' if self.static else 'dynamic'}," - f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape}," - f"{'a' if not self.symmetric else ''}symmetric)") - - -kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) -kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) -kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) - QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 @@ -74,6 +47,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501 } @@ -187,11 +161,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): @@ -244,11 +216,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, @@ -337,10 +307,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=False, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) @@ -435,10 +405,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=True, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 1f77a26676138..f942afe6a28ee 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -11,44 +13,41 @@ from torch._subclasses.fake_tensor import (FakeTensorMode, from vllm.attention import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from vllm.utils import round_up -from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default -class AttentionStaticQuantPattern: +class AttentionQuantPattern(ABC): """ - Fusion for Attention+StaticQuant. - - Only triggers when the attention implementation returns True in - `fused_output_quant_supported()`. If the pattern is found, the StaticQuant - op will be removed from the graph, and its scale will be passed into - Attention op as the `output_scale` argument. + The base class for Attn+Quant fusions. + Should not be used directly. """ def __init__( self, layer: Attention, - quant_dtype: torch.dtype, - symmetric=True, + quant_key: QuantKey, ): self.layer = layer self.layer_name = layer.layer_name self.num_heads = layer.num_heads self.head_size = layer.head_size - self.quant_dtype = quant_dtype - self.quant_key = QuantKey(dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric) + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] @@ -57,12 +56,49 @@ class AttentionStaticQuantPattern: kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) + @staticmethod + def wrap_trace_fn(process_fx, trace_fn): + + def wrapped(*args, **kwargs): + return process_fx(trace_fn(*args, **kwargs)) + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) + return gm + def register_if_supported(self, pm_pass: PatternMatcherPass): - if self.layer.impl.fused_output_quant_supported( - self.quant_dtype, self.quant_key.static, - self.quant_key.group_shape): + if self.layer.impl.fused_output_quant_supported(self.quant_key): self._register(pm_pass) + @abstractmethod + def _register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError + + +class AttentionFp8StaticQuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Fp8StaticQuant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Fp8StaticQuant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__( + self, + layer: Attention, + symmetric: bool = True, + ): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(layer, quant_key) + def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -74,9 +110,10 @@ class AttentionStaticQuantPattern: value=v, output=output_attn, layer_name=self.layer_name, - output_scale=None) - attn_out_view = RESHAPE_OP(at1[1], - [-1, self.num_heads * self.head_size]) + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, result=output_quant, input=attn_out_view, @@ -98,7 +135,8 @@ class AttentionStaticQuantPattern: value=v, output=output_attn, layer_name=self.layer_name, - output_scale=scale) + output_scale=scale, + output_block_scale=None) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) # Need custom fake mode, otherwise tracing happens with real tensors. @@ -114,21 +152,94 @@ class AttentionStaticQuantPattern: empty_fp32(1, 1) # scale ] - def wrap_trace_fn(process_fx, trace_fn): + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) - def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) - return wrapped +class AttentionNvfp4QuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Nvfp4Quant. - def fx_view_to_reshape(gm: torch.fx.GraphModule): - from torch._inductor.fx_passes.post_grad import view_to_reshape - view_to_reshape(gm) - return gm + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Nvfp4Quant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__(self, layer: Attention): + super().__init__(layer, kNvfp4Quant) + + def _register(self, pm_pass: PatternMatcherPass): + + def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) + at2 = auto_functionalized(self.QUANT_OP, + output=output_quant, + input=attn_out_view, + output_scale=output_scale, + input_scale=input_scale) + output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) + return at2[1], output_scale_view + + def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + # attention output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size // 2], + 0.0, + dtype=self.quant_dtype, + device=q.device) + # attention output block scale + output_scale_view = torch.ops.aten.view.dtype( + output_scale, FP8_DTYPE) + at2 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view) + output = RESHAPE_OP(at2[1], + [-1, self.num_heads * self.head_size // 2]) + return output, at2[2] + + # Need custom fake mode, otherwise tracing happens with real tensors. + # That would not work for the unified_attention custom op. + with unset_fake_temporarily(), FakeTensorMode(): + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # output_attn + self.empty_quant(5, self.num_heads * self.head_size // + 2), # output_quant + empty_i32(128, + round_up(self.num_heads * self.head_size // 16, + 4)), # output_scale + empty_fp32(1, 1), # input_scale + ] pm.register_replacement( pattern, replacement, inputs, - wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) class AttnFusionPass(VllmInductorPass): @@ -151,8 +262,12 @@ class AttnFusionPass(VllmInductorPass): attn_layers = get_layers_from_vllm_config(config, Attention) for layer_name, layer in attn_layers.items(): - pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE) - pattern.register_if_supported(self.patterns) + pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8.register_if_supported(self.patterns) + + pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) + pattern_nvfp4.register_if_supported(self.patterns) + if len(attn_layers) == 0: logger.warning( "Attention + quant fusion is enabled, but no attention layers " @@ -175,4 +290,6 @@ class AttnFusionPass(VllmInductorPass): self.end_and_log() def uuid(self): - return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern) + return VllmInductorPass.hash_source(self, AttentionQuantPattern, + AttentionFp8StaticQuantPattern, + AttentionNvfp4QuantPattern) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 97e5922ebd55f..6154fca2e416d 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -2,16 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" from collections.abc import Mapping +from dataclasses import dataclass from types import MappingProxyType from typing import ClassVar, NamedTuple, Optional import numpy import torch +from torch import fx from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + # Use proxy as NamedTuple direct subclasses cannot have static members class _GroupShape(NamedTuple): @@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) +@dataclass(frozen=True) +class ScaleDesc: + """ + Class for describing a single quantization scaling factor. + dtype: data type of the scale + static: static scale if True, dynamic if False + group_shape: group shape of the scale + """ + dtype: torch.dtype + static: bool + group_shape: GroupShape + + def __str__(self): + group_shape = ('per_tensor' + if self.group_shape == GroupShape.PER_TENSOR else + ('per_token' if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape))) + + return (f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'static' if self.static else 'dynamic'},{group_shape}") + + +@dataclass(frozen=True) +class QuantKey: + """ + Class for identifying the type of quantization. + dtype: quantized data type + scale: scale descriptor + scale2: second-level scale descriptor + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + scale: ScaleDesc + scale2: Optional[ScaleDesc] = None + symmetric: bool = True + + def __str__(self): + scale2_str = f"scale2({self.scale2})," if self.scale2 else "" + return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," + f"scale({self.scale}),{scale2_str}" + f"{'a' if not self.symmetric else ''}symmetric)") + + +kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR) +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) + +kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) + +kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) + +kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Quant = QuantKey(FP4_DTYPE, + scale=kNvfp4GroupScale, + scale2=kStaticTensorScale) + + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 9ed46331863c9..973979fdf7dfd 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, # type: ignore output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -497,7 +498,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TorchSDPABackendImpl") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index eca83b6d2ee45..6e7096de924ca 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -430,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -447,7 +448,7 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1e6e3f1d0abf4..50819bb2bb943 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache +from flashinfer.utils import FP4Tensor from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.platforms import current_platform from vllm.utils import cdiv, is_pin_memory_available from vllm.utils.flashinfer import (supports_trtllm_attention, @@ -40,6 +41,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 logger = init_logger(__name__) @@ -653,14 +655,12 @@ class FlashInferImpl(AttentionImpl): and num_heads % num_kv_heads == 0) self.bmm1_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None + self.o_sf_scale: Optional[float] = None - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): - supported_quant_type = (dtype == FP8_DTYPE and static - and group_shape == GroupShape.PER_TENSOR) + def fused_output_quant_supported(self, quant_key: QuantKey): return (self.support_trtllm_attn and self.kv_cache_dtype.startswith("fp8") - and supported_quant_type) + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) def forward( self, @@ -672,6 +672,7 @@ class FlashInferImpl(AttentionImpl): attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -705,19 +706,32 @@ class FlashInferImpl(AttentionImpl): if output_scale is None: assert attn_metadata.q_data_type != FP8_DTYPE, \ "Query can only be FP8 if output fusion happened." + assert output_block_scale is None, "output_block_scale "\ + "is not supported when fusion has not happened" else: assert attn_metadata.q_data_type == FP8_DTYPE, \ "Query must be FP8 when attn+quant fusion happened." assert (attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" - assert output.dtype == FP8_DTYPE, \ - "Output must be FP8 when attn+quant fusion happened." - # TRTLLM attn kernel requires o scale as a host scalar, store the - # o scale to host scalar in warmup run with cuda graph not enabled + if output.dtype == FP8_DTYPE: + assert output_block_scale is None, \ + "output_block_scale should not be provided for fp8 output" + elif output.dtype == FP4_DTYPE: + assert output_block_scale is not None, \ + "output_block_scale is required for nvfp4 output" + else: + raise ValueError(f"Unsupported output dtype: {output.dtype}") + + # TRTLLM attn kernel requires o scale to pass as a host scalar, + # store the o scale as a host scalar in warmup run with cuda graph + # not enabled if layer._o_scale_float is None: layer._o_scale_float = output_scale.cpu().item() - self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + if output.dtype == FP8_DTYPE: + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + elif output.dtype == FP4_DTYPE: + self.o_sf_scale = layer._o_scale_float # Insert FP8 quant for query num_tokens, num_heads, head_size = query.shape @@ -818,6 +832,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_prefill.is_contiguous() assert seq_lens_prefill.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape) + else: + assert self.o_sf_scale is None + out = output[num_decode_tokens:] + trtllm_batch_context_with_kv_cache( query=prefill_query, kv_cache=kv_cache_permute, @@ -833,7 +857,8 @@ class FlashInferImpl(AttentionImpl): cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, window_left=self.window_left, sinks=self.sinks, - out=output[num_decode_tokens:], + o_sf_scale=self.o_sf_scale, + out=out, ) if num_decode_tokens > 0: @@ -870,6 +895,16 @@ class FlashInferImpl(AttentionImpl): assert block_tables_decode.is_contiguous() assert seq_lens_decode.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape) + else: + assert self.o_sf_scale is None + out = output[:num_decode_tokens] + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -881,7 +916,8 @@ class FlashInferImpl(AttentionImpl): bmm2_scale=self.bmm2_scale, window_left=self.window_left, sinks=self.sinks, - out=output[:num_decode_tokens], + o_sf_scale=self.o_sf_scale, + out=out, ) return output_padded diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index abca981035d9e..f4aa54660a078 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl): attn_metadata: FlexAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl): shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlexAttentionImpl") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5c0ff943794a8..ce45b34f64355 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): attn_metadata: M, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLACommonImpl") diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 3eb4a0e7a574d..fd97db0abb84f 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl): attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for PallasAttentionBackendImpl") diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index b9ff113573a12..403ad8e88a958 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl): attn_metadata: AiterFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 2a0c52377cc7f..c93223a340839 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl): attn_metadata: TreeAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with TreeAttention. @@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TreeAttentionImpl") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c69dd8415f922..b12036c599799 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl): attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TritonAttentionImpl") diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index b305bc1539081..e0eb7d8be9746 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl): attn_metadata: XFormersAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with XFormers. @@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersAttentionImpl")