diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 362e9daf5ae04..deefdf22ba06b 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -18,6 +18,7 @@ import pytest from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger +from vllm.utils import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, create_new_process_for_each_test @@ -159,6 +160,7 @@ def _compare_sp( runner: RunnerOption, test_options: SPTestOptions, num_gpus_available: int, + use_inductor_graph_partition: bool, *, method: Literal["generate", "encode"], is_multimodal: bool, @@ -243,6 +245,7 @@ def _compare_sp( "enable_fusion": enable_fusion, "enable_noop": True, }, + "use_inductor_graph_partition": use_inductor_graph_partition, } tp_sp_args = [ @@ -297,6 +300,7 @@ SP_TEST_MODELS = [ if model_id in SP_TEST_MODELS ], ) +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @create_new_process_for_each_test() def test_tp_sp_generation( model_id: str, @@ -305,7 +309,11 @@ def test_tp_sp_generation( runner: RunnerOption, test_options: SPTestOptions, num_gpus_available, + use_inductor_graph_partition: bool, ): + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + _compare_sp( model_id, parallel_setup, @@ -313,6 +321,7 @@ def test_tp_sp_generation( runner, test_options, num_gpus_available, + use_inductor_graph_partition, method="generate", is_multimodal=False, ) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index d63978b32c187..f384ede066210 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -328,8 +328,12 @@ def is_residual_scattered_for_sp( """Check if the residual tensor is scattered for sequence parallelism. The residual tensor is scattered across tensor parallel ranks when sequence - parallelism and tensor parallelism is enabled, and the number of - input tokens is one of the compilation sizes. + parallelism and tensor parallelism is enabled. + + This follows the same logic as SequenceParallelismPass.is_applicable(): + - In full-graph compilation mode (no splitting ops or using inductor graph + partition), SP is always applied + - Otherwise, SP is only applied for specific shapes in compile_sizes """ if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: return False @@ -343,5 +347,10 @@ def is_residual_scattered_for_sp( # to be a multiple of tensor_parallel_size (tp) earlier. assert num_input_tokens % tp == 0 - # Currently, SP is only enabled for static size fx graphs. + if ( + not vllm_config.compilation_config.splitting_ops + or vllm_config.compilation_config.use_inductor_graph_partition + ): + return True + return num_input_tokens in vllm_config.compilation_config.compile_sizes