mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 14:05:01 +08:00
[Model] Support tensor parallel for timm ViT in Deepseek_vl2 (#21494)
Signed-off-by: wzqd <1057337859@qq.com>
This commit is contained in:
parent
adaf2c6d4f
commit
b38bc652ac
@ -14,9 +14,11 @@ from einops import rearrange, repeat
|
|||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.model_executor.models.transformers import replace_linear_class
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalKwargs, NestedTensors)
|
MultiModalKwargs, NestedTensors)
|
||||||
@ -379,6 +381,37 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str):
|
||||||
|
"""Return (parent_module, final_attr_name) for a dotted module path."""
|
||||||
|
names = dotted_name.split('.')
|
||||||
|
parent = root
|
||||||
|
for n in names[:-1]:
|
||||||
|
parent = getattr(parent, n)
|
||||||
|
return parent, names[-1]
|
||||||
|
|
||||||
|
#patch for timm ViT instance to support tensor parallel
|
||||||
|
def patch_vit_for_tp(self, vit: torch.nn.Module,
|
||||||
|
quant_config: QuantizationConfig):
|
||||||
|
try:
|
||||||
|
import timm
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Please install timm") from e
|
||||||
|
|
||||||
|
for name, module in vit.named_modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
parent, attr_name = self._get_parent_and_attr(vit, name)
|
||||||
|
if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1":
|
||||||
|
new_linear = replace_linear_class(module, "colwise",
|
||||||
|
quant_config)
|
||||||
|
setattr(parent, attr_name, new_linear)
|
||||||
|
elif isinstance(parent,
|
||||||
|
timm.layers.Mlp) and attr_name == "fc2":
|
||||||
|
new_linear = replace_linear_class(module, "rowwise",
|
||||||
|
quant_config)
|
||||||
|
setattr(parent, attr_name, new_linear)
|
||||||
|
|
||||||
|
return vit
|
||||||
|
|
||||||
def _init_vision_module(
|
def _init_vision_module(
|
||||||
self,
|
self,
|
||||||
vision_config: VisionEncoderConfig,
|
vision_config: VisionEncoderConfig,
|
||||||
@ -388,8 +421,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# TODO: refactor vision model through timm wrapper from transformers
|
# TODO: refactor vision model through timm wrapper from transformers
|
||||||
try:
|
try:
|
||||||
import timm
|
import timm
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
raise ImportError("Please install timm") from ImportError
|
raise ImportError("Please install timm") from e
|
||||||
|
|
||||||
with set_default_torch_dtype(torch.float16):
|
with set_default_torch_dtype(torch.float16):
|
||||||
model = timm.create_model(
|
model = timm.create_model(
|
||||||
@ -400,6 +433,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
dynamic_img_pad=True,
|
dynamic_img_pad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
|
model = self.patch_vit_for_tp(model, quant_config)
|
||||||
|
|
||||||
model = model.to(dtype=torch.get_default_dtype())
|
model = model.to(dtype=torch.get_default_dtype())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user