From 5efd6905bc8469a30664de83bdafaad56aa92903 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 20 Aug 2025 23:42:28 +0800 Subject: [PATCH] [CLI][Doc] Formalize `--mm-encoder-tp-mode` (#23190) Signed-off-by: DarkLight1337 --- docs/configuration/optimization.md | 45 ++++++++++++++++++++++++ vllm/config/__init__.py | 34 +++++++++++++++++- vllm/config/parallel.py | 4 --- vllm/engine/arg_utils.py | 35 +++++++++++------- vllm/model_executor/models/mllama4.py | 4 +-- vllm/model_executor/models/qwen2_5_vl.py | 3 +- vllm/model_executor/models/step3_vl.py | 3 +- 7 files changed, 104 insertions(+), 24 deletions(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index c7f50497d6ff..db9dfb313fb8 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. +### Batch-level DP for Multi-Modal Encoders + +By default, TP is used to shard the weights of multi-modal encoders just like for language decoders, +in order to reduce the memory and compute load on each GPU. + +However, since the size of multi-modal encoders is very small compared to language decoders, +there is relatively little gain from TP. On the other hand, TP incurs significant communication +overhead because of all-reduce being performed after every layer. + +Given this, it may be advantageous to instead shard the batched input data using TP, essentially +performing batch-level DP. This has been shown to improve the throughput by around 10% for +`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, +batch-level DP can provide another 40% increase to throughput compared to regular TP. + +Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, +there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. + +You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example: + +```python +from vllm import LLM + +llm = LLM( + model="Qwen/Qwen2.5-VL-72B-Instruct", + # Create two EngineCore instances, one per DP rank + data_parallel_size=2, + # Within each EngineCore instance: + # The vision encoder uses TP=4 (not DP=2) to shard the input data + # The language decoder uses TP=4 to shard the weights as usual + tensor_parallel_size=4, + mm_encoder_tp_mode="data", +) +``` + +!! important + Batch-level DP is not to be confused with API request-level DP + (which is instead controlled by `data_parallel_size`). + +The availablilty of batch-level DP is based on model implementation. +Currently, the following models support `mm_encoder_tp_mode="data"`: + +- Llama4 () +- Qwen2.5-VL () +- Step3 () + ## Input Processing ### Parallel Processing diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 801fa97fe5da..5b5d477ef066 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -258,6 +258,7 @@ TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"] +MMEncoderTPMode = Literal["weights", "data"] @config @@ -438,6 +439,19 @@ class ModelConfig: `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. Set to `0` to disable this cache completely (not recommended).""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" override_neuron_config: dict[str, Any] = field(default_factory=dict) """Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to @@ -856,8 +870,10 @@ class ModelConfig: media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, interleave_mm_strings=self.interleave_mm_strings, - skip_mm_profiling=self.skip_mm_profiling) + skip_mm_profiling=self.skip_mm_profiling, + ) return None @@ -2547,6 +2563,22 @@ class MultiModalConfig: Set to `0` to disable this cache completely (not recommended). """ + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """ + Indicates how to optimize multi-modal encoder inference using + tensor parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior) + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP. + """ + interleave_mm_strings: bool = False """ Enable fully interleaved support for multimodal prompts. diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index bac1e63800d7..7a9e68f0ea33 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -137,10 +137,6 @@ class ParallelConfig: rank: int = 0 """Global rank in distributed setup.""" - enable_multimodal_encoder_data_parallel: bool = False - """ Use data parallelism instead of tensor parallelism for vision encoder. - Only support LLama4 for now""" - @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 48d9cd08af03..6869c3f23f31 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -28,12 +28,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, ModelConfig, ModelDType, - ModelImpl, MultiModalConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, - RunnerOption, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, TaskOption, TokenizerMode, - VllmConfig, get_attr_docs, get_field) + LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, + ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -352,6 +352,7 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields enable_lora: bool = False @@ -434,16 +435,14 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - enable_multimodal_encoder_data_parallel: bool = \ - ParallelConfig.enable_multimodal_encoder_data_parallel + # DEPRECATED + enable_multimodal_encoder_data_parallel: bool = False logits_processors: Optional[list[Union[ str, type[LogitsProcessor]]]] = ModelConfig.logits_processors """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - # DEPRECATED - enable_prompt_adapter: bool = False kv_sharing_fast_prefill: bool = \ CacheConfig.kv_sharing_fast_prefill @@ -685,7 +684,8 @@ class EngineArgs: **parallel_kwargs["worker_extension_cls"]) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", - **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + action="store_true", + deprecated=True) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -735,6 +735,8 @@ class EngineArgs: multimodal_group.add_argument("--disable-mm-preprocessor-cache", action="store_true", deprecated=True) + multimodal_group.add_argument( + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"]) @@ -909,6 +911,14 @@ class EngineArgs: self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB + if self.enable_multimodal_encoder_data_parallel: + logger.warning( + "--enable-multimodal-encoder-data-parallel` is deprecated " + "and will be removed in v0.13. " + "Please use `--mm-encoder-tp-mode data` instead.") + + self.mm_encoder_tp_mode = "data" + return ModelConfig( model=self.model, hf_config_path=self.hf_config_path, @@ -947,6 +957,7 @@ class EngineArgs: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_encoder_tp_mode=self.mm_encoder_tp_mode, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -1258,8 +1269,6 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, - enable_multimodal_encoder_data_parallel=self. - enable_multimodal_encoder_data_parallel, ) if model_config.is_multimodal_model: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 35103eac8fb5..595bdd17cf2c 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -728,8 +728,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 34eec10296b5..811ecffcc1e4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -877,8 +877,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 5d41a9e569f5..f8877b584b19 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -882,8 +882,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.use_data_parallel = (vllm_config.parallel_config. - enable_multimodal_encoder_data_parallel) + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" if multimodal_config.get_limit_per_prompt("image"): self.vision_model = Step3VisionTransformer(