[CLI][Doc] Formalize --mm-encoder-tp-mode (#23190)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-20 23:42:28 +08:00 committed by GitHub
parent b17109beea
commit 5efd6905bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 104 additions and 24 deletions

View File

@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
### Batch-level DP for Multi-Modal Encoders
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
in order to reduce the memory and compute load on each GPU.
However, since the size of multi-modal encoders is very small compared to language decoders,
there is relatively little gain from TP. On the other hand, TP incurs significant communication
overhead because of all-reduce being performed after every layer.
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
performing batch-level DP. This has been shown to improve the throughput by around 10% for
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
batch-level DP can provide another 40% increase to throughput compared to regular TP.
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:
```python
from vllm import LLM
llm = LLM(
model="Qwen/Qwen2.5-VL-72B-Instruct",
# Create two EngineCore instances, one per DP rank
data_parallel_size=2,
# Within each EngineCore instance:
# The vision encoder uses TP=4 (not DP=2) to shard the input data
# The language decoder uses TP=4 to shard the weights as usual
tensor_parallel_size=4,
mm_encoder_tp_mode="data",
)
```
!! important
Batch-level DP is not to be confused with API request-level DP
(which is instead controlled by `data_parallel_size`).
The availablilty of batch-level DP is based on model implementation.
Currently, the following models support `mm_encoder_tp_mode="data"`:
- Llama4 (<gh-pr:18368>)
- Qwen2.5-VL (<gh-pr:22742>)
- Step3 (<gh-pr:22697>)
## Input Processing ## Input Processing
### Parallel Processing ### Parallel Processing

View File

@ -258,6 +258,7 @@ TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"] "processed_logits"]
MMEncoderTPMode = Literal["weights", "data"]
@config @config
@ -438,6 +439,19 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`. `mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended).""" Set to `0` to disable this cache completely (not recommended)."""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config: dict[str, Any] = field(default_factory=dict) override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config """Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to that are specific to Neuron devices, this argument will be used to
@ -856,8 +870,10 @@ class ModelConfig:
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
interleave_mm_strings=self.interleave_mm_strings, interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling) skip_mm_profiling=self.skip_mm_profiling,
)
return None return None
@ -2547,6 +2563,22 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended). Set to `0` to disable this cache completely (not recommended).
""" """
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""
Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).
- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.
"""
interleave_mm_strings: bool = False interleave_mm_strings: bool = False
""" """
Enable fully interleaved support for multimodal prompts. Enable fully interleaved support for multimodal prompts.

View File

@ -137,10 +137,6 @@ class ParallelConfig:
rank: int = 0 rank: int = 0
"""Global rank in distributed setup.""" """Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel: bool = False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
@property @property
def world_size_across_dp(self) -> int: def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world """world_size_across_dp is TPxPPxDP, it is the size of the world

View File

@ -28,12 +28,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend, DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, HfOverrides, KVEventsConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode, KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, ModelConfig, ModelDType, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelImpl, MultiModalConfig, ObservabilityConfig, ModelDType, ModelImpl, MultiModalConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, ObservabilityConfig, ParallelConfig, PoolerConfig,
RunnerOption, SchedulerConfig, SchedulerPolicy, PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerMode, SchedulerPolicy, SpeculativeConfig, TaskOption,
VllmConfig, get_attr_docs, get_field) TokenizerMode, VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
@ -352,6 +352,7 @@ class EngineArgs:
MultiModalConfig.mm_processor_kwargs MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
@ -434,16 +435,14 @@ class EngineArgs:
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location pt_load_map_location: str = LoadConfig.pt_load_map_location
enable_multimodal_encoder_data_parallel: bool = \ # DEPRECATED
ParallelConfig.enable_multimodal_encoder_data_parallel enable_multimodal_encoder_data_parallel: bool = False
logits_processors: Optional[list[Union[ logits_processors: Optional[list[Union[
str, type[LogitsProcessor]]]] = ModelConfig.logits_processors str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
"""Custom logitproc types""" """Custom logitproc types"""
async_scheduling: bool = SchedulerConfig.async_scheduling async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
kv_sharing_fast_prefill: bool = \ kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill CacheConfig.kv_sharing_fast_prefill
@ -685,7 +684,8 @@ class EngineArgs:
**parallel_kwargs["worker_extension_cls"]) **parallel_kwargs["worker_extension_cls"])
parallel_group.add_argument( parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel", "--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"]) action="store_true",
deprecated=True)
# KV cache arguments # KV cache arguments
cache_kwargs = get_kwargs(CacheConfig) cache_kwargs = get_kwargs(CacheConfig)
@ -735,6 +735,8 @@ class EngineArgs:
multimodal_group.add_argument("--disable-mm-preprocessor-cache", multimodal_group.add_argument("--disable-mm-preprocessor-cache",
action="store_true", action="store_true",
deprecated=True) deprecated=True)
multimodal_group.add_argument(
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
multimodal_group.add_argument( multimodal_group.add_argument(
"--interleave-mm-strings", "--interleave-mm-strings",
**multimodal_kwargs["interleave_mm_strings"]) **multimodal_kwargs["interleave_mm_strings"])
@ -909,6 +911,14 @@ class EngineArgs:
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
if self.enable_multimodal_encoder_data_parallel:
logger.warning(
"--enable-multimodal-encoder-data-parallel` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-encoder-tp-mode data` instead.")
self.mm_encoder_tp_mode = "data"
return ModelConfig( return ModelConfig(
model=self.model, model=self.model,
hf_config_path=self.hf_config_path, hf_config_path=self.hf_config_path,
@ -947,6 +957,7 @@ class EngineArgs:
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern, logits_processor_pattern=self.logits_processor_pattern,
@ -1258,8 +1269,6 @@ class EngineArgs:
distributed_executor_backend=self.distributed_executor_backend, distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls, worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
) )
if model_config.is_multimodal_model: if model_config.is_multimodal_model:

View File

@ -728,8 +728,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = (vllm_config.parallel_config. self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
enable_multimodal_encoder_data_parallel)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config

View File

@ -877,8 +877,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = (vllm_config.parallel_config. self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
enable_multimodal_encoder_data_parallel)
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config

View File

@ -882,8 +882,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = (vllm_config.parallel_config. self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
enable_multimodal_encoder_data_parallel)
if multimodal_config.get_limit_per_prompt("image"): if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Step3VisionTransformer( self.vision_model = Step3VisionTransformer(