From 9a161307f5f096c63ae4134c5055d87a36d224a8 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:59:55 -0400 Subject: [PATCH] [torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Gregory Shtrasberg Signed-off-by: Luka Govedič Co-authored-by: Luka Govedič Co-authored-by: Luka Govedič --- tests/compile/test_fusion_attn.py | 129 ++++++++++------ .../ops/chunked_prefill_paged_decode.py | 18 ++- vllm/attention/ops/prefix_prefill.py | 14 +- .../attention/ops/triton_unified_attention.py | 139 ++++++++++-------- vllm/compilation/backends.py | 3 +- vllm/compilation/fusion_attn.py | 31 ++-- .../layers/quantization/utils/w8a8_utils.py | 38 ++--- vllm/v1/attention/backends/triton_attn.py | 12 +- 8 files changed, 249 insertions(+), 135 deletions(-) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index dba668cfa16a6..6baf4bf83f499 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -40,13 +40,12 @@ backend_unfused: Optional[TestBackend] = None @pytest.mark.parametrize( "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) -@pytest.mark.parametrize( - "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) +@pytest.mark.parametrize("use_triton_fa", [True, False]) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test CUDA and ROCm") -def test_attention_fusion(example_prompts, monkeypatch, model: str, - quant_key: QuantKey, use_triton_fa: bool): +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="V0 attn quant fusion only on ROCm") +def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, + quant_key: QuantKey, use_triton_fa: bool): # Clean Dynamo cache to avoid reusing other test cases # (for some reason the reset at the end is not enough) torch._dynamo.reset() @@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend_unfused", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config) + vllm_config = VllmConfig(compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + )) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) llm = LLM(model, enforce_eager=True, compilation_config=compile_config, - gpu_memory_utilization=0.9, + gpu_memory_utilization=0.5, max_model_len=2048) sampling_params = SamplingParams(temperature=0.0, @@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config) + vllm_config = VllmConfig(compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + )) # AttnFusionPass needs attention layers to be registered in config upon init # so we initialize it during compilation. @@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, llm2 = LLM(model, enforce_eager=True, compilation_config=compile_config, - gpu_memory_utilization=0.9, + gpu_memory_utilization=0.5, max_model_len=2048) # check support @@ -171,6 +178,8 @@ class AttentionQuantPatternModel(torch.nn.Module): cache_config=vllm_config.cache_config, prefix="model.layers.0.self_attn.attn", ) + self.attn._k_scale = self.attn._k_scale.to(device) + self.attn._v_scale = self.attn._v_scale.to(device) self.block_size = 16 @@ -188,7 +197,7 @@ class AttentionQuantPatternModel(torch.nn.Module): device=self.device, ) - def build_attn_metadata(self, batch_size: int): + def build_attn_metadata(self, batch_size: int, use_hnd: bool): """Initialize attention metadata.""" # Create common attn metadata @@ -205,10 +214,8 @@ class AttentionQuantPatternModel(torch.nn.Module): num_blocks = batch_size * max_blocks # Create dummy KV cache for FlashInfer TRTLLM - # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - # Create kv_cache in HND layout and permute to NHD layout - # (later will be permuted back to HND layout in forward pass) + # - NHD: [num_blocks, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, num_kv_heads, block_size, head_size] kv_cache = torch.zeros(num_blocks, 2, self.num_kv_heads, @@ -216,7 +223,17 @@ class AttentionQuantPatternModel(torch.nn.Module): self.head_size, dtype=self.kv_cache_dtype, device=self.device) - kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + if current_platform.is_rocm(): + # k/v as 1st dimention + if use_hnd: + kv_cache = kv_cache.permute(1, 0, 2, 3, 4) + else: + kv_cache = kv_cache.permute(1, 0, 3, 2, 4) + else: + # k/v as 2nd dimention + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) self.attn.kv_cache = [kv_cache] # Build attn metadata @@ -296,28 +313,51 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): out_dtype=attn_output.dtype) -@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +if current_platform.is_cuda(): + MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)] + HEADS = [(64, 8), (40, 8)] +elif current_platform.is_rocm(): + MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV", + TestAttentionFp8StaticQuantPatternModel)] + HEADS = [(32, 8), (40, 8)] +else: + MODELS = [] + HEADS = [] + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) -@pytest.mark.parametrize("batch_size", [7, 256, 533]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("model_name, model_class", - [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - TestAttentionFp8StaticQuantPatternModel), - ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel)]) -@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) -@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.parametrize("batch_size", + [7, 256, 533] if current_platform.is_cuda() else [8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("model_name, model_class", MODELS) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if + current_platform.is_cuda() else [_Backend.ROCM_FLASH]) +@pytest.mark.parametrize( + "split_attention", + [False, True] if current_platform.is_rocm() else [False]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test ROCm or CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), - reason="Only test on SM100(Blackwell)") +@pytest.mark.skipif(current_platform.is_cuda() + and not current_platform.is_device_capability((10, 0)), + reason="On CUDA only test on SM100(Blackwell)") +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test ROCm or CUDA") def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, model_name: str, model_class: type[AttentionQuantPatternModel], - backend: _Backend, monkeypatch, dist_init): + backend: _Backend, split_attention: bool, + monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" monkeypatch.setenv("VLLM_USE_V1", "1") + if split_attention: + monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1") device = torch.device("cuda:0") torch.manual_seed(42) @@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, model_config=ModelConfig( model=model_name, max_model_len=2048, + dtype=dtype, ), scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( @@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata( - batch_size) + batch_size, use_hnd=split_attention) # Run model directly without compilation and fusion result_unfused = model_unfused(q, k, v) @@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, model_fused = model_fused.to(device) forward_ctx = get_forward_context() - forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + forward_ctx.attn_metadata = model_fused.build_attn_metadata( + batch_size, use_hnd=split_attention) # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) @@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert model_compiled.attn._o_scale_float is None result_fused_1 = model_compiled(q, k, v) - # After the 1st round of the forward pass, output quant scale should be - # loaded into the attn layer's _o_scale_float, the 2nd round should - # reuse the loaded _o_scale_float - assert model_compiled.attn._o_scale_float is not None - result_fused_2 = model_compiled(q, k, v) - assert model_compiled.attn._o_scale_float is not None + if backend == _Backend.FLASHINFER: + # With the Flashinfer backend after the 1st round of the forward + # pass, output quant scale should be loaded into the attn layer's + # _o_scale_float, the 2nd round should reuse the loaded + # _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is not None + + torch.testing.assert_close(result_unfused, + result_fused_2, + atol=1e-2, + rtol=1e-2) # Check attn fusion support quant_key = model_class.quant_key @@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ "Attention should have output_block_scale after FP4 fusion" # noqa: E501 - # Check that results are closed + # Check that results are close torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(result_unfused, - result_fused_2, - atol=1e-2, - rtol=1e-2) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index e5b90a8b27558..bf4b06512a3c1 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd +float8_info = torch.finfo(current_platform.fp8_dtype()) + @triton.jit def cdiv_fn(x, y): @@ -34,6 +36,7 @@ def kernel_paged_attention_2d( scale, # float32 k_scale, # float32 v_scale, # float32 + out_scale_inv, num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int num_queries_per_kv_padded: tl.constexpr, # int @@ -60,7 +63,9 @@ def kernel_paged_attention_2d( filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] USE_SINKS: tl.constexpr, # bool -): + USE_FP8: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -204,6 +209,9 @@ def kernel_paged_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) output_offset = (cur_batch_in_all_start_index * output_stride_0 + query_head_idx * output_stride_1) @@ -234,6 +242,7 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + output_scale=None, # Optional tensor for sinks sinks=None, ): @@ -266,6 +275,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + fp8_out_scale=output_scale, sinks=sinks, ) @@ -316,7 +326,7 @@ def chunked_prefill_paged_decode( tmp_output = torch.empty( size=(total_num_seq, num_query_heads, max_num_partitions, head_size), - dtype=output.dtype, + dtype=query.dtype, device=output.device, ) exp_sums = torch.empty( @@ -345,6 +355,7 @@ def chunked_prefill_paged_decode( kv_cache_dtype=kv_cache_dtype, k_scale=k_scale, v_scale=v_scale, + fp8_out_scale=output_scale, ) else: kernel_paged_attention_2d[( @@ -362,6 +373,8 @@ def chunked_prefill_paged_decode( scale=sm_scale, k_scale=k_scale, v_scale=v_scale, + out_scale_inv=1.0 / + output_scale if output_scale is not None else 1.0, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, num_queries_per_kv_padded=num_queries_per_kv_padded, @@ -388,4 +401,5 @@ def chunked_prefill_paged_decode( filter_by_query_len=True, query_start_len_ptr=query_start_loc, USE_SINKS=sinks is not None, + USE_FP8=output_scale is not None, ) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a70db89cdb76e..7e5c2b6c62e9b 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8 # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) +float8_info = torch.finfo(current_platform.fp8_dtype()) # Here's an example autotuner config for this kernel. This config does provide @@ -43,6 +44,7 @@ def _fwd_kernel(Q, sm_scale, k_scale, v_scale, + out_scale_inv, B_Start_Loc, B_Seqlen, x: tl.constexpr, @@ -82,8 +84,11 @@ def _fwd_kernel(Q, num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr, USE_SINKS: tl.constexpr, + USE_FP8: tl.constexpr, MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0): + MAX_CTX_LEN: tl.constexpr = 0, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -284,6 +289,9 @@ def _fwd_kernel(Q, off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) tl.store(out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) @@ -743,6 +751,7 @@ def context_attention_fwd(q, sliding_window=None, sm_scale=None, skip_decode=False, + fp8_out_scale=None, sinks=None): q_dtype_is_f32 = q.dtype is torch.float32 @@ -793,6 +802,7 @@ def context_attention_fwd(q, if alibi_slopes is not None: assert sinks is None, "Sinks arg is not supported with alibi" + assert fp8_out_scale is None, "FP8 output not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -870,6 +880,7 @@ def context_attention_fwd(q, sm_scale, k_scale, v_scale, + 1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0, b_start_loc, b_seq_len, k_cache.shape[4], @@ -905,6 +916,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, + USE_FP8=fp8_out_scale is not None, BLOCK_M=128, BLOCK_N=64, num_unroll_cache=4, diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 250e9b3890444..d2ad2f7e8d2aa 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -10,9 +10,11 @@ import torch from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton logger = init_logger(__name__) +float8_info = torch.finfo(current_platform.fp8_dtype()) @triton.jit @@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, @triton.jit def kernel_unified_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -281,6 +287,9 @@ def kernel_unified_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) output_offset = (query_offset_0[:, None] * output_stride_0 + query_offset_1[:, None] * output_stride_1 + @@ -552,23 +561,27 @@ def kernel_unified_attention_3d( @triton.jit def reduce_segments( - output_ptr, # [num_tokens, num_query_heads, head_size] - segm_output_ptr, - #[num_tokens, num_query_heads, max_num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] - seq_lens_ptr, # [num_seqs] - num_seqs, # int - num_query_heads: tl.constexpr, # int - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int, must be power of 2 - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + out_scale_inv, # float32 + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) @@ -624,6 +637,10 @@ def reduce_segments( # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + # write result output_offset = (query_token_idx * output_stride_0 + query_head_idx * output_stride_1 + @@ -649,6 +666,7 @@ def unified_attention( k_descale, v_descale, alibi_slopes=None, + output_scale=None, qq_bias=None, # Optional tensor for sinks sinks=None, @@ -707,6 +725,7 @@ def unified_attention( scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, softcap=softcap, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, @@ -736,6 +755,7 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, + USE_FP8=output_scale is not None, ) else: # for initial version, NUM_SEGMENTS = 16 is chosen as a default @@ -819,6 +839,8 @@ def unified_attention( seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, + out_scale_inv=1 / + output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), @@ -828,4 +850,5 @@ def unified_attention( query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + USE_FP8=output_scale is not None, ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3361b65a9b885..3cc0fc3106f5a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -454,11 +454,12 @@ class VllmBackend: inductor_config = config.inductor_compile_config PASS_KEY = "post_grad_custom_post_pass" if PASS_KEY in inductor_config: - # Config should automatically wrap all inductor passes if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + # PassManager already added to config, make sure it's correct assert (inductor_config[PASS_KEY].uuid() == self.post_grad_pass_manager.uuid()) else: + # Config should automatically wrap all inductor passes assert isinstance(inductor_config[PASS_KEY], InductorPass) self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 43c345695ef4e..e3677b3dd62d8 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -39,6 +39,7 @@ class AttentionQuantPattern(ABC): self, layer: Attention, quant_key: QuantKey, + dtype: torch.dtype, ): self.layer = layer self.layer_name = layer.layer_name @@ -46,11 +47,16 @@ class AttentionQuantPattern(ABC): self.head_size = layer.head_size self.quant_key = quant_key self.quant_dtype = quant_key.dtype + self.dtype = dtype assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] + def empty(self, *args, **kwargs): + kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + def empty_quant(self, *args, **kwargs): kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) @@ -91,12 +97,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): def __init__( self, layer: Attention, + dtype: torch.dtype, symmetric: bool = True, ): quant_key = QuantKey(dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric) - super().__init__(layer, quant_key) + super().__init__(layer, quant_key, dtype) def _register(self, pm_pass: PatternMatcherPass): @@ -139,10 +146,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # attn_output + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # q + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # k + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # v + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # attn_output self.empty_quant(5, self.num_heads * self.head_size), # quant_output empty_fp32(1, 1) # scale @@ -165,8 +176,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): will be passed into Attention op as the `output_scale` argument. """ - def __init__(self, layer: Attention): - super().__init__(layer, kNvfp4Quant) + def __init__(self, layer: Attention, dtype: torch.dtype): + super().__init__(layer, kNvfp4Quant, dtype) def _register(self, pm_pass: PatternMatcherPass): @@ -255,12 +266,14 @@ class AttnFusionPass(VllmInductorPass): attn_layers = get_layers_from_vllm_config(config, Attention) for layer_name, layer in attn_layers.items(): - pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8 = AttentionFp8StaticQuantPattern( + layer, config.model_config.dtype) pattern_fp8.register_if_supported(self.patterns) if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): - pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) + pattern_nvfp4 = AttentionNvfp4QuantPattern( + layer, config.model_config.dtype) pattern_nvfp4.register_if_supported(self.patterns) if len(attn_layers) == 0: diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8f6b7f83d47f8..e89a5e643b0e5 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, bias=bias) -def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl( return output -def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor) -> torch.Tensor: return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) @@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + qinput, weight, out_dtype, scale_a, scale_b, bias) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) direct_register_custom_op( @@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: output = torch._scaled_mm(qinput, weight, @@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, if type(output) is tuple and len(output) == 2: output = output[0] - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, @@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list, + output_shape: list, **kwargs) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. @@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b=scale_b.t(), bias=bias) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) output = output.view(*output_shape) return output @@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list, **kwargs) -> torch.Tensor: # Use unfused DQ due to limitations with scaled_mm @@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, if type(output) is tuple and len(output) == 2: output = output[0] # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) # DQ # C = sw * sx * (X * W) + bias @@ -430,7 +431,6 @@ class Fp8LinearOp: scale_a=x_scale, scale_b=weight_scale, bias=bias, - input_2d=input_2d, output_shape=output_shape) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index fe2894eaa0751..c294a5a73cbdd 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import ( from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionCGSupport, @@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool: class TritonAttentionImpl(AttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + def __init__( self, num_heads: int, @@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None or output_block_scale is not None: + if output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" + "fused block_scale output quantization is not yet supported" " for TritonAttentionImpl") if attn_metadata is None: @@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl): alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], sm_scale=self.scale, + output_scale=output_scale, sinks=self.sinks, ) @@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl): k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, - ) + output_scale=output_scale) return output