diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b739851cb905..ec7812820197 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -305,6 +305,7 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_fusion_attn.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 60334f5e4f68..ace4d25534cd 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence from copy import deepcopy from typing import Callable, Union from torch import fx +from torch._ops import OpOverload -from vllm.compilation.fx_utils import (find_specified_fn, - find_specified_fn_maybe) +from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass from vllm.config import get_current_vllm_config @@ -48,18 +49,19 @@ class TestBackend: # assign by reference, will reflect the final state of the graph self.final_graph = graph - def check_before_ops(self, ops, - find_fn=find_specified_fn, \ - find_fn_maybe=find_specified_fn_maybe, \ - ops_fully_replaced=True): + def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): for op in ops: - find_fn(self.graph_pre_pass.nodes, op) - if ops_fully_replaced: - assert find_fn_maybe(self.graph_post_pass.nodes, op) is None + num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) + num_post = len(list(find_op_nodes(op, self.graph_post_pass))) + assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph" + assert num_pre > num_post, f"All nodes remain for op {op.name()}" + if fully_replaced: + assert num_post == 0, \ + f"Unexpected op {op.name()} in post-pass graph" - def check_after_ops(self, ops, - find_fn=find_specified_fn, \ - find_fn_maybe=find_specified_fn_maybe): + def check_after_ops(self, ops: Sequence[OpOverload]): for op in ops: - find_fn(self.graph_post_pass.nodes, op) - assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None + num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) + num_post = len(list(find_op_nodes(op, self.graph_post_pass))) + assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" + assert num_post > 0, f"Op {op.name()} not found in post-pass graph" \ No newline at end of file diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 508056ea1914..62804e721e3d 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, # In pre-nodes, all gather or reduce scatter should exist, # fused_matmul_reduce_scatter or fused_all_gather_matmul should not - backend.check_before_ops(model.ops_in_model_before(), - ops_fully_replaced=False) + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) # In post-nodes, fused_matmul_reduce_scatter or \ # fused_all_gather_matmul should exist diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 0c25aae52d46..040fd176fec1 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -7,8 +7,7 @@ import torch import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass, QuantKey) -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe + FusionPass, GroupShape, QuantKey) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) @@ -30,9 +29,10 @@ class TestModel(torch.nn.Module): self.cutlass_fp8_enabled = cutlass_fp8_enabled self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN self.key = QuantKey(dtype=FP8_DTYPE, static=static, - per_tensor=static, + group_shape=group_shape, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] @@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # In pre-nodes, fp8 quant should be there and fused kernels should not - backend.check_before_ops(model.ops_in_model_before(), find_auto_fn, - find_auto_fn_maybe) + backend.check_before_ops(model.ops_in_model_before()) # In post-nodes, fused kernels should be there and fp8 quant should not - backend.check_after_ops(model.ops_in_model_after(), find_auto_fn, - find_auto_fn_maybe) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py new file mode 100644 index 000000000000..5e6679adfbdc --- /dev/null +++ b/tests/compile/test_fusion_attn.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import pytest +import torch._dynamo + +from tests.compile.backend import TestBackend +from tests.models.utils import check_outputs_equal +from vllm import LLM, SamplingParams +from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym +from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass +from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.platforms import current_platform + +# globals needed for string-import custom Dynamo backend field +backend: Optional[TestBackend] = None +backend_unfused: Optional[TestBackend] = None + + +@pytest.mark.parametrize( + "model, quant_key", + [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) +@pytest.mark.parametrize( + "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) +@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test CUDA and ROCm") +def test_attention_fusion(example_prompts, monkeypatch, model: str, + quant_key: QuantKey, use_triton_fa: bool): + # Clean Dynamo cache to avoid reusing other test cases + # (for some reason the reset at the end is not enough) + torch._dynamo.reset() + + # Use global backends + global backend, backend_unfused + + use_v1 = False # can be made a param once V1 support added + monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1))) + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa))) + + # Prompt 4 seems too open-ended, differs between fused and unfused + # (both outputs look reasonable though) + prompts = example_prompts[:4] + example_prompts[5:] + + compile_config = CompilationConfig( + # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation + # DYNAMO_ONCE does not properly propagate shapes. + level=CompilationLevel.DYNAMO_AS_IS, + backend="tests.compile.test_fusion_attn.backend_unfused", + ) + vllm_config = VllmConfig(compilation_config=compile_config) + backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) + + llm = LLM(model, + enforce_eager=True, + compilation_config=compile_config, + gpu_memory_utilization=0.9, + max_model_len=2048) + + sampling_params = SamplingParams(temperature=0.0, + max_tokens=10, + top_p=0.95) + + unfused_output = llm.generate(prompts, sampling_params) + backend_unfused = None # Reset backend to make sure llm gets released + del llm + + compile_config = CompilationConfig( + # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation + # DYNAMO_ONCE does not properly propagate shapes. + level=CompilationLevel.DYNAMO_AS_IS, + backend="tests.compile.test_fusion_attn.backend", + ) + vllm_config = VllmConfig(compilation_config=compile_config) + + # AttnFusionPass needs attention layers to be registered in config upon init + # so we initialize it during compilation. + attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) + backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) + llm2 = LLM(model, + enforce_eager=True, + compilation_config=compile_config, + gpu_memory_utilization=0.9, + max_model_len=2048) + + # check support + attn_fusion_supported = [ + layer.impl.fused_output_quant_supported(quant_key.dtype, + quant_key.static, + quant_key.group_shape) + for key, layer in compile_config.static_forward_context.items() + ] + + print(f"{attn_fusion_supported=}") + if any(attn_fusion_supported): + # Check quant ops + backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False) + + # attention ops present in both, just output_scale param changes + attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass)) + assert len(attn_nodes_pre) == len(attn_nodes_post) + + for i in range(len(attn_nodes_pre)): + assert attn_nodes_pre[i].kwargs["output_scale"] is None + fused = attn_nodes_post[i].kwargs["output_scale"] is not None + assert fused == attn_fusion_supported[i], \ + f"Node {i} {'' if fused else 'not '} expected " \ + f"to have fused output quant" + + # check outputs + fused_output = llm2.generate(prompts, sampling_params) + + # transform outputs to format expected by check_outputs_equal + sample_outs = lambda s: (list(s.token_ids), s.text) + outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros] + + check_outputs_equal( + outputs_0_lst=outs_lst(unfused_output), + outputs_1_lst=outs_lst(fused_output), + name_0="unfused", + name_1="fused", + ) + + # Clean Dynamo cache to avoid polluting other case(s) + torch._dynamo.reset() + + # Reset backend to make sure llm2 gets released + backend = None diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fe5b386c4d25..e26c90bf70cb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1225,6 +1225,7 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, + output: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -1256,7 +1257,12 @@ def scaled_fp8_quant( out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0ba5a5bf94c9..990ea054f338 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]): kv_cache: torch.Tensor, attn_metadata: T, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError + def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, + group_shape: tuple[int, int]): + """ + Does this attention implementation support fused output quantization. + This is used by the AttnFusionPass to only fuse output quantization + onto implementations that support it. + + TODO(luka) merge parameters into QuantDescriptor + :param dtype: quantized dtype + :param static: static or dynamic quantization + :param group_shape: quant group shape. (-1, -1) for per-tensor. + :return: is fusion supported for this type of quantization + """ + return False + class MLAAttentionImpl(AttentionImpl[T], Generic[T]): @@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): kv_cache: torch.Tensor, attn_metadata: T, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index c1663516de35..71415f49372f 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -374,6 +374,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -388,6 +389,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for BlocksparseFlashAttentionImpl") + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index 963bccdf21bc..55f57f37b100 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -370,6 +370,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: DualChunkFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with DualChunkFlashAttention. Args: @@ -383,6 +385,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ + assert output is None, "Output tensor not supported for DualChunk" + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashAttentionImpl") + ( query, query_succ, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 73e3772682e6..47c25d136c67 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -673,6 +673,7 @@ class FlashAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -692,6 +693,11 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashAttentionImpl") + # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: assert ( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a3937760f03b..ff73104787ab 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -975,8 +975,14 @@ class FlashInferImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashInferImpl") + # TODO: directly write to output tensor num_heads: int = self.num_heads head_size: int = self.head_size diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 9bd513fd894f..115e5ba1a20f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -181,6 +181,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -193,6 +194,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): Returns: shape = [num_tokens, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for HPUAttentionImpl") + batch_size, seq_len, hidden_size = query.shape _, seq_len_kv, _ = key.shape diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 5051c6a7cc4f..21f61cf70b28 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -192,6 +192,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): kv_cache: torch.Tensor, attn_metadata: IpexAttnMetadata, # type: ignore output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -206,6 +207,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for IpexAttentionImpl") + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 78cf95288130..0c3ff26d04c8 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1319,11 +1319,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): kv_cache: torch.Tensor, attn_metadata: T, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLAImplBase") + if attn_metadata.is_profile_run and \ attn_metadata.context_chunk_workspace is not None: # During the profile run try to simulate to worse case output size diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 7ad67615d33d..c5c080297cea 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -172,6 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl): kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -187,6 +188,11 @@ class PallasAttentionBackendImpl(AttentionImpl): Returns: shape = [batch_size, seq_len, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for PallasAttentionImpl") + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7887ebf65f44..8f1da84cd483 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -38,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool: @cache def _get_paged_attn_module() -> PagedAttention: """ - Initializes the appropriate PagedAttention module from `attention/ops`, + Initializes the appropriate PagedAttention module from `attention/ops`, which is used as helper function by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. - The choice of attention module depends on whether + The choice of attention module depends on whether AITER paged attention is enabled: - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. - Otherwise, it defaults to using the original `PagedAttention`. @@ -598,6 +598,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) + def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, + group_shape: tuple[int, int]): + if self.use_triton_flash_attn: + return dtype == current_platform.fp8_dtype( + ) and static and group_shape == (-1, -1) # per-tensor + + # Only supported in the Triton backend + return False + def forward( self, layer: AttentionLayer, @@ -607,6 +616,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -660,6 +670,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None and not self.use_triton_flash_attn: + raise NotImplementedError( + "fused output quantization only supported for Triton" + " implementation in ROCMFlashAttentionImpl for now") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -799,6 +814,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): attn_masks[0][None] if attn_masks is not None else None, full_scales, + output_scale, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: @@ -876,6 +892,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len, self.sliding_window, self.kv_cache_dtype, self.alibi_slopes) + if use_custom: max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type != AttentionType.ENCODER_DECODER else @@ -887,7 +904,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): assert _PARTITION_SIZE_ROCM % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, + dtype=query.dtype, device=output.device, ) exp_sums = torch.empty( @@ -921,9 +938,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): self.kv_cache_dtype, layer._k_scale, layer._v_scale, + output_scale, ) else: - output[num_prefill_tokens:] = paged_attn.forward_decode( + # PagedAttention does not support fused quant, manually quantize + if output_scale is None: + out_pa = output[num_prefill_tokens:] + else: + out_pa = torch.empty_like(output[num_prefill_tokens:], + dtype=query.dtype) + + out_pa[:] = paged_attn.forward_decode( decode_query, key_cache, value_cache, @@ -944,6 +969,14 @@ class ROCmFlashAttentionImpl(AttentionImpl): layer._v_scale, ) + # Manually perform quantization + if output_scale is not None: + out_uq = out_pa.view(-1, self.num_heads * self.head_size) + out_q = output.view(-1, self.num_heads * self.head_size) + ops.scaled_fp8_quant(out_uq, + output_scale, + output=out_q[num_prefill_tokens:]) + # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 23231c323f13..9d7e735dd41d 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -459,6 +459,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -473,6 +474,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TorchSDPABackendImpl") # For warming-up if attn_metadata is None: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 04ef928b7d7b..dfdc8ee6402d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -435,6 +435,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): kv_cache: torch.Tensor, attn_metadata: "XFormersMetadata", output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -487,6 +488,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): Returns: shape = [num_tokens, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for XFormersImpl") + attn_type = self.attn_type # Check that appropriate attention metadata attributes are # selected for the desired attention type diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a5fbd1a1c016..3bbe276e0cbe 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -430,6 +430,7 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + output_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -444,7 +445,8 @@ def unified_attention_with_output( value, kv_cache, attn_metadata, - output=output) + output=output, + output_scale=output_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -455,6 +457,7 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, + output_scale: Optional[torch.Tensor] = None, ) -> None: return diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 7e2c5b4fe66a..9d908fcae3df 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Callable, NamedTuple, Optional +from typing import Callable, ClassVar, NamedTuple, Optional import torch import torch._inductor.pattern_matcher as pm @@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default +# Use proxy as NamedTuple direct subclasses cannot have static members +class _GroupShape(NamedTuple): + row: int + col: int + + +class GroupShape(_GroupShape): + """ + This class describes the quantization group shape. + It includes static members for common shapes (per-tensor, per-token). + """ + + # Aliases for common quantization group shapes + PER_TENSOR: ClassVar['GroupShape'] + PER_TOKEN: ClassVar['GroupShape'] + + +GroupShape.PER_TENSOR = GroupShape(-1, -1) +GroupShape.PER_TOKEN = GroupShape(1, -1) + + class QuantKey(NamedTuple): """ Named tuple for identifying the type of quantization. dtype: quantized data type static: static quantization if True, dynamic if False - per_tensor: per-tensor quantization if True, per-token if False + group_shape: quantization group shape symmetric: symmetric if True, asymmetric if False + + TODO(luka) use QuantDescriptor once standardized: + https://github.com/vllm-project/vllm/issues/8913 + """ dtype: torch.dtype static: bool - per_tensor: bool = True + group_shape: GroupShape symmetric: bool = True def __str__(self): + group_shape = ('per_tensor' + if self.group_shape == GroupShape.PER_TENSOR else + ('per_token' if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape))) + return (f"QuantKey({'static' if self.static else 'dynamic'}," - f"{fx.graph.dtype_abbrs[self.dtype]}," - f"{'per_tensor' if self.per_tensor else 'per_token'}," + f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape}," f"{'a' if not self.symmetric else ''}symmetric)") -kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) -kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) -kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa + kFp8StaticTensorSym: + torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTensorSym: - torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTokenSym: - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } @@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple): FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FusedRMSQuantKey(kFp8StaticTensorSym, False): - torch.ops._C.rms_norm_static_fp8_quant.default, # noqa + torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 FusedRMSQuantKey(kFp8StaticTensorSym, True): - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 FusedRMSQuantKey(kFp8DynamicTokenSym, False): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 FusedRMSQuantKey(kFp8DynamicTokenSym, True): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 } @@ -177,10 +206,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - static=True, - per_tensor=True, - symmetric=symmetric)) + quant=QuantKey( + dtype=quant_dtype, + static=True, + group_shape=GroupShape.PER_TENSOR, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): @@ -233,10 +263,11 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): quant_dtype: torch.dtype, symmetric=True): key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - static=True, - per_tensor=True, - symmetric=symmetric)) + quant=QuantKey( + dtype=quant_dtype, + static=True, + group_shape=GroupShape.PER_TENSOR, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, @@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern): def __init__(self, epsilon: float, quant_dtype: torch.dtype, - per_tensor: bool, + group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): key = FusedRMSQuantKey(fused_add=False, quant=QuantKey(dtype=quant_dtype, static=False, - per_tensor=per_tensor, + group_shape=group_shape, symmetric=symmetric)) super().__init__(epsilon, key) @@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): def __init__(self, epsilon: float, quant_dtype: torch.dtype, - per_tensor: bool = True, + group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): key = FusedRMSQuantKey(fused_add=True, quant=QuantKey(dtype=quant_dtype, static=False, - per_tensor=per_tensor, + group_shape=group_shape, symmetric=symmetric)) super().__init__(epsilon, key) @@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass): self.patterns, self.record_match) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE, - per_tensor=False).register( - self.patterns, self.record_match) + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns, self.record_match) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant - FusedAddRMSNormDynamicQuantPattern(epsilon, - FP8_DTYPE, - per_tensor=False).register( - self.patterns, - self.record_match) + FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns, self.record_match) # WARNING: This is a hack to clear the pattern matcher cache # and allow multiple values of epsilon. diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py new file mode 100644 index 000000000000..cf57e5ed282e --- /dev/null +++ b/vllm/compilation/fusion_attn.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch._inductor.pattern_matcher as pm +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._subclasses.fake_tensor import (FakeTensorMode, + unset_fake_temporarily) + +from vllm.attention import Attention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + +ATTN_OP = torch.ops.vllm.unified_attention_with_output.default +RESHAPE_OP = torch.ops.aten.reshape.default + + +class AttentionStaticQuantPattern: + + def __init__( + self, + layer_name: str, + num_heads: int, + head_size: int, + quant_dtype: torch.dtype, + symmetric=True, + ): + self.layer_name = layer_name + self.num_heads = num_heads + self.head_size = head_size + self.quant_dtype = quant_dtype + self.quant_key = QuantKey(dtype=quant_dtype, + static=True, + group_shape=GroupShape.PER_TENSOR, + symmetric=symmetric) + assert self.quant_key in QUANT_OPS, \ + f"unsupported quantization scheme {self.quant_key}" + self.QUANT_OP = QUANT_OPS[self.quant_key] + + def empty_quant(self, *args, **kwargs): + kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + + def register_if_supported(self, pm_pass: PatternMatcherPass, + layer: Attention): + if layer.impl.fused_output_quant_supported(self.quant_dtype, + self.quant_key.static, + self.quant_key.group_shape): + self._register(pm_pass) + + def _register(self, pm_pass: PatternMatcherPass): + + def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + scale: torch.Tensor): + view_7 = RESHAPE_OP(output_attn, + [-1, self.num_heads, self.head_size]) + + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=view_7, + layer_name=self.layer_name, + output_scale=None) + attn_out_view = RESHAPE_OP(at1[1], + [-1, self.num_heads * self.head_size]) + + at2 = auto_functionalized(self.QUANT_OP, + result=output_quant, + input=attn_out_view, + scale=scale) + return at2[1] + + def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + scale: torch.Tensor): + view_7 = RESHAPE_OP(output_quant, + [-1, self.num_heads, self.head_size]) + + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=view_7, + layer_name=self.layer_name, + output_scale=scale) + + return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) + + # Need custom fake mode, otherwise tracing happens with real tensors. + # That would not work for the unified_attention custom op. + with unset_fake_temporarily(), FakeTensorMode(): + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads * self.head_size), # attn_output + self.empty_quant(5, self.num_heads * + self.head_size), # quant_output + empty_fp32(1, 1) # scale + ] + + def wrap_trace_fn(process_fx, trace_fn): + + def wrapped(*args, **kwargs): + return process_fx(trace_fn(*args, **kwargs)) + + return wrapped + + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) + return gm + + pm.register_replacement( + pattern, replacement, inputs, + wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) + + +class AttnFusionPass(VllmInductorPass): + """ + This pass fuses post-attention quantization onto attention if supported. + + It uses the pattern matcher and matches each layer manually, as strings + cannot be wildcarded. This also lets us check support on attention layers + upon registration instead of during pattern matching. + + Currently, only static fp8 quant is supported, but patterns could easily be + added for other quant schemes and dtypes. The bigger hurdle for wider + support are attention kernels, which need to support fusing output quant. + """ + + def __init__(self, config: VllmConfig): + super().__init__(config) + self.static_fwd_ctx = config.compilation_config.static_forward_context + + self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") + + for key, layer in self.static_fwd_ctx.items(): + pattern = AttentionStaticQuantPattern(key, layer.num_heads, + layer.head_size, + current_platform.fp8_dtype()) + pattern.register_if_supported(self.patterns, layer) + if len(self.static_fwd_ctx) == 0: + logger.warning( + "Attention + quant fusion is enabled, but " + "CompilationConfig.static_forward_context is empty. " + "Cannot access attention layers so no fusion " + "patterns were registered.") + + def __call__(self, graph: torch.fx.graph.Graph) -> None: + self.begin() + self.dump_graph(graph, "before_attn_fusion") + + count = self.patterns.apply(graph) + logger.debug("Fused quantization onto %s attention nodes", count) + self.dump_graph(graph, "after_attn_fusion") + self.end_and_log() diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 9ef388932388..2db8b5441bd6 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import operator -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from typing import Optional from torch import fx @@ -14,6 +14,10 @@ def is_func(node: fx.Node, target) -> bool: return node.op == "call_function" and node.target == target +def is_auto_func(node: fx.Node, op: OpOverload) -> bool: + return is_func(node, auto_functionalized) and node.args[0] == op + + # Returns the first specified node with the given op (if it exists) def find_specified_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: @@ -60,3 +64,21 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node: ret = find_getitem_maybe(node, idx) assert ret is not None, f"Could not find getitem {idx} in node {node}" return ret + + +# An auto-functionalization-aware utility for finding nodes with a specific op +def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: + if not op._schema.is_mutable: + yield from graph.find_nodes(op="call_function", target=op) + + for n in graph.find_nodes(op="call_function", target=auto_functionalized): + if n.args[0] == op: + yield n + + +# Asserts that the node only has one user and returns it +# Even if a node has only 1 user, it might share storage with another node, +# which might need to be taken into account. +def get_only_user(node: fx.Node) -> fx.Node: + assert len(node.users) == 1 + return next(iter(node.users)) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 46f70dcdc688..4888d4d1298e 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -23,7 +23,23 @@ class NoOpEliminationPass(VllmInductorPass): in the 2D-case. Additionally, torch internal no-op elimination pass does not handle certain slice variants. + Cases handled: + 1. A chain of reshapes is equivalent to the last reshape called on the + base tensor (input of the first reshape). + 2. A reshape that produces the shape of the input is redundant + 3. A slice that produces the shape of the input is redundant + Example graph 1: + mul_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32]) + view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096]) + view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32]) + + Can be replaced with: + mul_1: "f16[s0, 4096]" = ... + view_3: "f16[s0, 128, 32]" = ... + + Example graph 2: getitem_1: "f16[s0, 4096]" = ... view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) @@ -34,7 +50,7 @@ class NoOpEliminationPass(VllmInductorPass): at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) out: "f8e4m3fn[s0, 4096]" = at[1] - Example graph 2: + Example graph 3: arg0: "s0" = SymInt(s0) scaled_mm: "f16[s0, 4096]" = ... slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) @@ -58,6 +74,18 @@ class NoOpEliminationPass(VllmInductorPass): # Remove no-op reshapes/views: for node in graph.nodes: if is_func(node, torch.ops.aten.reshape.default): + # Case 1: rewrite reshape chains to reshapes on the base tensor + input = node.args[0] + # If the input is a reshape, rebind to that node + if is_func(input, torch.ops.aten.reshape.default): + # The new input is guaranteed not to be a reshape, + # because we process nodes in order + node.update_arg(0, input.args[0]) + if len(input.users) == 0: + graph.erase_node(input) + count += 1 + + # Case 2: remove this reshape if it produces the original shape input, shape = node.args[:2] input_shape = input.meta["val"].shape if len(shape) != len(input_shape): diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 621c89a14487..28a59905ecf8 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -10,6 +10,7 @@ from .activation_quant_fusion import ActivationQuantFusionPass from .collective_fusion import AsyncTPPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass +from .fusion_attn import AttnFusionPass from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .noop_elimination import NoOpEliminationPass from .sequence_parallelism import SequenceParallelismPass @@ -59,6 +60,9 @@ class PostGradPassManager(CustomGraphPass): if self.pass_config.enable_async_tp: self.passes += [AsyncTPPass(config)] + if self.pass_config.enable_attn_fusion: + self.passes += [AttnFusionPass(config)] + self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 3ccbf52d9fd3..628e9e204c55 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -4,6 +4,7 @@ import time import torch +from torch._dynamo.utils import lazy_format_graph_code from vllm.config import PassConfig, VllmConfig # yapf: disable @@ -34,6 +35,8 @@ class VllmInductorPass(InductorPass): self.pass_name = self.__class__.__name__ def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + lazy_format_graph_code(stage, graph.owning_module) + if stage in self.pass_config.dump_graph_stages or always: # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 diff --git a/vllm/config.py b/vllm/config.py index 5da44988bc5f..d2cfbc839252 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3804,9 +3804,10 @@ class PassConfig: its own stages (before, after, maybe in-between).""" dump_graph_dir: Path = Path(".") """Directory to dump the graphs.""" - # TODO(luka) better pass enabling system. enable_fusion: bool = True - """Whether to enable the custom fusion pass.""" + """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" + enable_attn_fusion: bool = False + """Whether to enable the custom attention+quant fusion pass.""" enable_noop: bool = True """Whether to enable the custom no-op elimination pass.""" enable_sequence_parallelism: bool = False @@ -3814,6 +3815,8 @@ class PassConfig: enable_async_tp: bool = False """Whether to enable async TP.""" + # TODO(luka) better pass enabling system. + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -3821,18 +3824,20 @@ class PassConfig: Do not include dump_graph_* in the hash - they don't affect compilation. """ - include = { - "enable_fusion", "enable_noop", "enable_sequence_parallelism", - "enable_async_tp" - } - dict_ = {k: v for k, v in asdict(self).items() if k in include} + exclude = {"dump_graph_stages", "dump_graph_dir"} + dict_ = {k: v for k, v in asdict(self).items() if k not in exclude} return InductorPass.hash_dict(dict_) def __post_init__(self) -> None: - if not self.enable_noop and self.enable_fusion: - logger.warning_once( - "Fusion enabled but reshape elimination disabled. " - "RMSNorm + quant (fp8) fusion might not work") + if not self.enable_noop: + if self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm/SiluMul + quant (fp8) fusion might not work") + if self.enable_attn_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Attention + quant (fp8) fusion might not work") @config diff --git a/vllm/envs.py b/vllm/envs.py index 80c5f289bba9..f24ae64396f3 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None - VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 91a7c43cd8d8..ebd9bd88dfd0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -569,6 +569,7 @@ class FlashAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -586,6 +587,11 @@ class FlashAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashAttentionImpl") + if attn_metadata is None: # Profiling run. return output diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b15bb4b3152a..277fc3ea5db9 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -547,6 +547,7 @@ class FlashInferImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -561,6 +562,11 @@ class FlashInferImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashInferImpl") + if attn_metadata is None: # Profiling run. return output diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 5b473b1461a6..1588839b685e 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -414,6 +414,7 @@ class FlexAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlexAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -427,6 +428,12 @@ class FlexAttentionImpl(AttentionImpl): shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlexAttentionImpl") + enable_gqa = self.num_kv_heads != self.num_heads if attn_metadata is None: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e6b4f6404632..86e78d7894a1 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -865,10 +865,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): kv_cache: torch.Tensor, attn_metadata: M, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLACommonImpl") + if attn_metadata is None: # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 0f956ba88b9c..62c72f43f147 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -161,6 +161,7 @@ class PallasAttentionBackendImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -173,6 +174,11 @@ class PallasAttentionBackendImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for PallasAttentionBackendImpl") + # For determine_available_memory case. if kv_cache.numel() == 0: if output is None: diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 5db592b15010..6b67d9932e9d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -142,6 +142,7 @@ class TritonAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -156,6 +157,11 @@ class TritonAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TritonAttentionImpl") + if attn_metadata is None: # Profiling run. return output