mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 02:45:31 +08:00
[torch.compile] Cleanup compilation tests and custom passes, add debug utils, fix DCE bug (#23091), fix test (#24376), and prep for custom op matching (#24604) (#24542)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6850bfe15c
commit
6dbbecd5b2
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import weakref
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
@ -10,7 +11,26 @@ from torch._ops import OpOverload
|
|||||||
|
|
||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.inductor_pass import InductorPass
|
from vllm.compilation.inductor_pass import InductorPass
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.compilation.pass_manager import with_pattern_match_debug
|
||||||
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
class LazyInitPass(InductorPass):
|
||||||
|
"""
|
||||||
|
If there's a pass that we want to initialize lazily in a test,
|
||||||
|
we can wrap it in LazyInitPass, which will initialize the pass when invoked
|
||||||
|
and then immediately invoke it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pass_cls: type[VllmInductorPass],
|
||||||
|
vllm_config: VllmConfig):
|
||||||
|
self.pass_cls = pass_cls
|
||||||
|
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
|
||||||
|
|
||||||
|
def __call__(self, graph: fx.Graph) -> None:
|
||||||
|
self.pass_ = self.pass_cls(self.vllm_config)
|
||||||
|
self.pass_(graph)
|
||||||
|
|
||||||
|
|
||||||
class TestBackend:
|
class TestBackend:
|
||||||
@ -40,10 +60,16 @@ class TestBackend:
|
|||||||
example_inputs,
|
example_inputs,
|
||||||
config_patches=self.inductor_config)
|
config_patches=self.inductor_config)
|
||||||
|
|
||||||
|
@with_pattern_match_debug
|
||||||
def post_pass(self, graph: fx.Graph):
|
def post_pass(self, graph: fx.Graph):
|
||||||
self.graph_pre_pass = deepcopy(graph)
|
self.graph_pre_pass = deepcopy(graph)
|
||||||
|
|
||||||
|
VllmInductorPass.dump_prefix = 0
|
||||||
for pass_ in self.custom_passes:
|
for pass_ in self.custom_passes:
|
||||||
pass_(graph)
|
pass_(graph)
|
||||||
|
VllmInductorPass.dump_prefix += 1
|
||||||
|
|
||||||
|
VllmInductorPass.dump_prefix = None
|
||||||
|
|
||||||
self.graph_post_pass = deepcopy(graph)
|
self.graph_post_pass = deepcopy(graph)
|
||||||
# assign by reference, will reflect the final state of the graph
|
# assign by reference, will reflect the final state of the graph
|
||||||
|
|||||||
@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
compiled_model = torch.compile(model, backend=backend)
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
compiled_model(hidden_states)
|
compiled_model(hidden_states)
|
||||||
|
|
||||||
|
assert async_tp_pass.matched_count == 1
|
||||||
|
|
||||||
# In pre-nodes, all gather or reduce scatter should exist,
|
# In pre-nodes, all gather or reduce scatter should exist,
|
||||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import pytest
|
|||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CompilationConfig, VllmConfig
|
||||||
from vllm.utils import _is_torch_equal_or_newer
|
from vllm.utils import _is_torch_equal_or_newer
|
||||||
|
|
||||||
|
|
||||||
@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
|||||||
assert not vllm_config.compilation_config.use_cudagraph
|
assert not vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_op():
|
||||||
|
# proper syntax
|
||||||
|
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid syntax '"):
|
||||||
|
_ = CompilationConfig(custom_ops=["quant_fp8"])
|
||||||
|
|
||||||
|
|
||||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||||
@pytest.mark.forked
|
@pytest.mark.forked
|
||||||
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
|
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
|
||||||
|
|||||||
@ -8,9 +8,10 @@ import vllm.envs as envs
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.fusion import FUSED_OPS, FusionPass
|
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||||
@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
|||||||
vllm_config.compilation_config = CompilationConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
fusion_pass = FusionPass.instance(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||||
|
|
||||||
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
|
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
|
||||||
] if do_fusion else [noop_pass]
|
] if do_fusion else [noop_pass, cleanup_pass]
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
backend_func = TestBackend(*passes, func_pass)
|
backend_func = TestBackend(*passes, func_pass)
|
||||||
backend_no_func = TestBackend(*passes)
|
backend_no_func = TestBackend(*passes)
|
||||||
|
|||||||
@ -4,11 +4,11 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||||
FusionPass)
|
RMSNormQuantFusionPass)
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -79,15 +79,15 @@ class TestModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
@pytest.mark.parametrize("hidden_size", [64])
|
||||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
@pytest.mark.parametrize("num_tokens", [257])
|
||||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
@pytest.mark.parametrize("static", [True, False])
|
@pytest.mark.parametrize("static", [True, False])
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize("cuda_force_torch",
|
@pytest.mark.parametrize("cuda_force_torch",
|
||||||
[True, False] if cutlass_fp8_supported() else [True])
|
[True, False] if cutlass_fp8_supported() else [True])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||||
reason="Only test on CUDA and ROCm")
|
reason="Only test on CUDA and ROCm")
|
||||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||||
cuda_force_torch):
|
cuda_force_torch):
|
||||||
@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
|||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
fusion_pass = FusionPass.instance(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
backend = TestBackend(noop_pass, fusion_pass)
|
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||||
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
||||||
|
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
|||||||
|
|
||||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||||
|
|
||||||
|
assert fusion_pass.matched_count == 2
|
||||||
|
|
||||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||||
backend.check_before_ops(model.ops_in_model_before())
|
backend.check_before_ops(model.ops_in_model_before())
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import vllm.envs as envs
|
|||||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||||
ModelConfig, PassConfig, VllmConfig)
|
ModelConfig, PassConfig, VllmConfig)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
|
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
|
||||||
|
cleanup_pass)
|
||||||
|
|
||||||
token_num = batch_size * seq_len
|
token_num = batch_size * seq_len
|
||||||
model = test_model_cls(hidden_size, token_num)
|
model = test_model_cls(hidden_size, token_num)
|
||||||
@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
compiled_model = torch.compile(model, backend=backend)
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
compiled_model(hidden_states, residual)
|
compiled_model(hidden_states, residual)
|
||||||
|
|
||||||
|
assert all_reduce_fusion_pass.matched_count == 1
|
||||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||||
backend.check_after_ops(model.ops_in_model_after())
|
backend.check_after_ops(model.ops_in_model_after())
|
||||||
del all_reduce_fusion_pass
|
del all_reduce_fusion_pass
|
||||||
|
|||||||
@ -6,18 +6,19 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
|
|
||||||
from tests.compile.backend import TestBackend
|
from tests.compile.backend import LazyInitPass, TestBackend
|
||||||
from tests.models.utils import check_outputs_equal
|
from tests.models.utils import check_outputs_equal
|
||||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||||
create_common_attn_metadata)
|
create_common_attn_metadata)
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||||
from vllm.compilation.fusion import QUANT_OPS
|
from vllm.compilation.fusion import QUANT_OPS
|
||||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||||
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
|
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
|
||||||
set_current_vllm_config)
|
set_current_vllm_config)
|
||||||
@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
|
|||||||
|
|
||||||
# AttnFusionPass needs attention layers to be registered in config upon init
|
# AttnFusionPass needs attention layers to be registered in config upon init
|
||||||
# so we initialize it during compilation.
|
# so we initialize it during compilation.
|
||||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
|
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||||
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
||||||
llm2 = LLM(model,
|
llm2 = LLM(model,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
@ -197,7 +198,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
|
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
|
||||||
|
-> AttentionMetadata:
|
||||||
"""Initialize attention metadata."""
|
"""Initialize attention metadata."""
|
||||||
|
|
||||||
# Create common attn metadata
|
# Create common attn metadata
|
||||||
@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
|
|
||||||
# Create test backend with fusion passes enabled
|
# Create test backend with fusion passes enabled
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
|
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||||
)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
test_backend = TestBackend(noop_pass, attn_pass)
|
|
||||||
|
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||||
|
|
||||||
# Compile model with fusion enabled
|
# Compile model with fusion enabled
|
||||||
model_compiled = torch.compile(model_fused,
|
model_compiled = torch.compile(model_fused,
|
||||||
@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
test_backend.check_before_ops([QUANT_OPS[quant_key]],
|
test_backend.check_before_ops([QUANT_OPS[quant_key]],
|
||||||
fully_replaced=True)
|
fully_replaced=True)
|
||||||
|
|
||||||
|
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||||
|
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||||
|
|
||||||
# Check attention ops in the graph before and after fusion
|
# Check attention ops in the graph before and after fusion
|
||||||
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
|
||||||
attn_nodes_post = list(find_op_nodes(ATTN_OP,
|
attn_nodes_post = list(find_op_nodes(ATTN_OP,
|
||||||
|
|||||||
@ -6,10 +6,12 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.fusion import FusionPass
|
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||||
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||||
PassConfig, VllmConfig)
|
PassConfig, VllmConfig)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
@ -104,7 +106,7 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||||
|
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
# Create a weight that is compatible with torch._scaled_mm,
|
# Create a weight that is compatible with torch._scaled_mm,
|
||||||
@ -137,8 +139,7 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
# layer normalization
|
# layer normalization
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
|
|
||||||
# for static input quantization
|
# scaled_mm with static input quantization
|
||||||
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
|
|
||||||
fp8_linear_result = self.fp8_linear.apply(norm_output,
|
fp8_linear_result = self.fp8_linear.apply(norm_output,
|
||||||
self.w,
|
self.w,
|
||||||
self.wscale,
|
self.wscale,
|
||||||
@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=42)
|
seed=42)
|
||||||
|
|
||||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
passes_for_backend = [noop_pass, sequence_parallelism_pass]
|
passes_for_backend: list[VllmInductorPass] = \
|
||||||
|
[noop_pass, sequence_parallelism_pass]
|
||||||
|
|
||||||
if enable_fusion:
|
if enable_fusion:
|
||||||
fusion_pass = FusionPass.instance(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
passes_for_backend.append(fusion_pass)
|
passes_for_backend.append(fusion_pass)
|
||||||
|
|
||||||
|
passes_for_backend.append(cleanup_pass)
|
||||||
|
|
||||||
backend_no_func = TestBackend(*passes_for_backend)
|
backend_no_func = TestBackend(*passes_for_backend)
|
||||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||||
|
|
||||||
@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||||
compiled_model_func(hidden_states, residual)
|
compiled_model_func(hidden_states, residual)
|
||||||
|
|
||||||
|
assert sequence_parallelism_pass.matched_count == 1
|
||||||
|
|
||||||
# In pre-nodes, all reduce should be there,
|
# In pre-nodes, all reduce should be there,
|
||||||
# reduce scatter and all gather should not
|
# reduce scatter and all gather should not
|
||||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.compilation.activation_quant_fusion import (
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.compilation.fusion import QUANT_OPS
|
from vllm.compilation.fusion import QUANT_OPS
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
from vllm.compilation.activation_quant_fusion import (
|
||||||
|
silu_and_mul_nvfp4_quant_supported)
|
||||||
|
assert silu_and_mul_nvfp4_quant_supported
|
||||||
|
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
|
|
||||||
# create nvfp4 weight
|
# create nvfp4 weight
|
||||||
@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
|||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||||
fusion_pass = ActivationQuantFusionPass(config)
|
fusion_pass = ActivationQuantFusionPass(config)
|
||||||
|
|
||||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
passes = [
|
||||||
|
NoOpEliminationPass(config), fusion_pass,
|
||||||
|
PostCleanupPass(config)
|
||||||
|
]
|
||||||
|
backend = TestBackend(*passes)
|
||||||
model = model_class(hidden_size=hidden_size,
|
model = model_class(hidden_size=hidden_size,
|
||||||
cuda_force_torch=cuda_force_torch,
|
cuda_force_torch=cuda_force_torch,
|
||||||
x=x)
|
x=x)
|
||||||
@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
|||||||
atol=atol,
|
atol=atol,
|
||||||
rtol=rtol)
|
rtol=rtol)
|
||||||
|
|
||||||
|
assert fusion_pass.matched_count == 1
|
||||||
|
|
||||||
# In pre-nodes, quant op should be present and fused kernels should not
|
# In pre-nodes, quant op should be present and fused kernels should not
|
||||||
backend.check_before_ops(model.ops_in_model_before())
|
backend.check_before_ops(model.ops_in_model_before())
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||||
from .inductor_pass import enable_fake_mode
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
|||||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class ActivationQuantFusionPass(VllmInductorPass):
|
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||||
"""
|
"""
|
||||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||||
It uses the torch pattern matcher to find the patterns and replace them.
|
It uses the torch pattern matcher to find the patterns and replace them.
|
||||||
@ -176,16 +176,12 @@ class ActivationQuantFusionPass(VllmInductorPass):
|
|||||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||||
|
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
self.begin()
|
self.matched_count = self.patterns.apply(graph)
|
||||||
self.dump_graph(graph, "before_act_quant_fusion")
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
|
|
||||||
count = self.patterns.apply(graph)
|
|
||||||
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
|
|
||||||
count)
|
|
||||||
|
|
||||||
self.dump_graph(graph, "after_act_quant_fusion")
|
|
||||||
self.end_and_log()
|
|
||||||
|
|
||||||
def uuid(self):
|
def uuid(self):
|
||||||
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
|
return VllmInductorPass.hash_source(self, ActivationQuantPattern,
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
from .inductor_pass import enable_fake_mode
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
@ -348,7 +348,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
|
|||||||
pm.fwd_only, pm_pass)
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AsyncTPPass(VllmInductorPass):
|
class AsyncTPPass(VllmPatternMatcherPass):
|
||||||
|
|
||||||
@enable_fake_mode
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
@ -378,18 +378,17 @@ class AsyncTPPass(VllmInductorPass):
|
|||||||
AllGatherCutlassScaledMMPattern(
|
AllGatherCutlassScaledMMPattern(
|
||||||
self.model_dtype, self.device).register(self.patterns)
|
self.model_dtype, self.device).register(self.patterns)
|
||||||
|
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||||
# only do replace for specific shapes
|
# only do replace for specific shapes
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
return shape is not None and shape % tp_size == 0
|
return shape is not None and shape % tp_size == 0
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
self.begin()
|
self.matched_count = self.patterns.apply(graph)
|
||||||
self.dump_graph(graph, "before_async_tp_pass")
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
count = self.patterns.apply(graph)
|
|
||||||
logger.debug("Replaced %s patterns with async TP pass.", count)
|
|
||||||
self.dump_graph(graph, "after_async_tp_pass")
|
|
||||||
self.end_and_log()
|
|
||||||
|
|
||||||
|
|
||||||
if flashinfer_comm is not None:
|
if flashinfer_comm is not None:
|
||||||
@ -1068,7 +1067,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
|||||||
pm.fwd_only, pm_pass)
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AllReduceFusionPass(VllmInductorPass):
|
class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@ -1124,6 +1123,7 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
||||||
|
|
||||||
self.register_patterns()
|
self.register_patterns()
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
@enable_fake_mode
|
@enable_fake_mode
|
||||||
def register_patterns(self):
|
def register_patterns(self):
|
||||||
@ -1172,15 +1172,14 @@ class AllReduceFusionPass(VllmInductorPass):
|
|||||||
|
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
|
logger.debug("AllReduceFusionPass disabled")
|
||||||
return
|
return
|
||||||
self.begin()
|
|
||||||
self.dump_graph(graph, "before_all_reduce_fusion_pass")
|
self.matched_count = self.patterns.apply(graph)
|
||||||
count = self.patterns.apply(graph)
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
logger.debug("Replaced %s patterns", count)
|
|
||||||
self.dump_graph(graph, "after_all_reduce_fusion_pass")
|
|
||||||
self.end_and_log()
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if getattr(self, "disabled", True):
|
if getattr(self, "disabled", True):
|
||||||
|
|||||||
@ -26,6 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
# XPU does not support auto-functionalization yet.
|
# XPU does not support auto-functionalization yet.
|
||||||
# Will enable this when switch to vllm-xpu-kernels.
|
# Will enable this when switch to vllm-xpu-kernels.
|
||||||
@ -34,9 +35,6 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
"pass currently.")
|
"pass currently.")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.begin()
|
|
||||||
self.dump_graph(graph, "before_fix_functionalization")
|
|
||||||
|
|
||||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||||
count = 0
|
count = 0
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
@ -111,7 +109,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
|
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
self.dump_graph(graph, "before_fix_functionalization_cleanup")
|
self.dump_graph(graph, "before_cleanup")
|
||||||
|
|
||||||
# Remove the nodes all at once
|
# Remove the nodes all at once
|
||||||
count_removed = len(self.nodes_to_remove)
|
count_removed = len(self.nodes_to_remove)
|
||||||
@ -120,8 +118,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
|
|
||||||
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
logger.debug("De-functionalized %s nodes, removed %s nodes", count,
|
||||||
count_removed)
|
count_removed)
|
||||||
self.dump_graph(graph, "after_fix_functionalization")
|
self.nodes_to_remove.clear()
|
||||||
self.end_and_log()
|
|
||||||
|
|
||||||
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
def _remove(self, node_or_nodes: Union[torch.fx.Node,
|
||||||
Iterable[torch.fx.Node]]):
|
Iterable[torch.fx.Node]]):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Callable, NamedTuple, Optional
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
@ -16,10 +16,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .fx_utils import find_getitem_maybe
|
|
||||||
from .inductor_pass import enable_fake_mode
|
from .inductor_pass import enable_fake_mode
|
||||||
from .multi_output_match import MultiOutputMatch
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
@ -50,8 +48,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|||||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||||
}
|
}
|
||||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||||
QUANT_OPS[
|
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default
|
||||||
kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
|
||||||
|
|
||||||
|
|
||||||
class FusedRMSQuantKey(NamedTuple):
|
class FusedRMSQuantKey(NamedTuple):
|
||||||
@ -80,68 +77,6 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class QuantMultiOutputMatch(MultiOutputMatch):
|
|
||||||
|
|
||||||
def __init__(self, match: pm.Match, quant_op, fused_op):
|
|
||||||
super().__init__(match)
|
|
||||||
assert isinstance(quant_op, OpOverload)
|
|
||||||
assert isinstance(fused_op, OpOverload)
|
|
||||||
self.QUANT_OP = quant_op # in-place quant op
|
|
||||||
self.FUSED_OP = fused_op # in-place fused quant op
|
|
||||||
|
|
||||||
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
|
|
||||||
int]],
|
|
||||||
**kwargs):
|
|
||||||
"""
|
|
||||||
This utility function inserts an auto-functionalized node for FUSED_OP.
|
|
||||||
It also correctly sets its meta value and rebinds the users of the
|
|
||||||
unfused nodes to use the fused node instead.
|
|
||||||
|
|
||||||
:param fused_return_mapping: A dictionary, mapping from getitem indices
|
|
||||||
of the fused node result to a tuple of the old node and a getitem index.
|
|
||||||
:param kwargs: kwargs that get directly forwarded to the auto_fn node
|
|
||||||
|
|
||||||
Example:
|
|
||||||
If we want to replace this graph:
|
|
||||||
_, x1, x2 = auto_fn(op1)
|
|
||||||
_, y1, y2 = auto_fn(op2)
|
|
||||||
|
|
||||||
with
|
|
||||||
_, x1, y2, x2 = auto_fn(FUSED_OP)
|
|
||||||
|
|
||||||
we would call:
|
|
||||||
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
|
|
||||||
|
|
||||||
Note that the 0th element is None for auto-functionalized in-place ops.
|
|
||||||
Hence, others appear 1-indexed.
|
|
||||||
"""
|
|
||||||
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
|
|
||||||
indices = fused_return_mapping.keys()
|
|
||||||
getitem_nodes = self.insert_getitems(fused_node, indices)
|
|
||||||
|
|
||||||
# Prepare the meta value, use a list so it's mutable
|
|
||||||
meta_val = [None] * (max(indices) + 1)
|
|
||||||
|
|
||||||
# Iterate through elements of the tuple produced by fused_node
|
|
||||||
for idx, getitem_node in zip(indices, getitem_nodes):
|
|
||||||
old_node, old_idx = fused_return_mapping[idx]
|
|
||||||
|
|
||||||
# If the old value was never used, the old_getitem might not exist
|
|
||||||
old_getitem = find_getitem_maybe(old_node, old_idx)
|
|
||||||
if old_getitem is not None:
|
|
||||||
# Rebind the users of match getitem nodes to use the new nodes.
|
|
||||||
# The old nodes will be removed by DCE at the end of the pass.
|
|
||||||
old_getitem.replace_all_uses_with(getitem_node)
|
|
||||||
getitem_node.meta["val"] = old_getitem.meta["val"]
|
|
||||||
|
|
||||||
# Extract the appropriate meta value
|
|
||||||
# It is present even if the getitem node does not exist
|
|
||||||
meta_val[idx] = old_node.meta["val"][old_idx]
|
|
||||||
|
|
||||||
# Fix the meta value on the new fused node
|
|
||||||
fused_node.meta["val"] = tuple(meta_val)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNormQuantPattern:
|
class RMSNormQuantPattern:
|
||||||
|
|
||||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||||
@ -224,8 +159,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|||||||
symmetric=symmetric))
|
symmetric=symmetric))
|
||||||
super().__init__(epsilon, key)
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass,
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
record_match: Callable[[MultiOutputMatch], bool]):
|
|
||||||
|
|
||||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||||
residual: torch.Tensor, weight: torch.Tensor,
|
residual: torch.Tensor, weight: torch.Tensor,
|
||||||
@ -271,36 +205,7 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
|||||||
inputs,
|
inputs,
|
||||||
pm.fwd_only,
|
pm.fwd_only,
|
||||||
pm_pass,
|
pm_pass,
|
||||||
extra_check=lambda m: record_match(
|
)
|
||||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
|
||||||
|
|
||||||
class Match(QuantMultiOutputMatch):
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
# Find the nodes in the match that we need to rebind
|
|
||||||
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
|
||||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
|
||||||
|
|
||||||
assert len(rms_node.users) == 2
|
|
||||||
assert len(quant_node.users) == 1
|
|
||||||
|
|
||||||
# First, insert a new auto_functionalized node for the fused op,
|
|
||||||
# as well as getitem nodes to extract the result and residual.
|
|
||||||
# The auto_fn node returns a tuple of (None, result, residual).
|
|
||||||
#
|
|
||||||
# The resulting graph looks like this:
|
|
||||||
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
|
||||||
# result_node_new = at[1]
|
|
||||||
# residual_node_new = at[2]
|
|
||||||
with self.inserting_after_match():
|
|
||||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
|
||||||
kwargs = self.match.kwargs.copy()
|
|
||||||
|
|
||||||
# 0 is always None
|
|
||||||
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
|
|
||||||
self.insert_fused_node(fused_return_mapping,
|
|
||||||
**kwargs,
|
|
||||||
epsilon=rms_node.kwargs["epsilon"])
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||||
@ -317,8 +222,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
symmetric=symmetric))
|
symmetric=symmetric))
|
||||||
super().__init__(epsilon, key)
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass,
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
record_match: Callable[[MultiOutputMatch], bool]):
|
|
||||||
|
|
||||||
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||||
input: torch.Tensor, weight: torch.Tensor,
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
@ -366,39 +270,7 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
inputs,
|
inputs,
|
||||||
pm.fwd_only,
|
pm.fwd_only,
|
||||||
pm_pass,
|
pm_pass,
|
||||||
extra_check=lambda m: record_match(
|
)
|
||||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
|
||||||
|
|
||||||
class Match(QuantMultiOutputMatch):
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
# Find the nodes in the match that we need to rebind
|
|
||||||
rms_node = self.find_auto_fn(RMS_OP)
|
|
||||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
|
||||||
|
|
||||||
assert len(rms_node.users) == 1
|
|
||||||
assert len(quant_node.users) == 2
|
|
||||||
|
|
||||||
# First, insert a new auto_functionalized node for the fused op,
|
|
||||||
# as well as getitem nodes to extract the result and scale.
|
|
||||||
# The auto_fn node returns a tuple of (None, result, scale).
|
|
||||||
#
|
|
||||||
# The resulting graph looks like this:
|
|
||||||
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
|
||||||
# result_node_new = at[1]
|
|
||||||
# scale_node_new = at[2]
|
|
||||||
with self.inserting_after_match():
|
|
||||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
|
||||||
kwargs = self.match.kwargs.copy()
|
|
||||||
del kwargs["result_rms"] # not used in the fused op
|
|
||||||
|
|
||||||
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
|
|
||||||
self.insert_fused_node(
|
|
||||||
fused_return_mapping,
|
|
||||||
epsilon=rms_node.kwargs["epsilon"],
|
|
||||||
scale_ub=None, # not used but required
|
|
||||||
residual=None, # not used but required
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||||
@ -415,8 +287,7 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
symmetric=symmetric))
|
symmetric=symmetric))
|
||||||
super().__init__(epsilon, key)
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass,
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
record_match: Callable[[MultiOutputMatch], bool]):
|
|
||||||
|
|
||||||
def pattern(result: torch.Tensor, input: torch.Tensor,
|
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||||
residual: torch.Tensor, weight: torch.Tensor,
|
residual: torch.Tensor, weight: torch.Tensor,
|
||||||
@ -464,137 +335,49 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
|||||||
inputs,
|
inputs,
|
||||||
pm.fwd_only,
|
pm.fwd_only,
|
||||||
pm_pass,
|
pm_pass,
|
||||||
extra_check=lambda m: record_match(
|
)
|
||||||
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
|
||||||
|
|
||||||
class Match(QuantMultiOutputMatch):
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
# Find the nodes in the match that we need to rebind
|
|
||||||
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
|
||||||
quant_node = self.find_auto_fn(self.QUANT_OP)
|
|
||||||
|
|
||||||
assert len(rms_node.users) == 2
|
|
||||||
assert len(quant_node.users) == 2
|
|
||||||
|
|
||||||
# First, insert a new auto_functionalized node for the fused op,
|
|
||||||
# as well as getitem nodes to extract result, scale, and residual.
|
|
||||||
# The auto_fn node returns a tuple (None, result, scale, residual).
|
|
||||||
#
|
|
||||||
# The resulting graph looks like this:
|
|
||||||
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
|
||||||
# result_node_new = at[1]
|
|
||||||
# scale_node_new = at[2]
|
|
||||||
# residual_node_new = at[3]
|
|
||||||
with self.inserting_after_match():
|
|
||||||
# Missing epsilon, scalars cannot be inputs to the pattern
|
|
||||||
kwargs = self.match.kwargs.copy()
|
|
||||||
|
|
||||||
fused_return_mapping = {
|
|
||||||
1: (quant_node, 1), # result
|
|
||||||
2: (quant_node, 2), # scale
|
|
||||||
3: (rms_node, 2), # residual
|
|
||||||
}
|
|
||||||
self.insert_fused_node(
|
|
||||||
fused_return_mapping,
|
|
||||||
epsilon=rms_node.kwargs["epsilon"],
|
|
||||||
scale_ub=None, # not used but required
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class FusionPass(VllmInductorPass):
|
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||||
"""
|
"""
|
||||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||||
It uses the torch pattern matcher to find the patterns and replace them.
|
It also supports fused_add_rms_norm.
|
||||||
It also manually processes multi-output matches, as those are broken in
|
|
||||||
the torch pattern matcher.
|
|
||||||
|
|
||||||
Because patterns can only be registered once, the pass is a singleton.
|
|
||||||
This will be addressed in a future version of PyTorch:
|
|
||||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: 'Optional[FusionPass]' = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def instance(cls, config: VllmConfig):
|
|
||||||
"""
|
|
||||||
Get the singleton instance of the FusionPass.
|
|
||||||
If the instance exists, the config is updated but
|
|
||||||
initialization is not repeated.
|
|
||||||
"""
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = FusionPass(config)
|
|
||||||
else:
|
|
||||||
cls._instance.pass_config = config.compilation_config.pass_config
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
@enable_fake_mode
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
assert self.__class__._instance is None, \
|
|
||||||
"FusionPass singleton instance already exists"
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.matches: list[MultiOutputMatch] = []
|
|
||||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
pass_name="fusion_pass")
|
pass_name="rmsnorm_quant_fusion_pass")
|
||||||
|
|
||||||
for epsilon in [1e-5, 1e-6]:
|
for epsilon in [1e-5, 1e-6]:
|
||||||
# Fuse rms_norm + static fp8 quant
|
# Fuse rms_norm + static fp8 quant
|
||||||
RMSNormStaticQuantPattern(epsilon,
|
RMSNormStaticQuantPattern(epsilon,
|
||||||
FP8_DTYPE).register(self.patterns)
|
FP8_DTYPE).register(self.patterns)
|
||||||
|
|
||||||
# Matches for patterns below have 2 or more outputs,
|
# Fuse fused_add_rms_norm + static fp8 quant
|
||||||
# so we need to process them manually (see process_matches)
|
|
||||||
|
|
||||||
# Fuse rms_norm + static fp8 quant
|
|
||||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||||
self.patterns, self.record_match)
|
self.patterns)
|
||||||
|
|
||||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
RMSNormDynamicQuantPattern(epsilon,
|
||||||
self.patterns, self.record_match)
|
FP8_DTYPE).register(self.patterns)
|
||||||
|
|
||||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||||
self.patterns, self.record_match)
|
self.patterns)
|
||||||
|
|
||||||
# WARNING: This is a hack to clear the pattern matcher cache
|
self.dump_patterns(config, self.patterns)
|
||||||
# and allow multiple values of epsilon.
|
|
||||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
|
||||||
|
|
||||||
def record_match(self, match: MultiOutputMatch) -> bool:
|
|
||||||
# Hijack the extra_check to record the match and
|
|
||||||
# save it for post-processing.
|
|
||||||
self.matches.append(match)
|
|
||||||
|
|
||||||
# Return False to prevent automatic replacement.
|
|
||||||
return False
|
|
||||||
|
|
||||||
def process_matches(self, graph: fx.Graph):
|
|
||||||
"""
|
|
||||||
Manually process multi-output matches and replace them with fused nodes.
|
|
||||||
See MultiOutputMatch for more details.
|
|
||||||
"""
|
|
||||||
for match in self.matches:
|
|
||||||
match.process()
|
|
||||||
|
|
||||||
# Finally, remove matched nodes
|
|
||||||
graph.eliminate_dead_code()
|
|
||||||
assert all(node not in graph.nodes for match in self.matches
|
|
||||||
for node in match.match.nodes)
|
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
self.begin()
|
self.matched_count = self.patterns.apply(graph)
|
||||||
self.dump_graph(graph, "before_fusion")
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
|
|
||||||
count = self.patterns.apply(graph)
|
def uuid(self) -> Any:
|
||||||
logger.debug("Replaced %s patterns", count)
|
return self.hash_source(self, RMSNormQuantPattern,
|
||||||
self.dump_graph(graph, "after_pattern_match")
|
RMSNormStaticQuantPattern,
|
||||||
|
RMSNormDynamicQuantPattern,
|
||||||
# Manually process multi-output matches (and run DCE)
|
FusedAddRMSNormStaticQuantPattern,
|
||||||
self.process_matches(graph)
|
FusedAddRMSNormDynamicQuantPattern)
|
||||||
logger.debug("Post-processed %s matches", len(self.matches))
|
|
||||||
self.dump_graph(graph, "after_fusion")
|
|
||||||
self.matches.clear()
|
|
||||||
self.end_and_log()
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.utils import round_up
|
|||||||
|
|
||||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||||
from .inductor_pass import enable_fake_mode
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
pm_pass)
|
pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class AttnFusionPass(VllmInductorPass):
|
class AttnFusionPass(VllmPatternMatcherPass):
|
||||||
"""
|
"""
|
||||||
This pass fuses post-attention quantization onto attention if supported.
|
This pass fuses post-attention quantization onto attention if supported.
|
||||||
|
|
||||||
@ -282,20 +282,12 @@ class AttnFusionPass(VllmInductorPass):
|
|||||||
"were found in CompilationConfig.static_forward_context "
|
"were found in CompilationConfig.static_forward_context "
|
||||||
"so no fusion patterns were registered.")
|
"so no fusion patterns were registered.")
|
||||||
|
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||||
self.begin()
|
self.matched_count = self.patterns.apply(graph)
|
||||||
self.dump_graph(graph, "before_attn_fusion")
|
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||||
|
|
||||||
count = self.patterns.apply(graph)
|
|
||||||
|
|
||||||
# TODO: Move this to pass_manager.py after the fx graph broken issue
|
|
||||||
# has been resolved.
|
|
||||||
# see https://github.com/vllm-project/vllm/issues/23091
|
|
||||||
graph.eliminate_dead_code()
|
|
||||||
|
|
||||||
logger.debug("Fused quantization onto %s attention nodes", count)
|
|
||||||
self.dump_graph(graph, "after_attn_fusion")
|
|
||||||
self.end_and_log()
|
|
||||||
|
|
||||||
def uuid(self):
|
def uuid(self):
|
||||||
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
|
||||||
|
|||||||
@ -1,109 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import operator
|
|
||||||
from abc import abstractmethod
|
|
||||||
from collections.abc import Iterable
|
|
||||||
|
|
||||||
from torch import fx
|
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
||||||
from torch._inductor import pattern_matcher as pm
|
|
||||||
from torch._ops import OpOverload
|
|
||||||
from torch.fx import Node
|
|
||||||
|
|
||||||
from vllm.compilation.fx_utils import find_auto_fn
|
|
||||||
|
|
||||||
|
|
||||||
class MultiOutputMatch(abc.ABC):
|
|
||||||
"""
|
|
||||||
This class provides utilities to process multi-output matches and
|
|
||||||
manually insert replacements.
|
|
||||||
|
|
||||||
This is necessary because the automatic replacement for multi-output
|
|
||||||
matches is broken: https://github.com/pytorch/pytorch/issues/137280
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, match: pm.Match):
|
|
||||||
self.match = match
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process(self):
|
|
||||||
"""
|
|
||||||
Process a multi-output match and manually insert the replacement.
|
|
||||||
|
|
||||||
This method should:
|
|
||||||
1. Insert the replacement nodes after the last node in the match.
|
|
||||||
2. Rebind the users of nodes in the match to use the new nodes.
|
|
||||||
3. Set meta["val"] for de-functionalization.
|
|
||||||
|
|
||||||
The result of an auto-functionalized node is a tuple of tensors.
|
|
||||||
The first element is the return value of the function, usually None.
|
|
||||||
The remaining elements are the mutated args of the function.
|
|
||||||
|
|
||||||
All auto-functionalized nodes must contain a proper meta["val"],
|
|
||||||
as it is used by de-functionalization. meta["val"] has to contain the
|
|
||||||
value of the node (tuple of tensors) that would be returned by the
|
|
||||||
functionalized node during tracing.
|
|
||||||
|
|
||||||
Existing nodes in the graph all have this property set, but we have
|
|
||||||
to set it manually for new nodes we insert.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
|
|
||||||
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
|
|
||||||
# at.meta["val"] = (None, a, c)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nodes(self) -> list[fx.Node]:
|
|
||||||
return self.match.nodes
|
|
||||||
|
|
||||||
@property
|
|
||||||
def graph(self) -> fx.Graph:
|
|
||||||
return self.match.graph
|
|
||||||
|
|
||||||
def find_auto_fn(self, op) -> fx.Node:
|
|
||||||
"""
|
|
||||||
Find the first auto_functionalized node with the given op in the match.
|
|
||||||
"""
|
|
||||||
return find_auto_fn(self.nodes, op)
|
|
||||||
|
|
||||||
def inserting_after_match(self):
|
|
||||||
"""
|
|
||||||
Insert nodes after the last node in the match.
|
|
||||||
This is done to avoid use-before-definition errors after inserting
|
|
||||||
replacement nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# match.nodes is not guaranteed to be sorted.
|
|
||||||
# Find the last node in the match.
|
|
||||||
for last_node_in_match in reversed(self.graph.nodes):
|
|
||||||
if last_node_in_match in self.match.nodes:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise ValueError("No nodes in graph")
|
|
||||||
|
|
||||||
return self.graph.inserting_after(last_node_in_match)
|
|
||||||
|
|
||||||
def insert_getitems(self, tuple_node: fx.Node,
|
|
||||||
indices: Iterable[int]) -> tuple[fx.Node, ...]:
|
|
||||||
"""
|
|
||||||
Insert operator.getitem nodes to extract elements from a tuple node.
|
|
||||||
|
|
||||||
:param tuple_node: The tuple node to extract elements from.
|
|
||||||
:param indices: The indices of the elements to extract.
|
|
||||||
:return: Tuple of the new getitem nodes, corresponding to the indices.
|
|
||||||
"""
|
|
||||||
with self.graph.inserting_after(tuple_node):
|
|
||||||
return tuple(
|
|
||||||
self.graph.call_function(operator.getitem, (tuple_node, idx))
|
|
||||||
for idx in indices)
|
|
||||||
|
|
||||||
def insert_auto_fn(self, op: OpOverload, kwargs) -> Node:
|
|
||||||
"""
|
|
||||||
Insert an auto_functionalized node with the given op and kwargs.
|
|
||||||
"""
|
|
||||||
return self.graph.call_function(auto_functionalized, (op, ),
|
|
||||||
kwargs=kwargs)
|
|
||||||
@ -64,9 +64,8 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
out: "f16[s0, 4096]" = at[1]
|
out: "f16[s0, 4096]" = at[1]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
self.begin()
|
|
||||||
self.dump_graph(graph, "before_noop_elimination")
|
|
||||||
count = 0
|
count = 0
|
||||||
# Remove no-op reshapes/views:
|
# Remove no-op reshapes/views:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
@ -121,8 +120,6 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||||
self.dump_graph(graph, "after_noop_elimination")
|
|
||||||
self.end_and_log()
|
|
||||||
|
|
||||||
# ---------------------- Reshape helpers ----------------------
|
# ---------------------- Reshape helpers ----------------------
|
||||||
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||||
|
|||||||
@ -1,15 +1,21 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
|
||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import set_env_var
|
||||||
|
|
||||||
|
from .post_cleanup import PostCleanupPass
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||||
from .fusion import FusionPass
|
from .fusion import RMSNormQuantFusionPass
|
||||||
from .fusion_attn import AttnFusionPass
|
from .fusion_attn import AttnFusionPass
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
@ -19,11 +25,28 @@ from .fix_functionalization import FixFunctionalizationPass
|
|||||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||||
from .noop_elimination import NoOpEliminationPass
|
from .noop_elimination import NoOpEliminationPass
|
||||||
from .sequence_parallelism import SequenceParallelismPass
|
from .sequence_parallelism import SequenceParallelismPass
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def with_pattern_match_debug(fn):
|
||||||
|
"""
|
||||||
|
Function decorator that turns on inductor pattern match debug
|
||||||
|
for the duration of the call.
|
||||||
|
Used to avoid logging builtin Inductor pattern matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||||
|
# optionally check rank here
|
||||||
|
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class PostGradPassManager(CustomGraphPass):
|
class PostGradPassManager(CustomGraphPass):
|
||||||
"""
|
"""
|
||||||
The pass manager for post-grad passes.
|
The pass manager for post-grad passes.
|
||||||
@ -40,16 +63,26 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.passes: list[VllmInductorPass] = []
|
self.passes: list[InductorPass] = []
|
||||||
|
|
||||||
|
@with_pattern_match_debug
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
|
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||||
|
|
||||||
shape = get_pass_context().runtime_shape
|
shape = get_pass_context().runtime_shape
|
||||||
for pass_ in self.passes:
|
for pass_ in self.passes:
|
||||||
if pass_.is_applicable_for_shape(shape):
|
if pass_.is_applicable_for_shape(shape):
|
||||||
pass_(graph)
|
pass_(graph)
|
||||||
|
VllmInductorPass.dump_prefix += 1
|
||||||
|
|
||||||
|
# post-cleanup goes before fix_functionalization
|
||||||
|
# because it requires a functional graph
|
||||||
|
self.post_cleanup(graph)
|
||||||
|
VllmInductorPass.dump_prefix += 1
|
||||||
|
|
||||||
# always run fix_functionalization last
|
# always run fix_functionalization last
|
||||||
self.fix_functionalization(graph)
|
self.fix_functionalization(graph)
|
||||||
|
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||||
|
|
||||||
def configure(self, config: VllmConfig):
|
def configure(self, config: VllmConfig):
|
||||||
self.pass_config = config.compilation_config.pass_config
|
self.pass_config = config.compilation_config.pass_config
|
||||||
@ -61,14 +94,18 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
if self.pass_config.enable_async_tp:
|
if self.pass_config.enable_async_tp:
|
||||||
self.passes += [AsyncTPPass(config)]
|
self.passes += [AsyncTPPass(config)]
|
||||||
|
|
||||||
|
if self.pass_config.enable_fi_allreduce_fusion:
|
||||||
|
self.passes += [AllReduceFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_fusion:
|
if self.pass_config.enable_fusion:
|
||||||
self.passes += [FusionPass.instance(config)]
|
self.passes += [RMSNormQuantFusionPass(config)]
|
||||||
self.passes += [ActivationQuantFusionPass(config)]
|
self.passes += [ActivationQuantFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_attn_fusion:
|
if self.pass_config.enable_attn_fusion:
|
||||||
self.passes += [AttnFusionPass(config)]
|
self.passes += [AttnFusionPass(config)]
|
||||||
if self.pass_config.enable_fi_allreduce_fusion:
|
|
||||||
self.passes += [AllReduceFusionPass(config)]
|
# needs a functional graph
|
||||||
|
self.post_cleanup = PostCleanupPass(config)
|
||||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||||
|
|
||||||
def add(self, pass_: InductorPass):
|
def add(self, pass_: InductorPass):
|
||||||
|
|||||||
20
vllm/compilation/post_cleanup.py
Normal file
20
vllm/compilation/post_cleanup.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from torch import fx
|
||||||
|
|
||||||
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
|
||||||
|
class PostCleanupPass(VllmInductorPass):
|
||||||
|
"""
|
||||||
|
This pass performs cleanup after custom passes.
|
||||||
|
It topologically sorts the graph and removes unused nodes.
|
||||||
|
This is needed because the pattern matcher does not guarantee producing
|
||||||
|
a topologically sorted graph, and there may be unused nodes left around.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
|
def __call__(self, graph: fx.Graph) -> None:
|
||||||
|
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||||
|
stable_topological_sort(graph)
|
||||||
|
graph.eliminate_dead_code()
|
||||||
@ -15,7 +15,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .inductor_pass import enable_fake_mode
|
from .inductor_pass import enable_fake_mode
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -417,7 +417,7 @@ class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
|||||||
pm.fwd_only, pm_pass)
|
pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelismPass(VllmInductorPass):
|
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||||
"""
|
"""
|
||||||
This pass enables sequence parallelism for models.
|
This pass enables sequence parallelism for models.
|
||||||
It identifies patterns where an AllReduce operation is followed by
|
It identifies patterns where an AllReduce operation is followed by
|
||||||
@ -466,19 +466,13 @@ class SequenceParallelismPass(VllmInductorPass):
|
|||||||
|
|
||||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||||
self.device).register(self.patterns)
|
self.device).register(self.patterns)
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
# WARNING: This is a hack to clear the pattern matcher cache
|
|
||||||
# and allow multiple values of epsilon.
|
|
||||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
|
||||||
|
|
||||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
return shape is not None and shape % tp_size == 0
|
return shape is not None and shape % tp_size == 0
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
self.begin()
|
self.matched_count = self.patterns.apply(graph)
|
||||||
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
count = self.patterns.apply(graph)
|
|
||||||
logger.debug("Replaced %s patterns with sequence parallelism", count)
|
|
||||||
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
|
||||||
self.end_and_log()
|
|
||||||
|
|||||||
@ -1,10 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
import operator
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from torch._dynamo.utils import lazy_format_graph_code
|
from torch._dynamo.utils import lazy_format_graph_code
|
||||||
|
from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||||
|
PatternPrettyPrinter)
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -19,6 +25,8 @@ class VllmInductorPass(InductorPass):
|
|||||||
An inductor pass with access to vLLM PassConfig.
|
An inductor pass with access to vLLM PassConfig.
|
||||||
It provides timing, logging, and dumping utilities.
|
It provides timing, logging, and dumping utilities.
|
||||||
"""
|
"""
|
||||||
|
dump_prefix: ClassVar[Optional[int]] = None
|
||||||
|
"""Keep track of pass index for debug dump ordering."""
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
self.pass_config = config.compilation_config.pass_config
|
self.pass_config = config.compilation_config.pass_config
|
||||||
@ -28,8 +36,24 @@ class VllmInductorPass(InductorPass):
|
|||||||
else None
|
else None
|
||||||
self.pass_name = self.__class__.__name__
|
self.pass_name = self.__class__.__name__
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def time_and_log(call_fn):
|
||||||
|
|
||||||
|
@functools.wraps(call_fn)
|
||||||
|
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
||||||
|
self.begin()
|
||||||
|
self.dump_graph(graph, "before")
|
||||||
|
call_fn(self, graph)
|
||||||
|
self.dump_graph(graph, "after")
|
||||||
|
self.end_and_log()
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||||
lazy_format_graph_code(stage, graph.owning_module)
|
i = VllmInductorPass.dump_prefix
|
||||||
|
i_str = "" if i is None else f".{i}"
|
||||||
|
lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}",
|
||||||
|
graph.owning_module)
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
self._start_time = time.perf_counter_ns()
|
self._start_time = time.perf_counter_ns()
|
||||||
@ -40,6 +64,88 @@ class VllmInductorPass(InductorPass):
|
|||||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||||
|
|
||||||
|
|
||||||
|
class VllmPatternMatcherPass(VllmInductorPass):
|
||||||
|
"""
|
||||||
|
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||||
|
Its main use is providing the dump_patterns utility that dumps the
|
||||||
|
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||||
|
|
||||||
|
TODO(luka) move more utilities to this pass.
|
||||||
|
"""
|
||||||
|
matched_count: int = 0
|
||||||
|
"""The number of matched patterns in the pass."""
|
||||||
|
|
||||||
|
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||||
|
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>")
|
||||||
|
|
||||||
|
def _replace_op_overloads(self, string: str) -> str:
|
||||||
|
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||||
|
return self._OP_OVERLOAD_PATTERN.sub(
|
||||||
|
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||||
|
string,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
|
||||||
|
"""
|
||||||
|
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||||
|
into the debug_dump_path folder next to the dumped fx graphs.
|
||||||
|
|
||||||
|
This method does its best to print something that looks like Python code
|
||||||
|
for easier debugging and potentially navigation. If any errors appear in
|
||||||
|
the output, please add to this method.
|
||||||
|
|
||||||
|
TODO(luka): use pattern object to manually produce pattern graph
|
||||||
|
"""
|
||||||
|
debug_dump_path = config.compilation_config.debug_dump_path
|
||||||
|
if not debug_dump_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
rank = config.parallel_config.rank
|
||||||
|
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
|
||||||
|
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
from vllm.utils import unique_filepath
|
||||||
|
file_path = unique_filepath(
|
||||||
|
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py")
|
||||||
|
|
||||||
|
with file_path.open("w") as f:
|
||||||
|
print(
|
||||||
|
f'# This file was produced by VllmPatternMatcherPass.'
|
||||||
|
f'dump_patterns for {self.pass_name}.\n'
|
||||||
|
f'# It does its best to produce valid-Python-looking code but'
|
||||||
|
f' please add to dump_patterns if there are any errors.\n\n'
|
||||||
|
f'from torch._higher_order_ops.auto_functionalize import '
|
||||||
|
f'auto_functionalized as auto_functionalized\n'
|
||||||
|
f'from torch._inductor.pattern_matcher import *',
|
||||||
|
file=f)
|
||||||
|
|
||||||
|
for node, patterns in pm_pass.patterns.items():
|
||||||
|
# fix the operator.getitem repr
|
||||||
|
if node[1] == operator.getitem:
|
||||||
|
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||||
|
else:
|
||||||
|
node_repr = repr(node)
|
||||||
|
|
||||||
|
node_repr = self._replace_op_overloads(node_repr)
|
||||||
|
|
||||||
|
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||||
|
for i, pattern in enumerate(patterns):
|
||||||
|
# reserve auto_functionalized ahead of time
|
||||||
|
pp = PatternPrettyPrinter()
|
||||||
|
pp.namespace.create_name("auto_functionalized", None)
|
||||||
|
|
||||||
|
# Assemble pattern
|
||||||
|
out_node = pp.pretty_print(pattern.pattern)
|
||||||
|
pattern_repr = "\n".join([f"def pattern_{i}():"] + [
|
||||||
|
f"{pp.memoized_objs_names[key]} = "
|
||||||
|
f"{pp.memoized_objs_pp[key]}"
|
||||||
|
for key in pp.memoized_objs_names
|
||||||
|
] + [f"return {out_node}"]).replace("\n", "\n ")
|
||||||
|
|
||||||
|
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||||
|
print(f"{pattern_repr}\n", file=f)
|
||||||
|
|
||||||
|
|
||||||
class PrinterInductorPass(VllmInductorPass):
|
class PrinterInductorPass(VllmInductorPass):
|
||||||
|
|
||||||
def __init__(self, name: str, config: VllmConfig):
|
def __init__(self, name: str, config: VllmConfig):
|
||||||
|
|||||||
@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig,
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
logger.debug("enabled custom ops: %s",
|
if check_compile:
|
||||||
vllm_config.compilation_config.enabled_custom_ops)
|
vllm_config.compilation_config.custom_op_log_check()
|
||||||
logger.debug("disabled custom ops: %s",
|
|
||||||
vllm_config.compilation_config.disabled_custom_ops)
|
|
||||||
if check_compile and \
|
if check_compile and \
|
||||||
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
|
||||||
and compilation_counter.num_models_seen == num_models_seen:
|
and compilation_counter.num_models_seen == num_models_seen:
|
||||||
|
|||||||
@ -487,6 +487,12 @@ class CompilationConfig:
|
|||||||
"supported with torch>=2.9.0.dev. Set "
|
"supported with torch>=2.9.0.dev. Set "
|
||||||
"use_inductor_graph_partition=False instead.")
|
"use_inductor_graph_partition=False instead.")
|
||||||
|
|
||||||
|
for op in self.custom_ops:
|
||||||
|
if op[0] not in {'+', '-'} and op not in {'all', 'none'}:
|
||||||
|
raise ValueError(f"Invalid syntax '{op}' for custom op, "
|
||||||
|
"must be 'all', 'none', '+op' or '-op' "
|
||||||
|
"(where 'op' is the registered op name)")
|
||||||
|
|
||||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
raise ValueError("No compilation level is set.")
|
raise ValueError("No compilation level is set.")
|
||||||
@ -628,3 +634,41 @@ class CompilationConfig:
|
|||||||
|
|
||||||
return use_fx_graph_piecewise_compilation or \
|
return use_fx_graph_piecewise_compilation or \
|
||||||
use_inductor_piecewise_compilation
|
use_inductor_piecewise_compilation
|
||||||
|
|
||||||
|
def custom_op_log_check(self):
|
||||||
|
"""
|
||||||
|
This method logs the enabled/disabled custom ops and checks that the
|
||||||
|
passed custom_ops field only contains relevant ops.
|
||||||
|
It is called at the end of set_current_vllm_config,
|
||||||
|
after the custom ops have been instantiated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
|
||||||
|
logger.debug("No custom ops found in model.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
|
||||||
|
logger.debug("disabled custom ops: %s", self.disabled_custom_ops)
|
||||||
|
|
||||||
|
all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops)
|
||||||
|
for op in self.custom_ops:
|
||||||
|
if op in {"all", "none"}:
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert op[0] in {'+', '-'}, "Invalid custom op syntax " \
|
||||||
|
"(should be checked during init)"
|
||||||
|
|
||||||
|
# check if op name exists in model
|
||||||
|
op_name = op[1:]
|
||||||
|
if op_name not in all_ops_in_model:
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
# Does op exist at all or is it just not present in this model?
|
||||||
|
# Note: Only imported op classes appear in the registry.
|
||||||
|
missing_str = "doesn't exist (or wasn't imported/registered)" \
|
||||||
|
if op_name not in CustomOp.op_registry \
|
||||||
|
else "not present in model"
|
||||||
|
|
||||||
|
enable_str = "enabling" if op[0] == '+' else "disabling"
|
||||||
|
logger.warning_once("Op '%s' %s, %s with '%s' has no effect",
|
||||||
|
op_name, missing_str, enable_str, op)
|
||||||
|
|||||||
@ -190,6 +190,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||||
|
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -442,6 +443,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_STANDALONE_COMPILE":
|
"VLLM_USE_STANDALONE_COMPILE":
|
||||||
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1",
|
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1",
|
||||||
|
|
||||||
|
# Debug pattern matching inside custom passes.
|
||||||
|
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
|
||||||
|
"VLLM_PATTERN_MATCH_DEBUG":
|
||||||
|
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
|
||||||
|
|
||||||
# local rank of the process in the distributed setting, used to determine
|
# local rank of the process in the distributed setting, used to determine
|
||||||
# the GPU device id
|
# the GPU device id
|
||||||
"LOCAL_RANK":
|
"LOCAL_RANK":
|
||||||
|
|||||||
@ -3413,3 +3413,16 @@ def length_from_prompt_token_ids_or_embeds(
|
|||||||
f" prompt_token_ids={prompt_token_len}"
|
f" prompt_token_ids={prompt_token_len}"
|
||||||
f" prompt_embeds={prompt_embeds_len}")
|
f" prompt_embeds={prompt_embeds_len}")
|
||||||
return prompt_token_len
|
return prompt_token_len
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_env_var(key, value):
|
||||||
|
old = os.environ.get(key)
|
||||||
|
os.environ[key] = value
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if old is None:
|
||||||
|
del os.environ[key]
|
||||||
|
else:
|
||||||
|
os.environ[key] = old
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user