[torch.compile] Cleanup compilation tests and custom passes, add debug utils, fix DCE bug (#23091), fix test (#24376), and prep for custom op matching (#24604) (#24542)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: luka <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Luka Govedič 2025-09-22 15:30:05 -04:00 committed by yewentao256
parent 6850bfe15c
commit 6dbbecd5b2
24 changed files with 404 additions and 461 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union
@ -10,7 +11,26 @@ from torch._ops import OpOverload
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config
class LazyInitPass(InductorPass):
"""
If there's a pass that we want to initialize lazily in a test,
we can wrap it in LazyInitPass, which will initialize the pass when invoked
and then immediately invoke it.
"""
def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
def __call__(self, graph: fx.Graph) -> None:
self.pass_ = self.pass_cls(self.vllm_config)
self.pass_(graph)
class TestBackend:
@ -40,10 +60,16 @@ class TestBackend:
example_inputs,
config_patches=self.inductor_config)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes:
pass_(graph)
VllmInductorPass.dump_prefix += 1
VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph)
# assign by reference, will reflect the final state of the graph

View File

@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
assert async_tp_pass.matched_count == 1
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)

View File

@ -4,7 +4,7 @@ import pytest
import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config import CompilationConfig, VllmConfig
from vllm.utils import _is_torch_equal_or_newer
@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
assert not vllm_config.compilation_config.use_cudagraph
def test_custom_op():
# proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
with pytest.raises(ValueError, match="Invalid syntax '"):
_ = CompilationConfig(custom_ops=["quant_fp8"])
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends

View File

@ -8,9 +8,10 @@ import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, FusionPass
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
] if do_fusion else [noop_pass]
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
] if do_fusion else [noop_pass, cleanup_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

View File

@ -4,11 +4,11 @@
import pytest
import torch
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass)
RMSNormQuantFusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
@ -79,15 +79,15 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cuda_force_torch):
@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic
@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2
# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

View File

@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
cleanup_pass)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)
assert all_reduce_fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass

View File

@ -6,18 +6,19 @@ from typing import Optional
import pytest
import torch._dynamo
from tests.compile.backend import TestBackend
from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
@ -197,7 +198,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device,
)
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
-> AttentionMetadata:
"""Initialize attention metadata."""
# Create common attn metadata
@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
)
test_backend = TestBackend(noop_pass, attn_pass)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP,

View File

@ -6,10 +6,12 @@ import torch
import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FusionPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
@ -104,7 +106,7 @@ class TestQuantModel(torch.nn.Module):
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
@ -137,8 +139,7 @@ class TestQuantModel(torch.nn.Module):
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
# for static input quantization
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(norm_output,
self.w,
self.wscale,
@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
dtype=dtype,
seed=42)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend = [noop_pass, sequence_parallelism_pass]
passes_for_backend: list[VllmInductorPass] = \
[noop_pass, sequence_parallelism_pass]
if enable_fusion:
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model(
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())

View File

@ -15,6 +15,7 @@ from vllm.compilation.activation_quant_fusion import (
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
super().__init__()
from vllm.compilation.activation_quant_fusion import (
silu_and_mul_nvfp4_quant_supported)
assert silu_and_mul_nvfp4_quant_supported
self.silu_and_mul = SiluAndMul()
# create nvfp4 weight
@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
passes = [
NoOpEliminationPass(config), fusion_pass,
PostCleanupPass(config)
]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size,
cuda_force_torch=cuda_force_torch,
x=x)
@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
atol=atol,
rtol=rtol)
assert fusion_pass.matched_count == 1
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

View File

@ -17,7 +17,7 @@ from vllm.platforms import current_platform
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -152,7 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
class ActivationQuantFusionPass(VllmInductorPass):
class ActivationQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
@ -176,16 +176,12 @@ class ActivationQuantFusionPass(VllmInductorPass):
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_act_quant_fusion")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
count)
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def uuid(self):
return VllmInductorPass.hash_source(self, ActivationQuantPattern,

View File

@ -20,7 +20,7 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
@ -348,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
pm.fwd_only, pm_pass)
class AsyncTPPass(VllmInductorPass):
class AsyncTPPass(VllmPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig):
@ -378,18 +378,17 @@ class AsyncTPPass(VllmInductorPass):
AllGatherCutlassScaledMMPattern(
self.model_dtype, self.device).register(self.patterns)
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
# only do replace for specific shapes
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_async_tp_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns with async TP pass.", count)
self.dump_graph(graph, "after_async_tp_pass")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
if flashinfer_comm is not None:
@ -1068,7 +1067,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
pm.fwd_only, pm_pass)
class AllReduceFusionPass(VllmInductorPass):
class AllReduceFusionPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
@ -1124,6 +1123,7 @@ class AllReduceFusionPass(VllmInductorPass):
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
self.register_patterns()
self.dump_patterns(config, self.patterns)
@enable_fake_mode
def register_patterns(self):
@ -1172,15 +1172,14 @@ class AllReduceFusionPass(VllmInductorPass):
self.disabled = False
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
if self.disabled:
logger.debug("AllReduceFusionPass disabled")
return
self.begin()
self.dump_graph(graph, "before_all_reduce_fusion_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_all_reduce_fusion_pass")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
def __del__(self):
if getattr(self, "disabled", True):

View File

@ -26,6 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
@ -34,9 +35,6 @@ class FixFunctionalizationPass(VllmInductorPass):
"pass currently.")
return
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
@ -111,7 +109,7 @@ class FixFunctionalizationPass(VllmInductorPass):
count += 1
self.dump_graph(graph, "before_fix_functionalization_cleanup")
self.dump_graph(graph, "before_cleanup")
# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
@ -120,8 +118,7 @@ class FixFunctionalizationPass(VllmInductorPass):
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
count_removed)
self.dump_graph(graph, "after_fix_functionalization")
self.end_and_log()
self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: Union[torch.fx.Node,
Iterable[torch.fx.Node]]):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, NamedTuple, Optional
from typing import Any, NamedTuple
import torch
import torch._inductor.pattern_matcher as pm
@ -16,10 +16,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
from .inductor_pass import enable_fake_mode
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
@ -50,8 +48,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
class FusedRMSQuantKey(NamedTuple):
@ -80,68 +77,6 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
}
class QuantMultiOutputMatch(MultiOutputMatch):
def __init__(self, match: pm.Match, quant_op, fused_op):
super().__init__(match)
assert isinstance(quant_op, OpOverload)
assert isinstance(fused_op, OpOverload)
self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
int]],
**kwargs):
"""
This utility function inserts an auto-functionalized node for FUSED_OP.
It also correctly sets its meta value and rebinds the users of the
unfused nodes to use the fused node instead.
:param fused_return_mapping: A dictionary, mapping from getitem indices
of the fused node result to a tuple of the old node and a getitem index.
:param kwargs: kwargs that get directly forwarded to the auto_fn node
Example:
If we want to replace this graph:
_, x1, x2 = auto_fn(op1)
_, y1, y2 = auto_fn(op2)
with
_, x1, y2, x2 = auto_fn(FUSED_OP)
we would call:
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
Note that the 0th element is None for auto-functionalized in-place ops.
Hence, others appear 1-indexed.
"""
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
indices = fused_return_mapping.keys()
getitem_nodes = self.insert_getitems(fused_node, indices)
# Prepare the meta value, use a list so it's mutable
meta_val = [None] * (max(indices) + 1)
# Iterate through elements of the tuple produced by fused_node
for idx, getitem_node in zip(indices, getitem_nodes):
old_node, old_idx = fused_return_mapping[idx]
# If the old value was never used, the old_getitem might not exist
old_getitem = find_getitem_maybe(old_node, old_idx)
if old_getitem is not None:
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
old_getitem.replace_all_uses_with(getitem_node)
getitem_node.meta["val"] = old_getitem.meta["val"]
# Extract the appropriate meta value
# It is present even if the getitem node does not exist
meta_val[idx] = old_node.meta["val"][old_idx]
# Fix the meta value on the new fused node
fused_node.meta["val"] = tuple(meta_val)
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
@ -224,8 +159,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
@ -271,36 +205,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and residual.
# The auto_fn node returns a tuple of (None, result, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
# 0 is always None
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
**kwargs,
epsilon=rms_node.kwargs["epsilon"])
)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
@ -317,8 +222,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
@ -366,39 +270,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 1
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and scale.
# The auto_fn node returns a tuple of (None, result, scale).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
del kwargs["result_rms"] # not used in the fused op
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
residual=None, # not used but required
**kwargs)
)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
@ -415,8 +287,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def register(self, pm_pass: PatternMatcherPass):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
@ -464,137 +335,49 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract result, scale, and residual.
# The auto_fn node returns a tuple (None, result, scale, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
# residual_node_new = at[3]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
fused_return_mapping = {
1: (quant_node, 1), # result
2: (quant_node, 2), # scale
3: (rms_node, 2), # residual
}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
**kwargs)
)
class FusionPass(VllmInductorPass):
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
It also manually processes multi-output matches, as those are broken in
the torch pattern matcher.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
It also supports fused_add_rms_norm.
"""
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: VllmConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
initialization is not repeated.
"""
if cls._instance is None:
cls._instance = FusionPass(config)
else:
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
@enable_fake_mode
def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)
self.matches: list[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass")
pass_name="rmsnorm_quant_fusion_pass")
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
self.patterns)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
RMSNormDynamicQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def record_match(self, match: MultiOutputMatch) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
self.matches.append(match)
# Return False to prevent automatic replacement.
return False
def process_matches(self, graph: fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
See MultiOutputMatch for more details.
"""
for match in self.matches:
match.process()
# Finally, remove matched nodes
graph.eliminate_dead_code()
assert all(node not in graph.nodes for match in self.matches
for node in match.match.nodes)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_fusion")
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_pattern_match")
# Manually process multi-output matches (and run DCE)
self.process_matches(graph)
logger.debug("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_fusion")
self.matches.clear()
self.end_and_log()
def uuid(self) -> Any:
return self.hash_source(self, RMSNormQuantPattern,
RMSNormStaticQuantPattern,
RMSNormDynamicQuantPattern,
FusedAddRMSNormStaticQuantPattern,
FusedAddRMSNormDynamicQuantPattern)

View File

@ -18,7 +18,7 @@ from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -245,7 +245,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
pm_pass)
class AttnFusionPass(VllmInductorPass):
class AttnFusionPass(VllmPatternMatcherPass):
"""
This pass fuses post-attention quantization onto attention if supported.
@ -282,20 +282,12 @@ class AttnFusionPass(VllmInductorPass):
"were found in CompilationConfig.static_forward_context "
"so no fusion patterns were registered.")
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.graph.Graph) -> None:
self.begin()
self.dump_graph(graph, "before_attn_fusion")
count = self.patterns.apply(graph)
# TODO: Move this to pass_manager.py after the fx graph broken issue
# has been resolved.
# see https://github.com/vllm-project/vllm/issues/23091
graph.eliminate_dead_code()
logger.debug("Fused quantization onto %s attention nodes", count)
self.dump_graph(graph, "after_attn_fusion")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
def uuid(self):
return VllmInductorPass.hash_source(self, AttentionQuantPattern,

View File

@ -1,109 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import operator
from abc import abstractmethod
from collections.abc import Iterable
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor import pattern_matcher as pm
from torch._ops import OpOverload
from torch.fx import Node
from vllm.compilation.fx_utils import find_auto_fn
class MultiOutputMatch(abc.ABC):
"""
This class provides utilities to process multi-output matches and
manually insert replacements.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
"""
def __init__(self, match: pm.Match):
self.match = match
@abstractmethod
def process(self):
"""
Process a multi-output match and manually insert the replacement.
This method should:
1. Insert the replacement nodes after the last node in the match.
2. Rebind the users of nodes in the match to use the new nodes.
3. Set meta["val"] for de-functionalization.
The result of an auto-functionalized node is a tuple of tensors.
The first element is the return value of the function, usually None.
The remaining elements are the mutated args of the function.
All auto-functionalized nodes must contain a proper meta["val"],
as it is used by de-functionalization. meta["val"] has to contain the
value of the node (tuple of tensors) that would be returned by the
functionalized node during tracing.
Existing nodes in the graph all have this property set, but we have
to set it manually for new nodes we insert.
Example:
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
# at.meta["val"] = (None, a, c)
"""
raise NotImplementedError
@property
def nodes(self) -> list[fx.Node]:
return self.match.nodes
@property
def graph(self) -> fx.Graph:
return self.match.graph
def find_auto_fn(self, op) -> fx.Node:
"""
Find the first auto_functionalized node with the given op in the match.
"""
return find_auto_fn(self.nodes, op)
def inserting_after_match(self):
"""
Insert nodes after the last node in the match.
This is done to avoid use-before-definition errors after inserting
replacement nodes.
"""
# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(self.graph.nodes):
if last_node_in_match in self.match.nodes:
break
else:
raise ValueError("No nodes in graph")
return self.graph.inserting_after(last_node_in_match)
def insert_getitems(self, tuple_node: fx.Node,
indices: Iterable[int]) -> tuple[fx.Node, ...]:
"""
Insert operator.getitem nodes to extract elements from a tuple node.
:param tuple_node: The tuple node to extract elements from.
:param indices: The indices of the elements to extract.
:return: Tuple of the new getitem nodes, corresponding to the indices.
"""
with self.graph.inserting_after(tuple_node):
return tuple(
self.graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices)
def insert_auto_fn(self, op: OpOverload, kwargs) -> Node:
"""
Insert an auto_functionalized node with the given op and kwargs.
"""
return self.graph.call_function(auto_functionalized, (op, ),
kwargs=kwargs)

View File

@ -64,9 +64,8 @@ class NoOpEliminationPass(VllmInductorPass):
out: "f16[s0, 4096]" = at[1]
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_noop_elimination")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
@ -121,8 +120,6 @@ class NoOpEliminationPass(VllmInductorPass):
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
self.dump_graph(graph, "after_noop_elimination")
self.end_and_log()
# ---------------------- Reshape helpers ----------------------
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],

View File

@ -1,15 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from torch import fx as fx
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import set_env_var
from .post_cleanup import PostCleanupPass
from .vllm_inductor_pass import VllmInductorPass
if current_platform.is_cuda_alike():
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import FusionPass
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
if current_platform.is_cuda():
@ -19,11 +25,28 @@ from .fix_functionalization import FixFunctionalizationPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def with_pattern_match_debug(fn):
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
Used to avoid logging builtin Inductor pattern matching.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
return fn(*args, **kwargs)
return fn(*args, **kwargs)
return wrapper
class PostGradPassManager(CustomGraphPass):
"""
The pass manager for post-grad passes.
@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass):
"""
def __init__(self):
self.passes: list[VllmInductorPass] = []
self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph):
VllmInductorPass.dump_prefix = 0 # reset dump index
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape):
pass_(graph)
VllmInductorPass.dump_prefix += 1
# post-cleanup goes before fix_functionalization
# because it requires a functional graph
self.post_cleanup(graph)
VllmInductorPass.dump_prefix += 1
# always run fix_functionalization last
self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from torch import fx
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
class PostCleanupPass(VllmInductorPass):
"""
This pass performs cleanup after custom passes.
It topologically sorts the graph and removes unused nodes.
This is needed because the pattern matcher does not guarantee producing
a topologically sorted graph, and there may be unused nodes left around.
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
from torch._inductor.pattern_matcher import stable_topological_sort
stable_topological_sort(graph)
graph.eliminate_dead_code()

View File

@ -15,7 +15,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -417,7 +417,7 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
pm.fwd_only, pm_pass)
class SequenceParallelismPass(VllmInductorPass):
class SequenceParallelismPass(VllmPatternMatcherPass):
"""
This pass enables sequence parallelism for models.
It identifies patterns where an AllReduce operation is followed by
@ -466,19 +466,13 @@ class SequenceParallelismPass(VllmInductorPass):
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
self.device).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_sequence_parallelism_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns with sequence parallelism", count)
self.dump_graph(graph, "after_sequence_parallelism_pass")
self.end_and_log()
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)

View File

@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import operator
import time
from pathlib import Path
from typing import ClassVar, Optional
import regex as re
import torch
from torch._dynamo.utils import lazy_format_graph_code
from torch._inductor.pattern_matcher import (PatternMatcherPass,
PatternPrettyPrinter)
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -19,6 +25,8 @@ class VllmInductorPass(InductorPass):
An inductor pass with access to vLLM PassConfig.
It provides timing, logging, and dumping utilities.
"""
dump_prefix: ClassVar[Optional[int]] = None
"""Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
@ -28,8 +36,24 @@ class VllmInductorPass(InductorPass):
else None
self.pass_name = self.__class__.__name__
@staticmethod
def time_and_log(call_fn):
@functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before")
call_fn(self, graph)
self.dump_graph(graph, "after")
self.end_and_log()
return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str):
lazy_format_graph_code(stage, graph.owning_module)
i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}"
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
graph.owning_module)
def begin(self):
self._start_time = time.perf_counter_ns()
@ -40,6 +64,88 @@ class VllmInductorPass(InductorPass):
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class VllmPatternMatcherPass(VllmInductorPass):
"""
A VllmInductorPass that uses the Inductor pattern matcher.
Its main use is providing the dump_patterns utility that dumps the
Inductor pattern matcher patterns into a file, which greatly aids debugging.
TODO(luka) move more utilities to this pass.
"""
matched_count: int = 0
"""The number of matched patterns in the pass."""
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
This method does its best to print something that looks like Python code
for easier debugging and potentially navigation. If any errors appear in
the output, please add to this method.
TODO(luka): use pattern object to manually produce pattern graph
"""
debug_dump_path = config.compilation_config.debug_dump_path
if not debug_dump_path:
return
rank = config.parallel_config.rank
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
debug_dump_path.mkdir(parents=True, exist_ok=True)
from vllm.utils import unique_filepath
file_path = unique_filepath(
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
with file_path.open("w") as f:
print(
f'# This file was produced by VllmPatternMatcherPass.'
f'dump_patterns for {self.pass_name}.\n'
f'# It does its best to produce valid-Python-looking code but'
f' please add to dump_patterns if there are any errors.\n\n'
f'from torch._higher_order_ops.auto_functionalize import '
f'auto_functionalized as auto_functionalized\n'
f'from torch._inductor.pattern_matcher import *',
file=f)
for node, patterns in pm_pass.patterns.items():
# fix the operator.getitem repr
if node[1] == operator.getitem:
node_repr = f"({repr(node[0])}, operator.getitem)"
else:
node_repr = repr(node)
node_repr = self._replace_op_overloads(node_repr)
print(f"\n\n# Patterns for op: {node_repr}", file=f)
for i, pattern in enumerate(patterns):
# reserve auto_functionalized ahead of time
pp = PatternPrettyPrinter()
pp.namespace.create_name("auto_functionalized", None)
# Assemble pattern
out_node = pp.pretty_print(pattern.pattern)
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
f"{pp.memoized_objs_names[key]} = "
f"{pp.memoized_objs_pp[key]}"
for key in pp.memoized_objs_names
] + [f"return {out_node}"]).replace("\n", "\n ")
pattern_repr = self._replace_op_overloads(pattern_repr)
print(f"{pattern_repr}\n", file=f)
class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig):

View File

@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig,
except Exception:
raise
else:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if check_compile:
vllm_config.compilation_config.custom_op_log_check()
if check_compile and \
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:

View File

@ -487,6 +487,12 @@ class CompilationConfig:
"supported with torch>=2.9.0.dev. Set "
"use_inductor_graph_partition=False instead.")
for op in self.custom_ops:
if op[0] not in {'+', '-'} and op not in {'all', 'none'}:
raise ValueError(f"Invalid syntax '{op}' for custom op, "
"must be 'all', 'none', '+op' or '-op' "
"(where 'op' is the registered op name)")
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -532,8 +538,8 @@ class CompilationConfig:
for x in self.compile_sizes:
if isinstance(x, str):
assert x == "cudagraph_capture_sizes", \
"Unrecognized size type in compile_sizes, " \
f"expect 'cudagraph_capture_sizes', got {x}"
"Unrecognized size type in compile_sizes, " \
f"expect 'cudagraph_capture_sizes', got {x}"
computed_compile_sizes.extend(self.cudagraph_capture_sizes)
else:
assert isinstance(x, int)
@ -628,3 +634,41 @@ class CompilationConfig:
return use_fx_graph_piecewise_compilation or \
use_inductor_piecewise_compilation
def custom_op_log_check(self):
"""
This method logs the enabled/disabled custom ops and checks that the
passed custom_ops field only contains relevant ops.
It is called at the end of set_current_vllm_config,
after the custom ops have been instantiated.
"""
if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
logger.debug("No custom ops found in model.")
return
logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
logger.debug("disabled custom ops: %s", self.disabled_custom_ops)
all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops)
for op in self.custom_ops:
if op in {"all", "none"}:
continue
assert op[0] in {'+', '-'}, "Invalid custom op syntax " \
"(should be checked during init)"
# check if op name exists in model
op_name = op[1:]
if op_name not in all_ops_in_model:
from vllm.model_executor.custom_op import CustomOp
# Does op exist at all or is it just not present in this model?
# Note: Only imported op classes appear in the registry.
missing_str = "doesn't exist (or wasn't imported/registered)" \
if op_name not in CustomOp.op_registry \
else "not present in model"
enable_str = "enabling" if op[0] == '+' else "disabling"
logger.warning_once("Op '%s' %s, %s with '%s' has no effect",
op_name, missing_str, enable_str, op)

View File

@ -190,6 +190,7 @@ if TYPE_CHECKING:
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
def get_default_cache_root():
@ -442,6 +443,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_STANDALONE_COMPILE":
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1",
# Debug pattern matching inside custom passes.
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
"VLLM_PATTERN_MATCH_DEBUG":
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":

View File

@ -3413,3 +3413,16 @@ def length_from_prompt_token_ids_or_embeds(
f" prompt_token_ids={prompt_token_len}"
f" prompt_embeds={prompt_embeds_len}")
return prompt_token_len
@contextlib.contextmanager
def set_env_var(key, value):
old = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
if old is None:
del os.environ[key]
else:
os.environ[key] = old