mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Model] Support tensor parallel for timm ViT in Deepseek_vl2 (#21494)
Signed-off-by: wzqd <1057337859@qq.com>
This commit is contained in:
parent
adaf2c6d4f
commit
b38bc652ac
@ -14,9 +14,11 @@ from einops import rearrange, repeat
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.transformers import replace_linear_class
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
@ -379,6 +381,37 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str):
|
||||
"""Return (parent_module, final_attr_name) for a dotted module path."""
|
||||
names = dotted_name.split('.')
|
||||
parent = root
|
||||
for n in names[:-1]:
|
||||
parent = getattr(parent, n)
|
||||
return parent, names[-1]
|
||||
|
||||
#patch for timm ViT instance to support tensor parallel
|
||||
def patch_vit_for_tp(self, vit: torch.nn.Module,
|
||||
quant_config: QuantizationConfig):
|
||||
try:
|
||||
import timm
|
||||
except ImportError as e:
|
||||
raise ImportError("Please install timm") from e
|
||||
|
||||
for name, module in vit.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
parent, attr_name = self._get_parent_and_attr(vit, name)
|
||||
if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1":
|
||||
new_linear = replace_linear_class(module, "colwise",
|
||||
quant_config)
|
||||
setattr(parent, attr_name, new_linear)
|
||||
elif isinstance(parent,
|
||||
timm.layers.Mlp) and attr_name == "fc2":
|
||||
new_linear = replace_linear_class(module, "rowwise",
|
||||
quant_config)
|
||||
setattr(parent, attr_name, new_linear)
|
||||
|
||||
return vit
|
||||
|
||||
def _init_vision_module(
|
||||
self,
|
||||
vision_config: VisionEncoderConfig,
|
||||
@ -388,8 +421,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# TODO: refactor vision model through timm wrapper from transformers
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
raise ImportError("Please install timm") from ImportError
|
||||
except ImportError as e:
|
||||
raise ImportError("Please install timm") from e
|
||||
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
model = timm.create_model(
|
||||
@ -400,6 +433,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
dynamic_img_pad=True,
|
||||
)
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
model = self.patch_vit_for_tp(model, quant_config)
|
||||
|
||||
model = model.to(dtype=torch.get_default_dtype())
|
||||
return model
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user