vllm/vllm/compilation/fusion_attn.py
Matthew Bonanni 430dd4d9eb
[Attention] Remove imports from vllm/attention/__init__.py (#29342)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2025-11-26 10:53:15 -07:00

360 lines
12 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .fx_utils import is_func
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionQuantPattern(ABC):
"""
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer: Attention,
quant_key: QuantKey,
dtype: torch.dtype,
):
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
self.dtype = dtype
assert self.quant_key in QUANT_OPS, (
f"unsupported quantization scheme {self.quant_key}"
)
self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args, **kwargs):
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args, **kwargs):
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
dtype: torch.dtype,
symmetric: bool = True,
):
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
return self.quant_matcher(attn_out_view, scale)[0]
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
scale: torch.Tensor,
):
# attn output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale,
output_block_scale=None,
)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention, dtype: torch.dtype):
super().__init__(layer, kNvfp4Quant, dtype)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
):
at1 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None,
)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
at2 = auto_functionalized(
self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale,
)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
output_scale: torch.Tensor,
input_scale: torch.Tensor,
):
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device,
)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
at2 = auto_functionalized(
ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view,
)
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
return output, at2[2]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
empty_i32(
128, round_up(self.num_heads * self.head_size // 16, 4)
), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
class AttnFusionPass(VllmPatternMatcherPass):
"""
This pass fuses post-attention quantization onto attention if supported.
It uses the pattern matcher and matches each layer manually, as strings
cannot be wildcarded. This also lets us check support on attention layers
upon registration instead of during pattern matching.
Currently, only static fp8 quant is supported, but patterns could easily be
added for other quant schemes and dtypes. The bigger hurdle for wider
support are attention kernels, which need to support fusing output quant.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(
layer, config.model_config.dtype
)
pattern_fp8.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(
layer, config.model_config.dtype
)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but no attention layers "
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered."
)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self):
return VllmInductorPass.hash_source(
self,
AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern,
)