mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
Support sequence parallelism combined with pipeline parallelism (#18243)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
parent
66e63e86ec
commit
9ab2c02ff8
@ -26,6 +26,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
sp_enabled: bool
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
@ -60,6 +61,7 @@ class SPTestSettings:
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
pp_base: int = 1,
|
||||
multi_node_only: bool = False,
|
||||
task: TaskOption = "auto",
|
||||
load_format: Optional[str] = None,
|
||||
@ -67,18 +69,42 @@ class SPTestSettings:
|
||||
return SPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=True,
|
||||
chunked_prefill=True)
|
||||
@ -94,6 +120,7 @@ class SPTestSettings:
|
||||
def fast(
|
||||
*,
|
||||
tp_base: int = 2,
|
||||
pp_base: int = 1,
|
||||
task: TaskOption = "auto",
|
||||
multi_node_only: bool = False,
|
||||
load_format: Optional[str] = None,
|
||||
@ -101,6 +128,12 @@ class SPTestSettings:
|
||||
return SPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
sp_enabled=True,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
@ -136,6 +169,7 @@ def _compare_sp(
|
||||
):
|
||||
(
|
||||
tp_size,
|
||||
pp_size,
|
||||
sp_enabled,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
@ -167,7 +201,6 @@ def _compare_sp(
|
||||
else:
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
pp_size = 1
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
@ -256,7 +289,7 @@ def _compare_sp(
|
||||
|
||||
SP_TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
|
||||
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
|
||||
}
|
||||
|
||||
SP_TEST_MODELS = [
|
||||
|
||||
@ -4287,18 +4287,6 @@ class VllmConfig:
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if self.parallel_config is not None and \
|
||||
self.parallel_config.tensor_parallel_size > 1 and \
|
||||
self.parallel_config.pipeline_parallel_size > 1 and \
|
||||
self.compilation_config is not None and \
|
||||
self.compilation_config.pass_config is not None and \
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
logger.warning_once(
|
||||
"Sequence parallelism is not supported with pipeline "
|
||||
"parallelism. Disabling sequence parallelism.")
|
||||
self.compilation_config.pass_config.\
|
||||
enable_sequence_parallelism = False
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
if self.cache_config is not None and \
|
||||
|
||||
@ -1056,6 +1056,40 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
indices=out_indices,
|
||||
)
|
||||
|
||||
def sync_and_slice_intermediate_tensors(
|
||||
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
|
||||
sync_self: bool) -> IntermediateTensors:
|
||||
|
||||
assert self.intermediate_tensors is not None
|
||||
|
||||
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
enabled_sp = self.vllm_config.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism
|
||||
if enabled_sp:
|
||||
# When sequence parallelism is enabled, we always pad num_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier
|
||||
assert num_tokens % tp == 0
|
||||
is_residual_scattered = tp > 1 and enabled_sp \
|
||||
and num_tokens % tp == 0
|
||||
|
||||
# When sequence parallelism is enabled, the "residual" tensor is sharded
|
||||
# across tensor parallel ranks, so each rank only needs its own slice.
|
||||
if sync_self:
|
||||
assert intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
is_scattered = "residual" and is_residual_scattered
|
||||
copy_len = num_tokens // tp if is_scattered else \
|
||||
num_tokens
|
||||
self.intermediate_tensors[k][:copy_len].copy_(
|
||||
v[:copy_len], non_blocking=True)
|
||||
|
||||
return IntermediateTensors({
|
||||
k:
|
||||
v[:num_tokens // tp]
|
||||
if k == "residual" and is_residual_scattered else v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -1131,15 +1165,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
assert self.intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
self.intermediate_tensors[k][:num_input_tokens].copy_(
|
||||
v[:num_input_tokens], non_blocking=True)
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_input_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
@ -1658,10 +1685,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device))
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_tokens, None, False)
|
||||
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user