[Model] Support dp on ViT on GLM-4.5V (#23168)

Signed-off-by: David Chen <530634352@qq.com>
This commit is contained in:
WeiQing Chen 2025-09-02 18:48:18 +08:00 committed by GitHub
parent fad73be1a5
commit 2f0bab3f26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 145 additions and 59 deletions

View File

@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
Known supported models: Known supported models:
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
- Kimi-VL (<gh-pr:23817>) - Kimi-VL (<gh-pr:23817>)
- Llama4 (<gh-pr:18368>) - Llama4 (<gh-pr:18368>)
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>) - MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)

View File

@ -45,15 +45,20 @@ from transformers.models.glm4v.video_processing_glm4v import (
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -66,6 +71,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs] Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
# === Vision Encoder === # # ==== Vision Encoder ==== #
class Glm4vVisionMLP(nn.Module): class Glm4vVisionMLP(nn.Module):
@ -165,19 +171,23 @@ class Glm4vVisionMLP(nn.Module):
bias: bool = False, bias: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( cls_gate_up = (MergedReplicatedLinear
input_size=in_features, if use_data_parallel else MergedColumnParallelLinear)
output_sizes=[hidden_features] * 2, self.gate_up_proj = cls_gate_up(input_size=in_features,
bias=bias, output_sizes=[hidden_features] * 2,
quant_config=quant_config, bias=bias,
prefix=f"{prefix}.gate_up_proj") quant_config=quant_config,
self.down_proj = RowParallelLinear(hidden_features, prefix=f"{prefix}.gate_up_proj")
in_features, cls_down = (ReplicatedLinear
bias=bias, if use_data_parallel else RowParallelLinear)
quant_config=quant_config, self.down_proj = cls_down(hidden_features,
prefix=f"{prefix}.down_proj") in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
@ -218,33 +228,54 @@ class Glm4vVisionAttention(nn.Module):
projection_size: int, projection_size: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
self.tp_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads) projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide( self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size) num_heads, self.tp_size)
self.qkv = QKVParallelLinear( if use_data_parallel:
hidden_size=embed_dim, self.qkv = ReplicatedLinear(
head_size=self.hidden_size_per_attention_head, input_size=embed_dim,
total_num_heads=num_heads, output_size=3 * projection_size,
total_num_kv_heads=num_heads, bias=False,
bias=False, quant_config=quant_config,
quant_config=quant_config, # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
# Change qkv prefix to align with GLM-4.5V-FP8 quantization config prefix=f"{prefix}.qkv_proj"
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", if quant_config else f"{prefix}.qkv",
) )
self.proj = RowParallelLinear( self.proj = ReplicatedLinear(
input_size=projection_size, input_size=projection_size,
output_size=embed_dim, output_size=embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
bias=False, bias=False,
) )
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
# Detect attention implementation. # Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@ -375,6 +406,7 @@ class Glm4vVisionBlock(nn.Module):
norm_layer: Optional[Callable[[int], nn.Module]] = None, norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -387,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
) )
self.mlp = Glm4vVisionMLP( self.mlp = Glm4vVisionMLP(
dim, dim,
@ -394,6 +427,7 @@ class Glm4vVisionBlock(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
) )
def forward( def forward(
@ -456,24 +490,40 @@ class Glm4vPatchMerger(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = d_model self.hidden_size = d_model
self.proj = ColumnParallelLinear(self.hidden_size, if use_data_parallel:
self.hidden_size, self.proj = ReplicatedLinear(
bias=bias, input_size=self.hidden_size,
gather_output=True, output_size=self.hidden_size,
quant_config=quant_config, bias=bias,
prefix=f"{prefix}.proj") quant_config=quant_config,
prefix=f"{prefix}.proj",
)
else:
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.post_projection_norm = nn.LayerNorm(self.hidden_size) self.post_projection_norm = nn.LayerNorm(self.hidden_size)
self.gate_up_proj = MergedColumnParallelLinear( cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(
input_size=self.hidden_size, input_size=self.hidden_size,
output_sizes=[context_dim] * 2, output_sizes=[context_dim] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(
context_dim, context_dim,
self.hidden_size, self.hidden_size,
bias=bias, bias=bias,
@ -548,14 +598,33 @@ class Glm4vVisionEmbeddings(nn.Module):
dtype=torch.float32)) dtype=torch.float32))
# Calculate target dimensions for each patch # Calculate target dimensions for each patch
target_h = torch.cat([ # Add bounds checking for data parallel mode
image_shapes[i, 1].repeat(lengths[i]) if len(lengths) > image_shapes.shape[0]:
for i in range(len(lengths)) # In data parallel mode, some GPUs might not have all
]).to(device=device, dtype=torch.float32) # image shapes
target_w = torch.cat([ # Use available image shapes, cycling if necessary
image_shapes[i, 2].repeat(lengths[i]) target_h_list = []
for i in range(len(lengths)) target_w_list = []
]).to(device=device, dtype=torch.float32) for i in range(len(lengths)):
# Cycle through available shapes
shape_idx = i % image_shapes.shape[0]
target_h_list.append(image_shapes[shape_idx,
1].repeat(lengths[i]))
target_w_list.append(image_shapes[shape_idx,
2].repeat(lengths[i]))
target_h = torch.cat(target_h_list).to(device=device,
dtype=torch.float32)
target_w = torch.cat(target_w_list).to(device=device,
dtype=torch.float32)
else:
target_h = torch.cat([
image_shapes[i, 1].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)
target_w = torch.cat([
image_shapes[i, 2].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)
# Normalize coordinates to [-1, 1] range for grid_sample # Normalize coordinates to [-1, 1] range for grid_sample
h_coords = h_coords.to(device=device, dtype=torch.float32) h_coords = h_coords.to(device=device, dtype=torch.float32)
@ -629,6 +698,7 @@ class Glm4vVisionTransformer(nn.Module):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -638,6 +708,7 @@ class Glm4vVisionTransformer(nn.Module):
depth = vision_config.depth depth = vision_config.depth
self.hidden_size = vision_config.hidden_size self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads self.num_heads = vision_config.num_heads
self.use_data_parallel = use_data_parallel
self.patch_size = vision_config.patch_size self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_size = vision_config.spatial_merge_size
@ -661,6 +732,7 @@ class Glm4vVisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=self.use_data_parallel,
) for layer_idx in range(depth) ) for layer_idx in range(depth)
]) ])
self.merger = Glm4vPatchMerger( self.merger = Glm4vPatchMerger(
@ -669,6 +741,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
use_data_parallel=self.use_data_parallel,
) )
self.embeddings = Glm4vVisionEmbeddings(vision_config) self.embeddings = Glm4vVisionEmbeddings(vision_config)
@ -731,8 +804,11 @@ class Glm4vVisionTransformer(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
grid_thw: torch.Tensor, grid_thw: list[list[int]],
) -> torch.Tensor: ) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
# patchify # patchify
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x) x = self.patch_embed(x)
@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.visual.": "visual.", "model.visual.": "visual.",
}) })
supports_encoder_tp_data = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"): if modality.startswith("image"):
@ -1267,12 +1345,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Glm4vVisionTransformer( self.visual = Glm4vVisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5), norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
) )
if config.model_type == "glm4v": if config.model_type == "glm4v":
@ -1382,8 +1462,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else: else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw) if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw.tolist(),
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values,
grid_thw=grid_thw.tolist())
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist()) return image_embeds.split(sizes.tolist())
@ -1393,23 +1479,22 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
grid_thw = video_input["video_grid_thw"] grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
device = self.visual.device
flat_grid_thw = torch.cat([
torch.tensor([[1, h, w]] * t, device=device)
for t, h, w in grid_thw
])
if video_input["type"] == "video_embeds": if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype) video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else: else:
pixel_values_videos = video_input["pixel_values_videos"].type( pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype) self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, if self.use_data_parallel:
grid_thw=flat_grid_thw) return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values_videos,
grid_thw.tolist(),
rope_type="rope_3d")
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw.tolist())
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist()) return video_embeds.split(sizes.tolist())
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: