diff --git a/tests/compile/distributed/test_async_tp.py b/tests/compile/distributed/test_async_tp.py index 86d409f1eadb..2eb18e25c98b 100644 --- a/tests/compile/distributed/test_async_tp.py +++ b/tests/compile/distributed/test_async_tp.py @@ -326,7 +326,7 @@ def async_tp_pass_on_test_model( vllm_config = VllmConfig() vllm_config.compilation_config = CompilationConfig( pass_config=PassConfig( - enable_async_tp=True, + fuse_gemm_comms=True, ), ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) @@ -413,7 +413,7 @@ def test_async_tp_pass_correctness( "mode": CompilationMode.VLLM_COMPILE, "compile_sizes": [2, 4, 8], "splitting_ops": [], - "pass_config": {"enable_async_tp": async_tp_enabled}, + "pass_config": {"fuse_gemm_comms": async_tp_enabled}, } async_tp_args = [ diff --git a/tests/compile/distributed/test_fusion_all_reduce.py b/tests/compile/distributed/test_fusion_all_reduce.py index d401d5703275..fc8d1f98ebf8 100644 --- a/tests/compile/distributed/test_fusion_all_reduce.py +++ b/tests/compile/distributed/test_fusion_all_reduce.py @@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model( ) ) vllm_config.compilation_config.pass_config = PassConfig( - enable_fi_allreduce_fusion=True, enable_noop=True + fuse_allreduce_rms=True, eliminate_noops=True ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.parallel_config.rank = local_rank # Setup rank for debug path diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 661172e1965b..5d2786e122a6 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -192,7 +192,7 @@ def test_attn_quant( splitting_ops=splitting_ops, # Common mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, ) @@ -282,9 +282,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm( # Common mode=CompilationMode.VLLM_COMPILE, pass_config=PassConfig( - enable_attn_fusion=True, - enable_noop=True, - enable_fi_allreduce_fusion=True, + fuse_attn_quant=True, + eliminate_noops=True, + fuse_allreduce_rms=True, ), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, @@ -384,10 +384,10 @@ def test_tp2_attn_quant_async_tp( # Common level=CompilationMode.VLLM_COMPILE, pass_config=PassConfig( - enable_attn_fusion=True, - enable_noop=True, - enable_sequence_parallelism=True, - enable_async_tp=True, + fuse_attn_quant=True, + eliminate_noops=True, + enable_sp=True, + fuse_gemm_comms=True, ), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, diff --git a/tests/compile/distributed/test_sequence_parallelism.py b/tests/compile/distributed/test_sequence_parallelism.py index 30084dfd5a95..d9fdc3acc3d6 100644 --- a/tests/compile/distributed/test_sequence_parallelism.py +++ b/tests/compile/distributed/test_sequence_parallelism.py @@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): ] def ops_in_model(self): - if self.vllm_config.compilation_config.pass_config.enable_fusion: + if self.vllm_config.compilation_config.pass_config.fuse_norm_quant: return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] elif RMSNorm.enabled(): return [ @@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("enable_fusion", [True, False]) +@pytest.mark.parametrize("fuse_norm_quant", [True, False]) @pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_sequence_parallelism_pass( @@ -193,7 +193,7 @@ def test_sequence_parallelism_pass( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_fusion: bool, + fuse_norm_quant: bool, dynamic: bool, ): num_processes = 2 @@ -211,7 +211,7 @@ def test_sequence_parallelism_pass( seq_len, hidden_size, dtype, - enable_fusion, + fuse_norm_quant, dynamic, ), nprocs=nprocs, @@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_fusion: bool, + fuse_norm_quant: bool, dynamic: bool, ): current_platform.seed_everything(0) @@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model( cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings custom_ops=custom_ops_list, pass_config=PassConfig( - enable_sequence_parallelism=True, - enable_fusion=enable_fusion, - enable_noop=True, + enable_sp=True, + fuse_norm_quant=fuse_norm_quant, + eliminate_noops=True, ), ) # NoOp needed for fusion device_config = DeviceConfig(device=torch.device("cuda")) @@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model( sequence_parallelism_pass, ] - if enable_fusion: + if fuse_norm_quant: fusion_pass = RMSNormQuantFusionPass(vllm_config) passes_for_backend.append(fusion_pass) diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 2c11ecef7f02..3cd1d4be2ebd 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -122,7 +122,9 @@ def test_full_graph( CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm"], - pass_config=PassConfig(enable_fusion=True, enable_noop=True), + pass_config=PassConfig( + fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True + ), ), *model_info, ) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index a9e5ccee520e..9e912c6d810d 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import logging from contextlib import nullcontext from unittest.mock import patch @@ -10,8 +11,9 @@ from pydantic import ValidationError from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.config.compilation import CompilationMode +from vllm.config.compilation import CompilationMode, PassConfig from vllm.engine.arg_utils import EngineArgs +from vllm.logger import _print_warning_once from vllm.platforms import current_platform from vllm.utils.torch_utils import _is_torch_equal_or_newer @@ -191,7 +193,7 @@ def test_splitting_ops_dynamic(): config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config={"enable_attn_fusion": True, "enable_noop": True}, + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) @@ -206,7 +208,7 @@ def test_splitting_ops_dynamic(): config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config={"enable_attn_fusion": True, "enable_noop": True}, + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, # work around for accessing all attntion ops @@ -219,7 +221,7 @@ def test_splitting_ops_dynamic(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, - pass_config={"enable_attn_fusion": True, "enable_noop": True}, + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) @@ -227,7 +229,7 @@ def test_splitting_ops_dynamic(): # With inductor graph partition, attn_fusion and splitting_ops # work together. Default splitting_ops include attention ops. assert config.compilation_config.splitting_ops_contain_attention() - # enable_attn_fusion is directly supported under + # fuse_attn_quant is directly supported under # use_inductor_graph_partition=True, and cudagraph_mode # is unchanged. assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE @@ -301,7 +303,7 @@ def test_should_split(): "cudagraph_capture_sizes", "max_cudagraph_capture_size", "tp_size", - "enable_sequence_parallelism", + "enable_sp", "max_num_batched_tokens", "cudagraph_mode", "expected_max_size", @@ -339,7 +341,7 @@ def test_cudagraph_sizes_post_init( cudagraph_capture_sizes, max_cudagraph_capture_size, tp_size, - enable_sequence_parallelism, + enable_sp, max_num_batched_tokens, cudagraph_mode, expected_max_size, @@ -355,11 +357,12 @@ def test_cudagraph_sizes_post_init( compilation_config = CompilationConfig( cudagraph_capture_sizes=cudagraph_capture_sizes, max_cudagraph_capture_size=max_cudagraph_capture_size, - pass_config={ - "enable_sequence_parallelism": enable_sequence_parallelism, - "enable_fusion": True, - "enable_noop": True, - }, + pass_config=PassConfig( + enable_sp=enable_sp, + fuse_norm_quant=True, + fuse_act_quant=True, + eliminate_noops=True, + ), cudagraph_mode=cudagraph_mode, ) engine_args = EngineArgs( @@ -375,3 +378,53 @@ def test_cudagraph_sizes_post_init( vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size ) + + +def test_pass_config_deprecation(caplog_vllm): + caplog_vllm.set_level(logging.WARNING) + + # Clear cache to ensure warnings are re-issued + _print_warning_once.cache_clear() + + # Test enable_fusion -> fuse_norm_quant, fuse_act_quant + caplog_vllm.clear() + config = PassConfig(enable_fusion=True) + assert "enable_fusion is deprecated" in caplog_vllm.text + assert config.fuse_norm_quant is True + assert config.fuse_act_quant is True + assert config.enable_fusion is None + + # Test enable_attn_fusion -> fuse_attn_quant + caplog_vllm.clear() + config = PassConfig(enable_attn_fusion=True) + assert "enable_attn_fusion is deprecated" in caplog_vllm.text + assert config.fuse_attn_quant is True + assert config.enable_attn_fusion is None + + # Test enable_noop -> eliminate_noops + caplog_vllm.clear() + config = PassConfig(enable_noop=True) + assert "enable_noop is deprecated" in caplog_vllm.text + assert config.eliminate_noops is True + assert config.enable_noop is None + + # Test enable_sequence_parallelism -> enable_sp + caplog_vllm.clear() + config = PassConfig(enable_sequence_parallelism=True) + assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text + assert config.enable_sp is True + assert config.enable_sequence_parallelism is None + + # Test enable_async_tp -> fuse_gemm_comms + caplog_vllm.clear() + config = PassConfig(enable_async_tp=True) + assert "enable_async_tp is deprecated" in caplog_vllm.text + assert config.fuse_gemm_comms is True + assert config.enable_async_tp is None + + # Test enable_fi_allreduce_fusion -> fuse_allreduce_rms + caplog_vllm.clear() + config = PassConfig(enable_fi_allreduce_fusion=True) + assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text + assert config.fuse_allreduce_rms is True + assert config.enable_fi_allreduce_fusion is None diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 515e0a93ac2a..758591589270 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -223,7 +223,11 @@ def test_fix_functionalization( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( custom_ops=["all"], - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), + pass_config=PassConfig( + fuse_norm_quant=do_fusion, + fuse_act_quant=do_fusion, + eliminate_noops=True, + ), ), ) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 286f2276367a..d0ba8385f4a0 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -159,7 +159,9 @@ def test_fusion_rmsnorm_quant( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), + pass_config=PassConfig( + fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True + ), ), ) with vllm.config.set_current_vllm_config(vllm_config): diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 4d213e030edb..9b4486e56c73 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -373,7 +373,7 @@ def test_attention_quant_pattern( # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( - enable_attn_fusion=True, enable_noop=True + fuse_attn_quant=True, eliminate_noops=True ) with ( set_current_vllm_config(vllm_config), diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py index 0ccc1a016162..bfe08382fd94 100644 --- a/tests/compile/test_noop_elimination.py +++ b/tests/compile/test_noop_elimination.py @@ -51,7 +51,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(enable_noop=True), + pass_config=PassConfig(eliminate_noops=True), ) ) with vllm.config.set_current_vllm_config(vllm_config): @@ -99,7 +99,7 @@ def test_non_noop_slice_preserved(): vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(enable_noop=True), + pass_config=PassConfig(eliminate_noops=True), ) ) with vllm.config.set_current_vllm_config(vllm_config): diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 1c40c599f748..6d0ba6b65503 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -64,8 +64,11 @@ def test_pass_manager_uuid(callable): # UUID should be different due to config change config2 = copy.deepcopy(config) - config2.compilation_config.pass_config.enable_fusion = ( - not config2.compilation_config.pass_config.enable_fusion + config2.compilation_config.pass_config.fuse_norm_quant = ( + not config2.compilation_config.pass_config.fuse_norm_quant + ) + config2.compilation_config.pass_config.fuse_act_quant = ( + not config2.compilation_config.pass_config.fuse_act_quant ) pass_manager3 = PostGradPassManager() pass_manager3.configure(config2) diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py index 5ebb95b6db33..e0968ac79925 100644 --- a/tests/compile/test_qk_norm_rope_fusion.py +++ b/tests/compile/test_qk_norm_rope_fusion.py @@ -140,7 +140,7 @@ def test_qk_norm_rope_fusion( custom_ops=custom_ops, pass_config=PassConfig( enable_qk_norm_rope_fusion=True, - enable_noop=True, + eliminate_noops=True, ), ), ) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 0ddb82b7c3fc..c336a45955cb 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -168,7 +168,7 @@ def test_fusion_silu_and_mul_quant( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), + pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True), ), ) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index f38c509775ed..0a7907aadeab 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -32,7 +32,8 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" class ParallelSetup(NamedTuple): tp_size: int pp_size: int - enable_fusion: bool + fuse_norm_quant: bool + fuse_act_quant: bool eager_mode: bool chunked_prefill: bool @@ -66,7 +67,8 @@ class SPTestSettings: ParallelSetup( tp_size=tp_base, pp_size=pp_multiplier * pp_base, - enable_fusion=False, + fuse_norm_quant=False, + fuse_act_quant=False, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -97,7 +99,8 @@ class SPTestSettings: ParallelSetup( tp_size=tp_base, pp_size=pp_multiplier * pp_base, - enable_fusion=False, + fuse_norm_quant=False, + fuse_act_quant=False, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -126,7 +129,8 @@ class SPTestSettings: ParallelSetup( tp_size=tp_base, pp_size=pp_base, - enable_fusion=fusion_val, + fuse_norm_quant=fusion_val, + fuse_act_quant=fusion_val, eager_mode=True, chunked_prefill=False, ) @@ -162,7 +166,7 @@ def _compare_sp( test_options: SPTestOptions, num_gpus_available: int, use_inductor_graph_partition: bool, - enable_async_tp: bool, + fuse_gemm_comms: bool, *, method: Literal["generate", "encode"], is_multimodal: bool, @@ -170,7 +174,8 @@ def _compare_sp( ( tp_size, pp_size, - enable_fusion, + fuse_norm_quant, + fuse_act_quant, eager_mode, chunked_prefill, ) = parallel_setup @@ -248,10 +253,11 @@ def _compare_sp( "mode": CompilationMode.VLLM_COMPILE, "compile_sizes": [4, 8], "pass_config": { - "enable_sequence_parallelism": True, - "enable_async_tp": enable_async_tp, - "enable_fusion": enable_fusion, - "enable_noop": True, + "enable_sp": True, + "fuse_gemm_comms": fuse_gemm_comms, + "fuse_norm_quant": fuse_norm_quant, + "fuse_act_quant": fuse_act_quant, + "eliminate_noops": True, }, "use_inductor_graph_partition": use_inductor_graph_partition, } @@ -309,7 +315,7 @@ SP_TEST_MODELS = [ ], ) @pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) -@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP +@pytest.mark.parametrize("fuse_gemm_comms", [False]) # TODO: enable async TP @create_new_process_for_each_test() def test_tp_sp_generation( model_id: str, @@ -319,7 +325,7 @@ def test_tp_sp_generation( test_options: SPTestOptions, num_gpus_available, use_inductor_graph_partition: bool, - enable_async_tp: bool, + fuse_gemm_comms: bool, ): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") @@ -328,7 +334,7 @@ def test_tp_sp_generation( if ( "fp8" in model_id.lower() and current_platform.get_device_capability() < (9, 0) - and (not enable_async_tp) + and (not fuse_gemm_comms) ): pytest.skip("FP8 reduction support begins with sm90 capable devices.") @@ -340,7 +346,7 @@ def test_tp_sp_generation( test_options, num_gpus_available, use_inductor_graph_partition, - enable_async_tp=enable_async_tp, + fuse_gemm_comms=fuse_gemm_comms, method="generate", is_multimodal=False, ) diff --git a/tests/test_config.py b/tests/test_config.py index b7ed68fea92a..019c0d6d8733 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1023,17 +1023,17 @@ def test_vllm_config_explicit_overrides(): assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE # Explicit pass config flags to override defaults - pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True) + pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True) compilation_config = CompilationConfig(pass_config=pass_config) config = VllmConfig( optimization_level=OptimizationLevel.O0, compilation_config=compilation_config, ) - assert config.compilation_config.pass_config.enable_noop is True - assert config.compilation_config.pass_config.enable_attn_fusion is True + assert config.compilation_config.pass_config.eliminate_noops is True + assert config.compilation_config.pass_config.fuse_attn_quant is True # Explicit cudagraph mode override on quantized model at O2 - pass_config = PassConfig(enable_async_tp=True) + pass_config = PassConfig(fuse_gemm_comms=True) compilation_config = CompilationConfig( cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config ) @@ -1043,7 +1043,7 @@ def test_vllm_config_explicit_overrides(): compilation_config=compilation_config, ) assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE - assert config.compilation_config.pass_config.enable_async_tp is True + assert config.compilation_config.pass_config.fuse_gemm_comms is True # Mode should still use default for O2 assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE @@ -1093,7 +1093,7 @@ def test_vllm_config_explicit_overrides(): compilation_config=compilation_config, ) # Explicit override should be respected - assert config.compilation_config.pass_config.enable_noop is False + assert config.compilation_config.pass_config.eliminate_noops is False # Other fields should still use defaults assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index fe2547d7feca..37f48721ea20 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -92,22 +92,23 @@ class PostGradPassManager(CustomGraphPass): # Set the current vllm config to allow tracing CustomOp instances with set_current_vllm_config(config, check_compile=False): - if self.pass_config.enable_noop: + if self.pass_config.eliminate_noops: self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_sequence_parallelism: + if self.pass_config.enable_sp: self.passes += [SequenceParallelismPass(config)] - if self.pass_config.enable_async_tp: + if self.pass_config.fuse_gemm_comms: self.passes += [AsyncTPPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: + if self.pass_config.fuse_allreduce_rms: self.passes += [AllReduceFusionPass(config)] - if self.pass_config.enable_fusion: + if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] + if self.pass_config.fuse_act_quant: self.passes += [ActivationQuantFusionPass(config)] - if self.pass_config.enable_attn_fusion: + if self.pass_config.fuse_attn_quant: self.passes += [AttnFusionPass(config)] if self.pass_config.enable_qk_norm_rope_fusion: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f876c38169a..963b091939e0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config +from vllm.config.utils import config, handle_deprecated from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -105,18 +105,43 @@ class PassConfig: improper state. """ + # New flags + fuse_norm_quant: bool = Field(default=None) + """Fuse the custom RMSNorm + quant ops.""" + fuse_act_quant: bool = Field(default=None) + """Fuse the custom SiluMul + quant ops.""" + fuse_attn_quant: bool = Field(default=None) + """Fuse the custom attention + quant ops.""" + eliminate_noops: bool = Field(default=None) + """Eliminate no-op ops.""" + enable_sp: bool = Field(default=None) + """Enable sequence parallelism.""" + fuse_gemm_comms: bool = Field(default=None) + """Enable async TP.""" + fuse_allreduce_rms: bool = Field(default=None) + """Enable flashinfer allreduce fusion.""" + + # Deprecated flags enable_fusion: bool = Field(default=None) - """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" + """Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant + instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner. + """ enable_attn_fusion: bool = Field(default=None) - """Whether to enable the custom attention+quant fusion pass.""" + """Deprecated in: v0.12.0. Use fuse_attn_quant instead. + Will be removed in v0.13.0 or v1.0.0, whichever is sooner.""" enable_noop: bool = Field(default=None) - """Whether to enable the custom no-op elimination pass.""" + """Deprecated in: v0.12.0. Use eliminate_noops instead. + Will be removed in v0.13.0 or v1.0.0, whichever is sooner.""" enable_sequence_parallelism: bool = Field(default=None) - """Whether to enable sequence parallelism.""" + """Deprecated in: v0.12.0. Use enable_sp instead. + Will be removed in v0.13.0 or v1.0.0, whichever is sooner.""" enable_async_tp: bool = Field(default=None) - """Whether to enable async TP.""" + """Deprecated in: v0.12.0. Use fuse_gemm_comms instead. + Will be removed in v0.13.0 or v1.0.0, whichever is sooner.""" enable_fi_allreduce_fusion: bool = Field(default=None) - """Whether to enable flashinfer allreduce fusion.""" + """Deprecated in: v0.12.0. Use fuse_allreduce_rms instead. + Will be removed in v0.13.0 or v1.0.0, whichever is sooner.""" + fi_allreduce_fusion_max_size_mb: float | None = None """The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a @@ -136,7 +161,7 @@ class PassConfig: }, }, where key is the device capability""" enable_qk_norm_rope_fusion: bool = False - """Whether to enable the fused Q/K RMSNorm + RoPE pass.""" + """Enable fused Q/K RMSNorm + RoPE pass.""" # TODO(luka) better pass enabling system. @@ -174,6 +199,13 @@ class PassConfig: return InductorPass.hash_dict(asdict(self)) @field_validator( + "fuse_norm_quant", + "fuse_act_quant", + "fuse_attn_quant", + "eliminate_noops", + "enable_sp", + "fuse_gemm_comms", + "fuse_allreduce_rms", "enable_fusion", "enable_attn_fusion", "enable_noop", @@ -190,18 +222,71 @@ class PassConfig: return handler(value) def __post_init__(self) -> None: - if not self.enable_noop: - if self.enable_fusion: + # Handle deprecation and defaults + + # Map old flags to new flags and issue warnings + handle_deprecated( + self, + "enable_fusion", + ["fuse_norm_quant", "fuse_act_quant"], + "v0.13.0 or v1.0.0, whichever is sooner", + ) + + handle_deprecated( + self, + "enable_attn_fusion", + "fuse_attn_quant", + "v0.13.0 or v1.0.0, whichever is sooner", + ) + + handle_deprecated( + self, + "enable_sequence_parallelism", + "enable_sp", + "v0.13.0 or v1.0.0, whichever is sooner", + ) + + handle_deprecated( + self, + "enable_async_tp", + "fuse_gemm_comms", + "v0.13.0 or v1.0.0, whichever is sooner", + ) + + handle_deprecated( + self, + "enable_fi_allreduce_fusion", + "fuse_allreduce_rms", + "v0.13.0 or v1.0.0, whichever is sooner", + ) + + handle_deprecated( + self, + "enable_noop", + "eliminate_noops", + "v0.13.0 or v1.0.0, whichever is sooner", + ) + + # Force old flags to None to ensure they are not used + self.enable_fusion = None + self.enable_attn_fusion = None + self.enable_noop = None + self.enable_sequence_parallelism = None + self.enable_async_tp = None + self.enable_fi_allreduce_fusion = None + + if not self.eliminate_noops: + if self.fuse_norm_quant or self.fuse_act_quant: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "RMSNorm/SiluMul + quant (fp8) fusion might not work" ) - if self.enable_attn_fusion: + if self.fuse_attn_quant: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work" ) - if self.enable_fi_allreduce_fusion: + if self.fuse_allreduce_rms: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work" @@ -873,7 +958,7 @@ class CompilationConfig: self.set_splitting_ops_for_inductor_graph_partition() return - if self.pass_config.enable_attn_fusion: + if self.pass_config.fuse_attn_quant: # here use_inductor_graph_partition is False self.set_splitting_ops_for_attn_fusion() return @@ -915,12 +1000,12 @@ class CompilationConfig: self.splitting_ops = list(self._attention_ops) def set_splitting_ops_for_attn_fusion(self): - assert self.pass_config.enable_attn_fusion + assert self.pass_config.fuse_attn_quant if self.splitting_ops is None: self.splitting_ops = [] if self.cudagraph_mode.has_piecewise_cudagraphs(): logger.warning_once( - "enable_attn_fusion is incompatible with piecewise " + "fuse_attn_quant is incompatible with piecewise " "cudagraph when use_inductor_graph_partition is off. " "In this case, splitting_ops will be set to empty " "list, and cudagraph_mode will be set to FULL. " @@ -931,8 +1016,7 @@ class CompilationConfig: self.cudagraph_mode = CUDAGraphMode.FULL assert not self.splitting_ops_contain_attention(), ( - "attention ops should not be in splitting_ops " - "when enable_attn_fusion is True" + "attention ops should not be in splitting_ops when fuse_attn_quant is True" ) def splitting_ops_contain_attention(self) -> bool: @@ -1008,7 +1092,7 @@ class CompilationConfig: self, uniform_decode_query_len: int, tensor_parallel_size: int ): multiple_of = uniform_decode_query_len - if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism: + if tensor_parallel_size > 1 and self.pass_config.enable_sp: multiple_of = max(uniform_decode_query_len, tensor_parallel_size) if ( multiple_of % uniform_decode_query_len != 0 diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 02f2b75f608f..3124fcf00739 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -19,6 +19,10 @@ import torch from pydantic.fields import FieldInfo from typing_extensions import runtime_checkable +from vllm.logger import init_logger + +logger = init_logger(__name__) + if TYPE_CHECKING: from _typeshed import DataclassInstance else: @@ -293,3 +297,28 @@ def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, ob def hash_factors(items: dict[str, object]) -> str: """Return a SHA-256 hex digest of the canonical items structure.""" return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest() + + +def handle_deprecated( + config: ConfigT, + old_name: str, + new_name_or_names: str | list[str], + removal_version: str, +) -> None: + old_val = getattr(config, old_name) + if old_val is None: + return + + if isinstance(new_name_or_names, str): + new_names = [new_name_or_names] + else: + new_names = new_name_or_names + + msg = ( + f"{old_name} is deprecated and will be removed in {removal_version}. " + f"Use {', '.join(new_names)} instead." + ) + logger.warning(msg) + + for new_name in new_names: + setattr(config, new_name, old_val) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 5b3a9c437662..735b0afbaaeb 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -83,22 +83,33 @@ IS_DENSE = False # See https://github.com/vllm-project/vllm/issues/25689. -def enable_fusion(cfg: "VllmConfig") -> bool: - """Returns True if RMS norm or quant FP8 is enabled.""" +def enable_norm_fusion(cfg: "VllmConfig") -> bool: + """Enable if either RMS norm or quant FP8 custom op is active; + otherwise Inductor handles fusion.""" + return cfg.compilation_config.is_custom_op_enabled( "rms_norm" ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") +def enable_act_fusion(cfg: "VllmConfig") -> bool: + """Enable if either SiLU+Mul or quant FP8 custom op is active; + otherwise Inductor handles fusion.""" + return cfg.compilation_config.is_custom_op_enabled( + "silu_and_mul" + ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") + + OPTIMIZATION_LEVEL_00 = { "compilation_config": { "pass_config": { - "enable_noop": False, - "enable_fusion": False, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": False, - "enable_sequence_parallelism": False, - "enable_async_tp": False, + "eliminate_noops": False, + "fuse_norm_quant": False, + "fuse_act_quant": False, + "fuse_allreduce_rms": False, + "fuse_attn_quant": False, + "enable_sp": False, + "fuse_gemm_comms": False, }, "cudagraph_mode": CUDAGraphMode.NONE, "use_inductor_graph_partition": False, @@ -107,12 +118,13 @@ OPTIMIZATION_LEVEL_00 = { OPTIMIZATION_LEVEL_01 = { "compilation_config": { "pass_config": { - "enable_noop": True, - "enable_fusion": enable_fusion, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": False, - "enable_sequence_parallelism": False, - "enable_async_tp": False, + "eliminate_noops": True, + "fuse_norm_quant": enable_norm_fusion, + "fuse_act_quant": enable_act_fusion, + "fuse_allreduce_rms": False, + "fuse_attn_quant": False, + "enable_sp": False, + "fuse_gemm_comms": False, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, "use_inductor_graph_partition": False, @@ -121,12 +133,13 @@ OPTIMIZATION_LEVEL_01 = { OPTIMIZATION_LEVEL_02 = { "compilation_config": { "pass_config": { - "enable_noop": True, - "enable_fusion": enable_fusion, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": IS_QUANTIZED, - "enable_sequence_parallelism": IS_DENSE, - "enable_async_tp": IS_DENSE, + "eliminate_noops": True, + "fuse_norm_quant": enable_norm_fusion, + "fuse_act_quant": enable_act_fusion, + "fuse_allreduce_rms": False, + "fuse_attn_quant": IS_QUANTIZED, + "enable_sp": IS_DENSE, + "fuse_gemm_comms": IS_DENSE, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -135,12 +148,13 @@ OPTIMIZATION_LEVEL_02 = { OPTIMIZATION_LEVEL_03 = { "compilation_config": { "pass_config": { - "enable_noop": True, - "enable_fusion": enable_fusion, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": IS_QUANTIZED, - "enable_sequence_parallelism": IS_DENSE, - "enable_async_tp": IS_DENSE, + "eliminate_noops": True, + "fuse_norm_quant": enable_norm_fusion, + "fuse_act_quant": enable_act_fusion, + "fuse_allreduce_rms": False, + "fuse_attn_quant": IS_QUANTIZED, + "enable_sp": IS_DENSE, + "fuse_gemm_comms": IS_DENSE, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -645,9 +659,9 @@ class VllmConfig: # async tp is built on top of sequence parallelism # and requires it to be enabled. - if self.compilation_config.pass_config.enable_async_tp: - self.compilation_config.pass_config.enable_sequence_parallelism = True - if self.compilation_config.pass_config.enable_sequence_parallelism: + if self.compilation_config.pass_config.fuse_gemm_comms: + self.compilation_config.pass_config.enable_sp = True + if self.compilation_config.pass_config.enable_sp: if "-rms_norm" in self.compilation_config.custom_ops: logger.warning( "RMS norm force disabled, sequence parallelism might break" @@ -797,7 +811,7 @@ class VllmConfig: # Do this after all the updates to compilation_config.mode self.compilation_config.set_splitting_ops_for_v1() - if self.compilation_config.pass_config.enable_sequence_parallelism: + if self.compilation_config.pass_config.enable_sp: # With pipeline parallelism or dynamo partitioning, # native rms norm tracing errors due to incorrect residual shape. # Use custom rms norm to unblock. In the future, @@ -1062,7 +1076,7 @@ class VllmConfig: if ( self.parallel_config.tensor_parallel_size > 1 - and self.compilation_config.pass_config.enable_sequence_parallelism + and self.compilation_config.pass_config.enable_sp ): cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( cudagraph_capture_sizes diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8c22ada029b1..1b250a8bd009 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2417,10 +2417,7 @@ class GPUModelRunner( # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if ( - self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1 - ): + if self.compilation_config.pass_config.enable_sp and tp_size > 1: return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ed6fb32bcb2f..edba07a423cd 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -552,7 +552,7 @@ class Worker(WorkerBase): if ( parallel_config.pipeline_parallel_size > 1 - and compilation_config.pass_config.enable_sequence_parallelism + and compilation_config.pass_config.enable_sp and forward_pass ): # currently only supported by V1 GPUModelRunner diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index bd88cb1b253f..427a0d296b25 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -342,7 +342,7 @@ def is_residual_scattered_for_sp( partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes """ - if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: + if not vllm_config.compilation_config.pass_config.enable_sp: return False tp = vllm_config.parallel_config.tensor_parallel_size