From 42efe609ba75eb1b0bc06ae635778b2bc0aa4e7a Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 24 Oct 2025 15:32:47 +0800 Subject: [PATCH] [MM][Bugfix] Replace `PatchEmbed`'s conv3d to linear layer (#27418) Signed-off-by: Isotr0py Co-authored-by: Roger Wang --- vllm/model_executor/models/glm4_1v.py | 21 +++++++++------ vllm/model_executor/models/qwen2_5_vl.py | 22 +++++++++------ vllm/model_executor/models/qwen2_vl.py | 27 ++++++++++++------- .../models/qwen3_omni_moe_thinker.py | 26 ++++++++++++------ vllm/model_executor/models/qwen3_vl.py | 27 ++++++++++++------- vllm/model_executor/models/vision.py | 16 +++++++++++ 6 files changed, 97 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 85f489837701..9f1439e21ef7 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -98,7 +99,11 @@ from .utils import ( init_vllm_registered_model, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -478,18 +483,15 @@ class Glm4vVisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), hidden_size, - kernel_size=kernel_size, - stride=kernel_size, bias=True, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x @@ -887,6 +889,9 @@ class Glm4vVisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 1b3ce3edd47b..c657b06d4355 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -26,6 +26,7 @@ # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" +import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import lru_cache, partial from typing import Annotated, Any, Literal, TypeAlias @@ -56,6 +57,7 @@ from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -98,7 +100,11 @@ from .utils import ( init_vllm_registered_model, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -532,18 +538,15 @@ class Qwen2_5_VisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), hidden_size, - kernel_size=kernel_size, - stride=kernel_size, bias=False, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x @@ -950,6 +953,9 @@ class Qwen2_5_VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 94436fe009f1..61f7970d56f6 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,6 +25,7 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, TypeAlias @@ -53,7 +54,11 @@ from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding.common import ( dispatch_rotary_emb_function, @@ -100,7 +105,11 @@ from .utils import ( init_vllm_registered_model, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -561,18 +570,15 @@ class Qwen2VisionPatchEmbed(nn.Module): self.embed_dim = embed_dim kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), embed_dim, - kernel_size=kernel_size, - stride=kernel_size, bias=False, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.embed_dim) + x = self.proj(x) return x @@ -835,6 +841,9 @@ class Qwen2VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index cdef1bdaedc5..3485adea6ac8 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -22,6 +22,7 @@ # limitations under the License. """Inference-only Qwen3-Omni-Moe model (thinker part).""" +import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Any @@ -53,7 +54,11 @@ from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -98,7 +103,11 @@ from .utils import ( _merge_multimodal_embeddings, maybe_prefix, ) -from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend +from .vision import ( + conv3d_to_linear_weight, + get_llm_pos_ids_for_vision, + get_vit_attn_backend, +) try: import flash_attn @@ -131,18 +140,16 @@ class Qwen3_VisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), hidden_size, - kernel_size=kernel_size, - stride=kernel_size, bias=True, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x @@ -559,6 +566,9 @@ class Qwen3Omni_VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index e9e16762e525..fb2af187ebf1 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -24,6 +24,7 @@ # limitations under the License. """Inference-only Qwen3VL model compatible with HuggingFace weights.""" +import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from itertools import islice @@ -56,7 +57,11 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -107,7 +112,11 @@ from .utils import ( _merge_multimodal_embeddings, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -129,18 +138,15 @@ class Qwen3_VisionPatchEmbed(nn.Module): self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), hidden_size, - kernel_size=kernel_size, - stride=kernel_size, bias=True, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x @@ -576,6 +582,9 @@ class Qwen3_VisionTransformer(nn.Module): loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 8bbb06f72772..b5f6c60514c0 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision( llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) return llm_pos_ids + + +# Due to a performance regression with Conv3D in PyTorch2.9, we reshape +# Conv3D weights to Linear weights for better performance. +# See: https://github.com/vllm-project/vllm/issues/27406 +# and https://github.com/pytorch/pytorch/issues/166122 +# FIXME(Isotr0py): Revert the PR introduces this workaround +# (https://github.com/vllm-project/vllm/pull/27418), +# once the performance issue is resolved in PyTorch. +def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor: + """ + Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride. + """ + out_channels, in_channels, kt, kh, kw = conv3d_weight.shape + linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw) + return linear_weight