[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-06-12 11:31:04 -04:00 committed by GitHub
parent 96846bb360
commit f98548b9da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 622 additions and 79 deletions

View File

@ -305,6 +305,7 @@ steps:
commands: commands:
- pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py - pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_async_tp.py

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import Callable, Union from typing import Callable, Union
from torch import fx from torch import fx
from torch._ops import OpOverload
from vllm.compilation.fx_utils import (find_specified_fn, from vllm.compilation.fx_utils import find_op_nodes
find_specified_fn_maybe)
from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
@ -48,18 +49,19 @@ class TestBackend:
# assign by reference, will reflect the final state of the graph # assign by reference, will reflect the final state of the graph
self.final_graph = graph self.final_graph = graph
def check_before_ops(self, ops, def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe, \
ops_fully_replaced=True):
for op in ops: for op in ops:
find_fn(self.graph_pre_pass.nodes, op) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
if ops_fully_replaced: num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
if fully_replaced:
assert num_post == 0, \
f"Unexpected op {op.name()} in post-pass graph"
def check_after_ops(self, ops, def check_after_ops(self, ops: Sequence[OpOverload]):
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe):
for op in ops: for op in ops:
find_fn(self.graph_post_pass.nodes, op) num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"

View File

@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
# In pre-nodes, all gather or reduce scatter should exist, # In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not # fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
ops_fully_replaced=False)
# In post-nodes, fused_matmul_reduce_scatter or \ # In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist # fused_all_gather_matmul should exist

View File

@ -7,8 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.plugins import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey) FusionPass, GroupShape, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig) VllmConfig)
@ -30,9 +29,10 @@ class TestModel(torch.nn.Module):
self.cutlass_fp8_enabled = cutlass_fp8_enabled self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE, self.key = QuantKey(dtype=FP8_DTYPE,
static=static, static=static,
per_tensor=static, group_shape=group_shape,
symmetric=True) symmetric=True)
if static: if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# In pre-nodes, fp8 quant should be there and fused kernels should not # In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn, backend.check_before_ops(model.ops_in_model_before())
find_auto_fn_maybe)
# In post-nodes, fused kernels should be there and fp8 quant should not # In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn, backend.check_after_ops(model.ops_in_model_after())
find_auto_fn_maybe)

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch._dynamo
from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
from vllm import LLM, SamplingParams
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.platforms import current_platform
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
# Use global backends
global backend, backend_unfused
use_v1 = False # can be made a param once V1 support added
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
)
vllm_config = VllmConfig(compilation_config=compile_config)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
for key, layer in compile_config.static_forward_context.items()
]
print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)
# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()
# Reset backend to make sure llm2 gets released
backend = None

View File

@ -1225,6 +1225,7 @@ def scaled_fp8_quant(
num_token_padding: Optional[int] = None, num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False, use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize input tensor to FP8 and return quantized tensor and scale. Quantize input tensor to FP8 and return quantized tensor and scale.
@ -1256,7 +1257,12 @@ def scaled_fp8_quant(
out_dtype: torch.dtype = current_platform.fp8_dtype() out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding: if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1]) shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype) if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None: if scale is None:
if use_per_token_if_dynamic: if use_per_token_if_dynamic:

View File

@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:return: is fusion supported for this type of quantization
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]): class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -374,6 +374,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -388,6 +389,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)

View File

@ -370,6 +370,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: DualChunkFlashAttentionMetadata, attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention. """Forward pass with DualChunkFlashAttention.
Args: Args:
@ -383,6 +385,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
( (
query, query,
query_succ, query_succ,

View File

@ -673,6 +673,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -692,6 +693,11 @@ class FlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert ( assert (

View File

@ -975,8 +975,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
# TODO: directly write to output tensor # TODO: directly write to output tensor
num_heads: int = self.num_heads num_heads: int = self.num_heads
head_size: int = self.head_size head_size: int = self.head_size

View File

@ -181,6 +181,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata, attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
@ -193,6 +194,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape _, seq_len_kv, _ = key.shape

View File

@ -192,6 +192,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
@ -206,6 +207,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for IpexAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.

View File

@ -1319,11 +1319,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if output is not None: if output is not None:
raise NotImplementedError( raise NotImplementedError(
"output is not yet supported for MLAImplBase") "output is not yet supported for MLAImplBase")
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
if attn_metadata.is_profile_run and \ if attn_metadata.is_profile_run and \
attn_metadata.context_chunk_workspace is not None: attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size # During the profile run try to simulate to worse case output size

View File

@ -172,6 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
@ -187,6 +188,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)

View File

@ -598,6 +598,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim)) head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
# Only supported in the Triton backend
return False
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
@ -607,6 +616,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -660,6 +670,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None and not self.use_triton_flash_attn:
raise NotImplementedError(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None: if key is not None:
assert value is not None assert value is not None
@ -799,6 +814,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks[0][None] attn_masks[0][None]
if attn_masks is not None else None, if attn_masks is not None else None,
full_scales, full_scales,
output_scale,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
@ -876,6 +892,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window, decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes) self.kv_cache_dtype, self.alibi_slopes)
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
@ -887,7 +904,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert _PARTITION_SIZE_ROCM % block_size == 0 assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size), size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype, dtype=query.dtype,
device=output.device, device=output.device,
) )
exp_sums = torch.empty( exp_sums = torch.empty(
@ -921,9 +938,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
output_scale,
) )
else: else:
output[num_prefill_tokens:] = paged_attn.forward_decode( # PagedAttention does not support fused quant, manually quantize
if output_scale is None:
out_pa = output[num_prefill_tokens:]
else:
out_pa = torch.empty_like(output[num_prefill_tokens:],
dtype=query.dtype)
out_pa[:] = paged_attn.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
@ -944,6 +969,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
# Manually perform quantization
if output_scale is not None:
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
out_q = output.view(-1, self.num_heads * self.head_size)
ops.scaled_fp8_quant(out_uq,
output_scale,
output=out_q[num_prefill_tokens:])
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)

View File

@ -459,6 +459,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -473,6 +474,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")
# For warming-up # For warming-up
if attn_metadata is None: if attn_metadata is None:

View File

@ -435,6 +435,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
@ -487,6 +488,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
attn_type = self.attn_type attn_type = self.attn_type
# Check that appropriate attention metadata attributes are # Check that appropriate attention metadata attributes are
# selected for the desired attention type # selected for the desired attention type

View File

@ -430,6 +430,7 @@ def unified_attention_with_output(
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name) wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
@ -444,7 +445,8 @@ def unified_attention_with_output(
value, value,
kv_cache, kv_cache,
attn_metadata, attn_metadata,
output=output) output=output,
output_scale=output_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
@ -455,6 +457,7 @@ def unified_attention_with_output_fake(
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
return return

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, ClassVar, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
class QuantKey(NamedTuple): class QuantKey(NamedTuple):
""" """
Named tuple for identifying the type of quantization. Named tuple for identifying the type of quantization.
dtype: quantized data type dtype: quantized data type
static: static quantization if True, dynamic if False static: static quantization if True, dynamic if False
per_tensor: per-tensor quantization if True, per-token if False group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
""" """
dtype: torch.dtype dtype: torch.dtype
static: bool static: bool
per_tensor: bool = True group_shape: GroupShape
symmetric: bool = True symmetric: bool = True
def __str__(self): def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'}," return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]}," f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'per_tensor' if self.per_tensor else 'per_token'},"
f"{'a' if not self.symmetric else ''}symmetric)") f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = { QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
} }
@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False): FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8StaticTensorSym, True): FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, False): FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey(kFp8DynamicTokenSym, True): FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
} }
@ -177,10 +206,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
symmetric=True): symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False, fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(
static=True, dtype=quant_dtype,
per_tensor=True, static=True,
symmetric=symmetric)) group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, fused_key) super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
@ -233,10 +263,11 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
symmetric=True): symmetric=True):
key = FusedRMSQuantKey(fused_add=True, key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(
static=True, dtype=quant_dtype,
per_tensor=True, static=True,
symmetric=symmetric)) group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass, def register(self, pm_pass: PatternMatcherPass,
@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self, def __init__(self,
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
per_tensor: bool, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True): symmetric=True):
key = FusedRMSQuantKey(fused_add=False, key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(dtype=quant_dtype,
static=False, static=False,
per_tensor=per_tensor, group_shape=group_shape,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self, def __init__(self,
epsilon: float, epsilon: float,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
per_tensor: bool = True, group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True): symmetric=True):
key = FusedRMSQuantKey(fused_add=True, key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype, quant=QuantKey(dtype=quant_dtype,
static=False, static=False,
per_tensor=per_tensor, group_shape=group_shape,
symmetric=symmetric)) symmetric=symmetric))
super().__init__(epsilon, key) super().__init__(epsilon, key)
@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass):
self.patterns, self.record_match) self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant # Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE, RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
per_tensor=False).register( self.patterns, self.record_match)
self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant # Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
FP8_DTYPE, self.patterns, self.record_match)
per_tensor=False).register(
self.patterns,
self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache # WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon. # and allow multiple values of epsilon.

View File

@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern:
def __init__(
self,
layer_name: str,
num_heads: int,
head_size: int,
quant_dtype: torch.dtype,
symmetric=True,
):
self.layer_name = layer_name
self.num_heads = num_heads
self.head_size = head_size
self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty_quant(self, *args, **kwargs):
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def register_if_supported(self, pm_pass: PatternMatcherPass,
layer: Attention):
if layer.impl.fused_output_quant_supported(self.quant_dtype,
self.quant_key.static,
self.quant_key.group_shape):
self._register(pm_pass)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_attn,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
layer_name=self.layer_name,
output_scale=None)
attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
result=output_quant,
input=attn_out_view,
scale=scale)
return at2[1]
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
scale: torch.Tensor):
view_7 = RESHAPE_OP(output_quant,
[-1, self.num_heads, self.head_size])
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=view_7,
layer_name=self.layer_name,
output_scale=scale)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads * self.head_size), # attn_output
self.empty_quant(5, self.num_heads *
self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
def wrap_trace_fn(process_fx, trace_fn):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
pm.register_replacement(
pattern, replacement, inputs,
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
class AttnFusionPass(VllmInductorPass):
"""
This pass fuses post-attention quantization onto attention if supported.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
Currently, only static fp8 quant is supported, but patterns could easily be
added for other quant schemes and dtypes. The bigger hurdle for wider
support are attention kernels, which need to support fusing output quant.
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.static_fwd_ctx = config.compilation_config.static_forward_context
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
for key, layer in self.static_fwd_ctx.items():
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
layer.head_size,
current_platform.fp8_dtype())
pattern.register_if_supported(self.patterns, layer)
if len(self.static_fwd_ctx) == 0:
logger.warning(
"Attention + quant fusion is enabled, but "
"CompilationConfig.static_forward_context is empty. "
"Cannot access attention layers so no fusion "
"patterns were registered.")
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin()
self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph)
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator import operator
from collections.abc import Iterable from collections.abc import Iterable, Iterator
from typing import Optional from typing import Optional
from torch import fx from torch import fx
@ -14,6 +14,10 @@ def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target return node.op == "call_function" and node.target == target
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
return is_func(node, auto_functionalized) and node.args[0] == op
# Returns the first specified node with the given op (if it exists) # Returns the first specified node with the given op (if it exists)
def find_specified_fn_maybe(nodes: Iterable[fx.Node], def find_specified_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]: op: OpOverload) -> Optional[fx.Node]:
@ -60,3 +64,21 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx) ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}" assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret return ret
# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
if n.args[0] == op:
yield n
# Asserts that the node only has one user and returns it
# Even if a node has only 1 user, it might share storage with another node,
# which might need to be taken into account.
def get_only_user(node: fx.Node) -> fx.Node:
assert len(node.users) == 1
return next(iter(node.users))

View File

@ -23,7 +23,23 @@ class NoOpEliminationPass(VllmInductorPass):
in the 2D-case. Additionally, torch internal no-op elimination pass does in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants. not handle certain slice variants.
Cases handled:
1. A chain of reshapes is equivalent to the last reshape called on the
base tensor (input of the first reshape).
2. A reshape that produces the shape of the input is redundant
3. A slice that produces the shape of the input is redundant
Example graph 1: Example graph 1:
mul_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
Can be replaced with:
mul_1: "f16[s0, 4096]" = ...
view_3: "f16[s0, 128, 32]" = ...
Example graph 2:
getitem_1: "f16[s0, 4096]" = ... getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
@ -34,7 +50,7 @@ class NoOpEliminationPass(VllmInductorPass):
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1] out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 2: Example graph 3:
arg0: "s0" = SymInt(s0) arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ... scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
@ -58,6 +74,18 @@ class NoOpEliminationPass(VllmInductorPass):
# Remove no-op reshapes/views: # Remove no-op reshapes/views:
for node in graph.nodes: for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default): if is_func(node, torch.ops.aten.reshape.default):
# Case 1: rewrite reshape chains to reshapes on the base tensor
input = node.args[0]
# If the input is a reshape, rebind to that node
if is_func(input, torch.ops.aten.reshape.default):
# The new input is guaranteed not to be a reshape,
# because we process nodes in order
node.update_arg(0, input.args[0])
if len(input.users) == 0:
graph.erase_node(input)
count += 1
# Case 2: remove this reshape if it produces the original shape
input, shape = node.args[:2] input, shape = node.args[:2]
input_shape = input.meta["val"].shape input_shape = input.meta["val"].shape
if len(shape) != len(input_shape): if len(shape) != len(input_shape):

View File

@ -10,6 +10,7 @@ from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass from .collective_fusion import AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass from .fusion import FusionPass
from .fusion_attn import AttnFusionPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass from .sequence_parallelism import SequenceParallelismPass
@ -59,6 +60,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp: if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)] self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass): def add(self, pass_: InductorPass):

View File

@ -4,6 +4,7 @@
import time import time
import torch import torch
from torch._dynamo.utils import lazy_format_graph_code
from vllm.config import PassConfig, VllmConfig from vllm.config import PassConfig, VllmConfig
# yapf: disable # yapf: disable
@ -34,6 +35,8 @@ class VllmInductorPass(InductorPass):
self.pass_name = self.__class__.__name__ self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
lazy_format_graph_code(stage, graph.owning_module)
if stage in self.pass_config.dump_graph_stages or always: if stage in self.pass_config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting # Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1 parallel = p_is_init() and get_tp_world_size() > 1

View File

@ -3804,9 +3804,10 @@ class PassConfig:
its own stages (before, after, maybe in-between).""" its own stages (before, after, maybe in-between)."""
dump_graph_dir: Path = Path(".") dump_graph_dir: Path = Path(".")
"""Directory to dump the graphs.""" """Directory to dump the graphs."""
# TODO(luka) better pass enabling system.
enable_fusion: bool = True enable_fusion: bool = True
"""Whether to enable the custom fusion pass.""" """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False
"""Whether to enable the custom attention+quant fusion pass."""
enable_noop: bool = True enable_noop: bool = True
"""Whether to enable the custom no-op elimination pass.""" """Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
@ -3814,6 +3815,8 @@ class PassConfig:
enable_async_tp: bool = False enable_async_tp: bool = False
"""Whether to enable async TP.""" """Whether to enable async TP."""
# TODO(luka) better pass enabling system.
def uuid(self): def uuid(self):
""" """
Produces a hash unique to the pass configuration. Produces a hash unique to the pass configuration.
@ -3821,18 +3824,20 @@ class PassConfig:
Do not include dump_graph_* in the hash - they don't affect Do not include dump_graph_* in the hash - they don't affect
compilation. compilation.
""" """
include = { exclude = {"dump_graph_stages", "dump_graph_dir"}
"enable_fusion", "enable_noop", "enable_sequence_parallelism", dict_ = {k: v for k, v in asdict(self).items() if k not in exclude}
"enable_async_tp"
}
dict_ = {k: v for k, v in asdict(self).items() if k in include}
return InductorPass.hash_dict(dict_) return InductorPass.hash_dict(dict_)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if not self.enable_noop and self.enable_fusion: if not self.enable_noop:
logger.warning_once( if self.enable_fusion:
"Fusion enabled but reshape elimination disabled. " logger.warning_once(
"RMSNorm + quant (fp8) fusion might not work") "Fusion enabled but reshape elimination disabled. "
"RMSNorm/SiluMul + quant (fp8) fusion might not work")
if self.enable_attn_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Attention + quant (fp8) fusion might not work")
@config @config

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0

View File

@ -569,6 +569,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -586,6 +587,11 @@ class FlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output return output

View File

@ -547,6 +547,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashInfer. """Forward pass with FlashInfer.
@ -561,6 +562,11 @@ class FlashInferImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output return output

View File

@ -414,6 +414,7 @@ class FlexAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata, attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FLexAttention. """Forward pass with FLexAttention.
@ -427,6 +428,12 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlexAttentionImpl")
enable_gqa = self.num_kv_heads != self.num_heads enable_gqa = self.num_kv_heads != self.num_heads
if attn_metadata is None: if attn_metadata is None:

View File

@ -865,10 +865,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")
if attn_metadata is None: if attn_metadata is None:
# The zero fill is required when used with DP + EP # The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the # to ensure all ranks within a DP group compute the

View File

@ -161,6 +161,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
@ -173,6 +174,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl")
# For determine_available_memory case. # For determine_available_memory case.
if kv_cache.numel() == 0: if kv_cache.numel() == 0:
if output is None: if output is None:

View File

@ -142,6 +142,7 @@ class TritonAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -156,6 +157,11 @@ class TritonAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TritonAttentionImpl")
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output return output