mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:34:58 +08:00
[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
96846bb360
commit
f98548b9da
@ -305,6 +305,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_fusion_attn.py
|
||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Union
|
||||
|
||||
from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.compilation.fx_utils import (find_specified_fn,
|
||||
find_specified_fn_maybe)
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
@ -48,18 +49,19 @@ class TestBackend:
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
def check_before_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe, \
|
||||
ops_fully_replaced=True):
|
||||
def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True):
|
||||
for op in ops:
|
||||
find_fn(self.graph_pre_pass.nodes, op)
|
||||
if ops_fully_replaced:
|
||||
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph"
|
||||
assert num_pre > num_post, f"All nodes remain for op {op.name()}"
|
||||
if fully_replaced:
|
||||
assert num_post == 0, \
|
||||
f"Unexpected op {op.name()} in post-pass graph"
|
||||
|
||||
def check_after_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe):
|
||||
def check_after_ops(self, ops: Sequence[OpOverload]):
|
||||
for op in ops:
|
||||
find_fn(self.graph_post_pass.nodes, op)
|
||||
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
|
||||
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
|
||||
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
|
||||
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
|
||||
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
|
||||
@ -169,8 +169,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(),
|
||||
ops_fully_replaced=False)
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
|
||||
@ -7,8 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass, QuantKey)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
FusionPass, GroupShape, QuantKey)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
@ -30,9 +29,10 @@ class TestModel(torch.nn.Module):
|
||||
self.cutlass_fp8_enabled = cutlass_fp8_enabled
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
per_tensor=static,
|
||||
group_shape=group_shape,
|
||||
symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
@ -122,9 +122,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
131
tests/compile/test_fusion_attn.py
Normal file
131
tests/compile/test_fusion_attn.py
Normal file
@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.models.utils import check_outputs_equal
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: Optional[TestBackend] = None
|
||||
backend_unfused: Optional[TestBackend] = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, quant_key",
|
||||
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
||||
@pytest.mark.parametrize(
|
||||
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Only test CUDA and ROCm")
|
||||
def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
quant_key: QuantKey, use_triton_fa: bool):
|
||||
# Clean Dynamo cache to avoid reusing other test cases
|
||||
# (for some reason the reset at the end is not enough)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Use global backends
|
||||
global backend, backend_unfused
|
||||
|
||||
use_v1 = False # can be made a param once V1 support added
|
||||
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
|
||||
|
||||
# Prompt 4 seems too open-ended, differs between fused and unfused
|
||||
# (both outputs look reasonable though)
|
||||
prompts = example_prompts[:4] + example_prompts[5:]
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
|
||||
llm = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=10,
|
||||
top_p=0.95)
|
||||
|
||||
unfused_output = llm.generate(prompts, sampling_params)
|
||||
backend_unfused = None # Reset backend to make sure llm gets released
|
||||
del llm
|
||||
|
||||
compile_config = CompilationConfig(
|
||||
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
|
||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||
# so we initialize it during compilation.
|
||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
|
||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||
llm2 = LLM(model,
|
||||
enforce_eager=True,
|
||||
compilation_config=compile_config,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048)
|
||||
|
||||
# check support
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape)
|
||||
for key, layer in compile_config.static_forward_context.items()
|
||||
]
|
||||
|
||||
print(f"{attn_fusion_supported=}")
|
||||
if any(attn_fusion_supported):
|
||||
# Check quant ops
|
||||
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||
|
||||
# attention ops present in both, just output_scale param changes
|
||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
|
||||
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
|
||||
assert len(attn_nodes_pre) == len(attn_nodes_post)
|
||||
|
||||
for i in range(len(attn_nodes_pre)):
|
||||
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
||||
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
||||
assert fused == attn_fusion_supported[i], \
|
||||
f"Node {i} {'' if fused else 'not '} expected " \
|
||||
f"to have fused output quant"
|
||||
|
||||
# check outputs
|
||||
fused_output = llm2.generate(prompts, sampling_params)
|
||||
|
||||
# transform outputs to format expected by check_outputs_equal
|
||||
sample_outs = lambda s: (list(s.token_ids), s.text)
|
||||
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outs_lst(unfused_output),
|
||||
outputs_1_lst=outs_lst(fused_output),
|
||||
name_0="unfused",
|
||||
name_1="fused",
|
||||
)
|
||||
|
||||
# Clean Dynamo cache to avoid polluting other case(s)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Reset backend to make sure llm2 gets released
|
||||
backend = None
|
||||
@ -1225,6 +1225,7 @@ def scaled_fp8_quant(
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
@ -1256,7 +1257,12 @@ def scaled_fp8_quant(
|
||||
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
if output is None:
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
else:
|
||||
assert num_token_padding is None, \
|
||||
"padding not supported if output passed in"
|
||||
assert output.dtype == out_dtype
|
||||
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
|
||||
@ -284,9 +284,25 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: tuple[int, int]):
|
||||
"""
|
||||
Does this attention implementation support fused output quantization.
|
||||
This is used by the AttnFusionPass to only fuse output quantization
|
||||
onto implementations that support it.
|
||||
|
||||
TODO(luka) merge parameters into QuantDescriptor
|
||||
:param dtype: quantized dtype
|
||||
:param static: static or dynamic quantization
|
||||
:param group_shape: quant group shape. (-1, -1) for per-tensor.
|
||||
:return: is fusion supported for this type of quantization
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
@ -300,6 +316,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -374,6 +374,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
@ -388,6 +389,11 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for BlocksparseFlashAttentionImpl")
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
@ -370,6 +370,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: DualChunkFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with DualChunkFlashAttention.
|
||||
Args:
|
||||
@ -383,6 +385,13 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is None, "Output tensor not supported for DualChunk"
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
(
|
||||
query,
|
||||
query_succ,
|
||||
|
||||
@ -673,6 +673,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -692,6 +693,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
|
||||
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
|
||||
assert (
|
||||
|
||||
@ -975,8 +975,14 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
# TODO: directly write to output tensor
|
||||
num_heads: int = self.num_heads
|
||||
head_size: int = self.head_size
|
||||
|
||||
@ -181,6 +181,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: HPUAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
@ -193,6 +194,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for HPUAttentionImpl")
|
||||
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
_, seq_len_kv, _ = key.shape
|
||||
|
||||
|
||||
@ -192,6 +192,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
|
||||
@ -206,6 +207,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for IpexAttentionImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
|
||||
@ -1319,11 +1319,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLAImplBase")
|
||||
|
||||
if attn_metadata.is_profile_run and \
|
||||
attn_metadata.context_chunk_workspace is not None:
|
||||
# During the profile run try to simulate to worse case output size
|
||||
|
||||
@ -172,6 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
@ -187,6 +188,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
|
||||
@ -38,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
|
||||
@cache
|
||||
def _get_paged_attn_module() -> PagedAttention:
|
||||
"""
|
||||
Initializes the appropriate PagedAttention module from `attention/ops`,
|
||||
Initializes the appropriate PagedAttention module from `attention/ops`,
|
||||
which is used as helper function
|
||||
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
|
||||
|
||||
The choice of attention module depends on whether
|
||||
The choice of attention module depends on whether
|
||||
AITER paged attention is enabled:
|
||||
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
|
||||
- Otherwise, it defaults to using the original `PagedAttention`.
|
||||
@ -598,6 +598,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
||||
head_dim))
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: tuple[int, int]):
|
||||
if self.use_triton_flash_attn:
|
||||
return dtype == current_platform.fp8_dtype(
|
||||
) and static and group_shape == (-1, -1) # per-tensor
|
||||
|
||||
# Only supported in the Triton backend
|
||||
return False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@ -607,6 +616,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
@ -660,6 +670,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None and not self.use_triton_flash_attn:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization only supported for Triton"
|
||||
" implementation in ROCMFlashAttentionImpl for now")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
@ -799,6 +814,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
attn_masks[0][None]
|
||||
if attn_masks is not None else None,
|
||||
full_scales,
|
||||
output_scale,
|
||||
)
|
||||
elif self.use_naive_attn:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
@ -876,6 +892,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len, self.sliding_window,
|
||||
self.kv_cache_dtype, self.alibi_slopes)
|
||||
|
||||
if use_custom:
|
||||
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
||||
!= AttentionType.ENCODER_DECODER else
|
||||
@ -887,7 +904,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
dtype=query.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
@ -921,9 +938,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
output_scale,
|
||||
)
|
||||
else:
|
||||
output[num_prefill_tokens:] = paged_attn.forward_decode(
|
||||
# PagedAttention does not support fused quant, manually quantize
|
||||
if output_scale is None:
|
||||
out_pa = output[num_prefill_tokens:]
|
||||
else:
|
||||
out_pa = torch.empty_like(output[num_prefill_tokens:],
|
||||
dtype=query.dtype)
|
||||
|
||||
out_pa[:] = paged_attn.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@ -944,6 +969,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Manually perform quantization
|
||||
if output_scale is not None:
|
||||
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
|
||||
out_q = output.view(-1, self.num_heads * self.head_size)
|
||||
ops.scaled_fp8_quant(out_uq,
|
||||
output_scale,
|
||||
output=out_q[num_prefill_tokens:])
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
@ -459,6 +459,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
@ -473,6 +474,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TorchSDPABackendImpl")
|
||||
|
||||
# For warming-up
|
||||
if attn_metadata is None:
|
||||
|
||||
@ -435,6 +435,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
@ -487,6 +488,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for XFormersImpl")
|
||||
|
||||
attn_type = self.attn_type
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
|
||||
@ -430,6 +430,7 @@ def unified_attention_with_output(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
@ -444,7 +445,8 @@ def unified_attention_with_output(
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
output=output,
|
||||
output_scale=output_scale)
|
||||
|
||||
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||
|
||||
@ -455,6 +457,7 @@ def unified_attention_with_output_fake(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import Callable, ClassVar, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@ -34,36 +33,66 @@ RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
# Use proxy as NamedTuple direct subclasses cannot have static members
|
||||
class _GroupShape(NamedTuple):
|
||||
row: int
|
||||
col: int
|
||||
|
||||
|
||||
class GroupShape(_GroupShape):
|
||||
"""
|
||||
This class describes the quantization group shape.
|
||||
It includes static members for common shapes (per-tensor, per-token).
|
||||
"""
|
||||
|
||||
# Aliases for common quantization group shapes
|
||||
PER_TENSOR: ClassVar['GroupShape']
|
||||
PER_TOKEN: ClassVar['GroupShape']
|
||||
|
||||
|
||||
GroupShape.PER_TENSOR = GroupShape(-1, -1)
|
||||
GroupShape.PER_TOKEN = GroupShape(1, -1)
|
||||
|
||||
|
||||
class QuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of quantization.
|
||||
dtype: quantized data type
|
||||
static: static quantization if True, dynamic if False
|
||||
per_tensor: per-tensor quantization if True, per-token if False
|
||||
group_shape: quantization group shape
|
||||
symmetric: symmetric if True, asymmetric if False
|
||||
|
||||
TODO(luka) use QuantDescriptor once standardized:
|
||||
https://github.com/vllm-project/vllm/issues/8913
|
||||
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
static: bool
|
||||
per_tensor: bool = True
|
||||
group_shape: GroupShape
|
||||
symmetric: bool = True
|
||||
|
||||
def __str__(self):
|
||||
group_shape = ('per_tensor'
|
||||
if self.group_shape == GroupShape.PER_TENSOR else
|
||||
('per_token' if self.group_shape == GroupShape.PER_TOKEN
|
||||
else str(self.group_shape)))
|
||||
|
||||
return (f"QuantKey({'static' if self.static else 'dynamic'},"
|
||||
f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||
f"{'per_tensor' if self.per_tensor else 'per_token'},"
|
||||
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
|
||||
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||
|
||||
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
|
||||
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
|
||||
kFp8StaticTensorSym:
|
||||
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym:
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym:
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@ -83,13 +112,13 @@ class FusedRMSQuantKey(NamedTuple):
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
@ -177,10 +206,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
per_tensor=True,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
@ -233,10 +263,11 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
per_tensor=True,
|
||||
symmetric=symmetric))
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass,
|
||||
@ -323,12 +354,12 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
per_tensor: bool,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
per_tensor=per_tensor,
|
||||
group_shape=group_shape,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
@ -421,12 +452,12 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
per_tensor: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric=True):
|
||||
key = FusedRMSQuantKey(fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype,
|
||||
static=False,
|
||||
per_tensor=per_tensor,
|
||||
group_shape=group_shape,
|
||||
symmetric=symmetric))
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
@ -566,16 +597,12 @@ class FusionPass(VllmInductorPass):
|
||||
self.patterns, self.record_match)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
|
||||
per_tensor=False).register(
|
||||
self.patterns, self.record_match)
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon,
|
||||
FP8_DTYPE,
|
||||
per_tensor=False).register(
|
||||
self.patterns,
|
||||
self.record_match)
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns, self.record_match)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
|
||||
165
vllm/compilation/fusion_attn.py
Normal file
165
vllm/compilation/fusion_attn.py
Normal file
@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionStaticQuantPattern:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_name: str,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
quant_dtype: torch.dtype,
|
||||
symmetric=True,
|
||||
):
|
||||
self.layer_name = layer_name
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.quant_dtype = quant_dtype
|
||||
self.quant_key = QuantKey(dtype=quant_dtype,
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
symmetric=symmetric)
|
||||
assert self.quant_key in QUANT_OPS, \
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty_quant(self, *args, **kwargs):
|
||||
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass,
|
||||
layer: Attention):
|
||||
if layer.impl.fused_output_quant_supported(self.quant_dtype,
|
||||
self.quant_key.static,
|
||||
self.quant_key.group_shape):
|
||||
self._register(pm_pass)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
view_7 = RESHAPE_OP(output_attn,
|
||||
[-1, self.num_heads, self.head_size])
|
||||
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=view_7,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None)
|
||||
attn_out_view = RESHAPE_OP(at1[1],
|
||||
[-1, self.num_heads * self.head_size])
|
||||
|
||||
at2 = auto_functionalized(self.QUANT_OP,
|
||||
result=output_quant,
|
||||
input=attn_out_view,
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
output_attn: torch.Tensor, output_quant: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
view_7 = RESHAPE_OP(output_quant,
|
||||
[-1, self.num_heads, self.head_size])
|
||||
|
||||
at1 = auto_functionalized(ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=view_7,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale)
|
||||
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||
# That would not work for the unified_attention custom op.
|
||||
with unset_fake_temporarily(), FakeTensorMode():
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads * self.head_size), # attn_output
|
||||
self.empty_quant(5, self.num_heads *
|
||||
self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmInductorPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.static_fwd_ctx = config.compilation_config.static_forward_context
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
for key, layer in self.static_fwd_ctx.items():
|
||||
pattern = AttentionStaticQuantPattern(key, layer.num_heads,
|
||||
layer.head_size,
|
||||
current_platform.fp8_dtype())
|
||||
pattern.register_if_supported(self.patterns, layer)
|
||||
if len(self.static_fwd_ctx) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but "
|
||||
"CompilationConfig.static_forward_context is empty. "
|
||||
"Cannot access attention layers so no fusion "
|
||||
"patterns were registered.")
|
||||
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_attn_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quantization onto %s attention nodes", count)
|
||||
self.dump_graph(graph, "after_attn_fusion")
|
||||
self.end_and_log()
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Optional
|
||||
|
||||
from torch import fx
|
||||
@ -14,6 +14,10 @@ def is_func(node: fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first specified node with the given op (if it exists)
|
||||
def find_specified_fn_maybe(nodes: Iterable[fx.Node],
|
||||
op: OpOverload) -> Optional[fx.Node]:
|
||||
@ -60,3 +64,21 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
|
||||
@ -23,7 +23,23 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Cases handled:
|
||||
1. A chain of reshapes is equivalent to the last reshape called on the
|
||||
base tensor (input of the first reshape).
|
||||
2. A reshape that produces the shape of the input is redundant
|
||||
3. A slice that produces the shape of the input is redundant
|
||||
|
||||
Example graph 1:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
|
||||
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
|
||||
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
|
||||
|
||||
Can be replaced with:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_3: "f16[s0, 128, 32]" = ...
|
||||
|
||||
Example graph 2:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
@ -34,7 +50,7 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 2:
|
||||
Example graph 3:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
@ -58,6 +74,18 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
# Case 1: rewrite reshape chains to reshapes on the base tensor
|
||||
input = node.args[0]
|
||||
# If the input is a reshape, rebind to that node
|
||||
if is_func(input, torch.ops.aten.reshape.default):
|
||||
# The new input is guaranteed not to be a reshape,
|
||||
# because we process nodes in order
|
||||
node.update_arg(0, input.args[0])
|
||||
if len(input.users) == 0:
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# Case 2: remove this reshape if it produces the original shape
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
|
||||
@ -10,6 +10,7 @@ from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .collective_fusion import AsyncTPPass
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
@ -59,6 +60,9 @@ class PostGradPassManager(CustomGraphPass):
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
from vllm.config import PassConfig, VllmConfig
|
||||
# yapf: disable
|
||||
@ -34,6 +35,8 @@ class VllmInductorPass(InductorPass):
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
|
||||
lazy_format_graph_code(stage, graph.owning_module)
|
||||
|
||||
if stage in self.pass_config.dump_graph_stages or always:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
|
||||
@ -3804,9 +3804,10 @@ class PassConfig:
|
||||
its own stages (before, after, maybe in-between)."""
|
||||
dump_graph_dir: Path = Path(".")
|
||||
"""Directory to dump the graphs."""
|
||||
# TODO(luka) better pass enabling system.
|
||||
enable_fusion: bool = True
|
||||
"""Whether to enable the custom fusion pass."""
|
||||
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
||||
enable_attn_fusion: bool = False
|
||||
"""Whether to enable the custom attention+quant fusion pass."""
|
||||
enable_noop: bool = True
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
enable_sequence_parallelism: bool = False
|
||||
@ -3814,6 +3815,8 @@ class PassConfig:
|
||||
enable_async_tp: bool = False
|
||||
"""Whether to enable async TP."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
@ -3821,18 +3824,20 @@ class PassConfig:
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
include = {
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism",
|
||||
"enable_async_tp"
|
||||
}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k in include}
|
||||
exclude = {"dump_graph_stages", "dump_graph_dir"}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k not in exclude}
|
||||
return InductorPass.hash_dict(dict_)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
if not self.enable_noop:
|
||||
if self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm/SiluMul + quant (fp8) fusion might not work")
|
||||
if self.enable_attn_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Attention + quant (fp8) fusion might not work")
|
||||
|
||||
|
||||
@config
|
||||
|
||||
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
|
||||
VLLM_NCCL_SO_PATH: Optional[str] = None
|
||||
LD_LIBRARY_PATH: Optional[str] = None
|
||||
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
||||
VLLM_USE_TRITON_FLASH_ATTN: bool = True
|
||||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
||||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
||||
LOCAL_RANK: int = 0
|
||||
|
||||
@ -569,6 +569,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -586,6 +587,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
@ -547,6 +547,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashInfer.
|
||||
|
||||
@ -561,6 +562,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
@ -414,6 +414,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
@ -427,6 +428,12 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlexAttentionImpl")
|
||||
|
||||
enable_gqa = self.num_kv_heads != self.num_heads
|
||||
|
||||
if attn_metadata is None:
|
||||
|
||||
@ -865,10 +865,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
|
||||
@ -161,6 +161,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
@ -173,6 +174,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl")
|
||||
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
|
||||
@ -142,6 +142,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -156,6 +157,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TritonAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user