mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[Model] Support dp on ViT on GLM-4.5V (#23168)
Signed-off-by: David Chen <530634352@qq.com>
This commit is contained in:
parent
fad73be1a5
commit
2f0bab3f26
@ -174,6 +174,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
|
|||||||
|
|
||||||
Known supported models:
|
Known supported models:
|
||||||
|
|
||||||
|
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
|
||||||
- Kimi-VL (<gh-pr:23817>)
|
- Kimi-VL (<gh-pr:23817>)
|
||||||
- Llama4 (<gh-pr:18368>)
|
- Llama4 (<gh-pr:18368>)
|
||||||
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
|
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
|
||||||
|
|||||||
@ -45,15 +45,20 @@ from transformers.models.glm4v.video_processing_glm4v import (
|
|||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
|
parallel_state)
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
# yapf: disable
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
|
MergedReplicatedLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
@ -66,6 +71,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, PromptReplacement,
|
||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
@ -153,7 +159,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
|
|||||||
|
|
||||||
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
|
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
|
||||||
|
|
||||||
# === Vision Encoder === #
|
# ==== Vision Encoder ==== #
|
||||||
|
|
||||||
|
|
||||||
class Glm4vVisionMLP(nn.Module):
|
class Glm4vVisionMLP(nn.Module):
|
||||||
@ -165,19 +171,23 @@ class Glm4vVisionMLP(nn.Module):
|
|||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
cls_gate_up = (MergedReplicatedLinear
|
||||||
input_size=in_features,
|
if use_data_parallel else MergedColumnParallelLinear)
|
||||||
output_sizes=[hidden_features] * 2,
|
self.gate_up_proj = cls_gate_up(input_size=in_features,
|
||||||
bias=bias,
|
output_sizes=[hidden_features] * 2,
|
||||||
quant_config=quant_config,
|
bias=bias,
|
||||||
prefix=f"{prefix}.gate_up_proj")
|
quant_config=quant_config,
|
||||||
self.down_proj = RowParallelLinear(hidden_features,
|
prefix=f"{prefix}.gate_up_proj")
|
||||||
in_features,
|
cls_down = (ReplicatedLinear
|
||||||
bias=bias,
|
if use_data_parallel else RowParallelLinear)
|
||||||
quant_config=quant_config,
|
self.down_proj = cls_down(hidden_features,
|
||||||
prefix=f"{prefix}.down_proj")
|
in_features,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj")
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
@ -218,33 +228,54 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
projection_size: int,
|
projection_size: int,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
|
self.tp_size = (1 if use_data_parallel else
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
projection_size, num_heads)
|
projection_size, num_heads)
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, self.tp_size)
|
num_heads, self.tp_size)
|
||||||
|
|
||||||
self.qkv = QKVParallelLinear(
|
if use_data_parallel:
|
||||||
hidden_size=embed_dim,
|
self.qkv = ReplicatedLinear(
|
||||||
head_size=self.hidden_size_per_attention_head,
|
input_size=embed_dim,
|
||||||
total_num_heads=num_heads,
|
output_size=3 * projection_size,
|
||||||
total_num_kv_heads=num_heads,
|
bias=False,
|
||||||
bias=False,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
||||||
# Change qkv prefix to align with GLM-4.5V-FP8 quantization config
|
prefix=f"{prefix}.qkv_proj"
|
||||||
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
|
if quant_config else f"{prefix}.qkv",
|
||||||
)
|
)
|
||||||
self.proj = RowParallelLinear(
|
self.proj = ReplicatedLinear(
|
||||||
input_size=projection_size,
|
input_size=projection_size,
|
||||||
output_size=embed_dim,
|
output_size=embed_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.qkv = QKVParallelLinear(
|
||||||
|
hidden_size=embed_dim,
|
||||||
|
head_size=self.hidden_size_per_attention_head,
|
||||||
|
total_num_heads=num_heads,
|
||||||
|
total_num_kv_heads=num_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
||||||
|
prefix=f"{prefix}.qkv_proj"
|
||||||
|
if quant_config else f"{prefix}.qkv",
|
||||||
|
)
|
||||||
|
self.proj = RowParallelLinear(
|
||||||
|
input_size=projection_size,
|
||||||
|
output_size=embed_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.proj",
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||||
@ -375,6 +406,7 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -387,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.mlp = Glm4vVisionMLP(
|
self.mlp = Glm4vVisionMLP(
|
||||||
dim,
|
dim,
|
||||||
@ -394,6 +427,7 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -456,24 +490,40 @@ class Glm4vPatchMerger(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = d_model
|
self.hidden_size = d_model
|
||||||
self.proj = ColumnParallelLinear(self.hidden_size,
|
if use_data_parallel:
|
||||||
self.hidden_size,
|
self.proj = ReplicatedLinear(
|
||||||
bias=bias,
|
input_size=self.hidden_size,
|
||||||
gather_output=True,
|
output_size=self.hidden_size,
|
||||||
quant_config=quant_config,
|
bias=bias,
|
||||||
prefix=f"{prefix}.proj")
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.proj",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
gather_output=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.proj",
|
||||||
|
)
|
||||||
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
cls_gate_up = (MergedReplicatedLinear
|
||||||
|
if use_data_parallel else MergedColumnParallelLinear)
|
||||||
|
self.gate_up_proj = cls_gate_up(
|
||||||
input_size=self.hidden_size,
|
input_size=self.hidden_size,
|
||||||
output_sizes=[context_dim] * 2,
|
output_sizes=[context_dim] * 2,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
cls_down = (ReplicatedLinear
|
||||||
|
if use_data_parallel else RowParallelLinear)
|
||||||
|
self.down_proj = cls_down(
|
||||||
context_dim,
|
context_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -548,14 +598,33 @@ class Glm4vVisionEmbeddings(nn.Module):
|
|||||||
dtype=torch.float32))
|
dtype=torch.float32))
|
||||||
|
|
||||||
# Calculate target dimensions for each patch
|
# Calculate target dimensions for each patch
|
||||||
target_h = torch.cat([
|
# Add bounds checking for data parallel mode
|
||||||
image_shapes[i, 1].repeat(lengths[i])
|
if len(lengths) > image_shapes.shape[0]:
|
||||||
for i in range(len(lengths))
|
# In data parallel mode, some GPUs might not have all
|
||||||
]).to(device=device, dtype=torch.float32)
|
# image shapes
|
||||||
target_w = torch.cat([
|
# Use available image shapes, cycling if necessary
|
||||||
image_shapes[i, 2].repeat(lengths[i])
|
target_h_list = []
|
||||||
for i in range(len(lengths))
|
target_w_list = []
|
||||||
]).to(device=device, dtype=torch.float32)
|
for i in range(len(lengths)):
|
||||||
|
# Cycle through available shapes
|
||||||
|
shape_idx = i % image_shapes.shape[0]
|
||||||
|
target_h_list.append(image_shapes[shape_idx,
|
||||||
|
1].repeat(lengths[i]))
|
||||||
|
target_w_list.append(image_shapes[shape_idx,
|
||||||
|
2].repeat(lengths[i]))
|
||||||
|
target_h = torch.cat(target_h_list).to(device=device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
target_w = torch.cat(target_w_list).to(device=device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
target_h = torch.cat([
|
||||||
|
image_shapes[i, 1].repeat(lengths[i])
|
||||||
|
for i in range(len(lengths))
|
||||||
|
]).to(device=device, dtype=torch.float32)
|
||||||
|
target_w = torch.cat([
|
||||||
|
image_shapes[i, 2].repeat(lengths[i])
|
||||||
|
for i in range(len(lengths))
|
||||||
|
]).to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
# Normalize coordinates to [-1, 1] range for grid_sample
|
# Normalize coordinates to [-1, 1] range for grid_sample
|
||||||
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
||||||
@ -629,6 +698,7 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
use_data_parallel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -638,6 +708,7 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
depth = vision_config.depth
|
depth = vision_config.depth
|
||||||
self.hidden_size = vision_config.hidden_size
|
self.hidden_size = vision_config.hidden_size
|
||||||
self.num_heads = vision_config.num_heads
|
self.num_heads = vision_config.num_heads
|
||||||
|
self.use_data_parallel = use_data_parallel
|
||||||
|
|
||||||
self.patch_size = vision_config.patch_size
|
self.patch_size = vision_config.patch_size
|
||||||
self.spatial_merge_size = vision_config.spatial_merge_size
|
self.spatial_merge_size = vision_config.spatial_merge_size
|
||||||
@ -661,6 +732,7 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
|
use_data_parallel=self.use_data_parallel,
|
||||||
) for layer_idx in range(depth)
|
) for layer_idx in range(depth)
|
||||||
])
|
])
|
||||||
self.merger = Glm4vPatchMerger(
|
self.merger = Glm4vPatchMerger(
|
||||||
@ -669,6 +741,7 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
bias=False,
|
bias=False,
|
||||||
prefix=f"{prefix}.merger",
|
prefix=f"{prefix}.merger",
|
||||||
|
use_data_parallel=self.use_data_parallel,
|
||||||
)
|
)
|
||||||
self.embeddings = Glm4vVisionEmbeddings(vision_config)
|
self.embeddings = Glm4vVisionEmbeddings(vision_config)
|
||||||
|
|
||||||
@ -731,8 +804,11 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
grid_thw: torch.Tensor,
|
grid_thw: list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Convert grid_thw to tensor (always expecting list format now)
|
||||||
|
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
|
||||||
|
|
||||||
# patchify
|
# patchify
|
||||||
x = x.to(device=self.device, dtype=self.dtype)
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
@ -1250,6 +1326,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
"model.visual.": "visual.",
|
"model.visual.": "visual.",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
supports_encoder_tp_data = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
if modality.startswith("image"):
|
if modality.startswith("image"):
|
||||||
@ -1267,12 +1345,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||||
|
|
||||||
self.visual = Glm4vVisionTransformer(
|
self.visual = Glm4vVisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
|
use_data_parallel=self.use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.model_type == "glm4v":
|
if config.model_type == "glm4v":
|
||||||
@ -1382,8 +1462,14 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||||
else:
|
else:
|
||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
if self.use_data_parallel:
|
||||||
|
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||||
|
pixel_values,
|
||||||
|
grid_thw.tolist(),
|
||||||
|
rope_type="rope_3d")
|
||||||
|
else:
|
||||||
|
image_embeds = self.visual(pixel_values,
|
||||||
|
grid_thw=grid_thw.tolist())
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||||
return image_embeds.split(sizes.tolist())
|
return image_embeds.split(sizes.tolist())
|
||||||
@ -1393,23 +1479,22 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
grid_thw = video_input["video_grid_thw"]
|
grid_thw = video_input["video_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
|
|
||||||
device = self.visual.device
|
|
||||||
flat_grid_thw = torch.cat([
|
|
||||||
torch.tensor([[1, h, w]] * t, device=device)
|
|
||||||
for t, h, w in grid_thw
|
|
||||||
])
|
|
||||||
if video_input["type"] == "video_embeds":
|
if video_input["type"] == "video_embeds":
|
||||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||||
self.visual.dtype)
|
self.visual.dtype)
|
||||||
video_embeds = self.visual(pixel_values_videos,
|
if self.use_data_parallel:
|
||||||
grid_thw=flat_grid_thw)
|
return run_dp_sharded_mrope_vision_model(self.visual,
|
||||||
|
pixel_values_videos,
|
||||||
|
grid_thw.tolist(),
|
||||||
|
rope_type="rope_3d")
|
||||||
|
else:
|
||||||
|
video_embeds = self.visual(pixel_values_videos,
|
||||||
|
grid_thw=grid_thw.tolist())
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||||
|
|
||||||
return video_embeds.split(sizes.tolist())
|
return video_embeds.split(sizes.tolist())
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user