From b38bc652ac5111d96cfd41e3575a879e9b47efbd Mon Sep 17 00:00:00 2001 From: Jason Gu <1057337859@qq.com> Date: Fri, 25 Jul 2025 13:45:16 +0800 Subject: [PATCH] [Model] Support tensor parallel for timm ViT in Deepseek_vl2 (#21494) Signed-off-by: wzqd <1057337859@qq.com> --- vllm/model_executor/models/deepseek_vl2.py | 40 ++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index a222c4cbe9d0..0ca6b28073ec 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -14,9 +14,11 @@ from einops import rearrange, repeat from transformers import BatchFeature from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) @@ -379,6 +381,37 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str): + """Return (parent_module, final_attr_name) for a dotted module path.""" + names = dotted_name.split('.') + parent = root + for n in names[:-1]: + parent = getattr(parent, n) + return parent, names[-1] + + #patch for timm ViT instance to support tensor parallel + def patch_vit_for_tp(self, vit: torch.nn.Module, + quant_config: QuantizationConfig): + try: + import timm + except ImportError as e: + raise ImportError("Please install timm") from e + + for name, module in vit.named_modules(): + if isinstance(module, nn.Linear): + parent, attr_name = self._get_parent_and_attr(vit, name) + if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": + new_linear = replace_linear_class(module, "colwise", + quant_config) + setattr(parent, attr_name, new_linear) + elif isinstance(parent, + timm.layers.Mlp) and attr_name == "fc2": + new_linear = replace_linear_class(module, "rowwise", + quant_config) + setattr(parent, attr_name, new_linear) + + return vit + def _init_vision_module( self, vision_config: VisionEncoderConfig, @@ -388,8 +421,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # TODO: refactor vision model through timm wrapper from transformers try: import timm - except ImportError: - raise ImportError("Please install timm") from ImportError + except ImportError as e: + raise ImportError("Please install timm") from e with set_default_torch_dtype(torch.float16): model = timm.create_model( @@ -400,6 +433,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): dynamic_img_pad=True, ) + if get_tensor_model_parallel_world_size() > 1: + model = self.patch_vit_for_tp(model, quant_config) + model = model.to(dtype=torch.get_default_dtype()) return model