mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 10:17:56 +08:00
[bugfix] Fix SP + PP without specifying compile size (#26955)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
parent
582f2c6be7
commit
e19b16dde6
@ -18,6 +18,7 @@ import pytest
|
|||||||
from vllm.config.compilation import CompilationMode
|
from vllm.config.compilation import CompilationMode
|
||||||
from vllm.config.model import RunnerOption
|
from vllm.config.model import RunnerOption
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
from ..models.registry import HF_EXAMPLE_MODELS
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||||
@ -159,6 +160,7 @@ def _compare_sp(
|
|||||||
runner: RunnerOption,
|
runner: RunnerOption,
|
||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available: int,
|
num_gpus_available: int,
|
||||||
|
use_inductor_graph_partition: bool,
|
||||||
*,
|
*,
|
||||||
method: Literal["generate", "encode"],
|
method: Literal["generate", "encode"],
|
||||||
is_multimodal: bool,
|
is_multimodal: bool,
|
||||||
@ -243,6 +245,7 @@ def _compare_sp(
|
|||||||
"enable_fusion": enable_fusion,
|
"enable_fusion": enable_fusion,
|
||||||
"enable_noop": True,
|
"enable_noop": True,
|
||||||
},
|
},
|
||||||
|
"use_inductor_graph_partition": use_inductor_graph_partition,
|
||||||
}
|
}
|
||||||
|
|
||||||
tp_sp_args = [
|
tp_sp_args = [
|
||||||
@ -297,6 +300,7 @@ SP_TEST_MODELS = [
|
|||||||
if model_id in SP_TEST_MODELS
|
if model_id in SP_TEST_MODELS
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_tp_sp_generation(
|
def test_tp_sp_generation(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -305,7 +309,11 @@ def test_tp_sp_generation(
|
|||||||
runner: RunnerOption,
|
runner: RunnerOption,
|
||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
|
use_inductor_graph_partition: bool,
|
||||||
):
|
):
|
||||||
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
|
|
||||||
_compare_sp(
|
_compare_sp(
|
||||||
model_id,
|
model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
@ -313,6 +321,7 @@ def test_tp_sp_generation(
|
|||||||
runner,
|
runner,
|
||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
|
use_inductor_graph_partition,
|
||||||
method="generate",
|
method="generate",
|
||||||
is_multimodal=False,
|
is_multimodal=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -328,8 +328,12 @@ def is_residual_scattered_for_sp(
|
|||||||
"""Check if the residual tensor is scattered for sequence parallelism.
|
"""Check if the residual tensor is scattered for sequence parallelism.
|
||||||
|
|
||||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||||
parallelism and tensor parallelism is enabled, and the number of
|
parallelism and tensor parallelism is enabled.
|
||||||
input tokens is one of the compilation sizes.
|
|
||||||
|
This follows the same logic as SequenceParallelismPass.is_applicable():
|
||||||
|
- In full-graph compilation mode (no splitting ops or using inductor graph
|
||||||
|
partition), SP is always applied
|
||||||
|
- Otherwise, SP is only applied for specific shapes in compile_sizes
|
||||||
"""
|
"""
|
||||||
if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism:
|
if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
return False
|
return False
|
||||||
@ -343,5 +347,10 @@ def is_residual_scattered_for_sp(
|
|||||||
# to be a multiple of tensor_parallel_size (tp) earlier.
|
# to be a multiple of tensor_parallel_size (tp) earlier.
|
||||||
assert num_input_tokens % tp == 0
|
assert num_input_tokens % tp == 0
|
||||||
|
|
||||||
# Currently, SP is only enabled for static size fx graphs.
|
if (
|
||||||
|
not vllm_config.compilation_config.splitting_ops
|
||||||
|
or vllm_config.compilation_config.use_inductor_graph_partition
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
return num_input_tokens in vllm_config.compilation_config.compile_sizes
|
return num_input_tokens in vllm_config.compilation_config.compile_sizes
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user