[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer (#27418)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py 2025-10-24 15:32:47 +08:00 committed by GitHub
parent 88d3141ec6
commit 42efe609ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 97 additions and 42 deletions

View File

@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -98,7 +99,11 @@ from .utils import (
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__)
@ -478,18 +483,15 @@ class Glm4vVisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x
@ -887,6 +889,9 @@ class Glm4vVisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -26,6 +26,7 @@
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias
@ -56,6 +57,7 @@ from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -98,7 +100,11 @@ from .utils import (
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__)
@ -532,18 +538,15 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x
@ -950,6 +953,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -25,6 +25,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
@ -53,7 +54,11 @@ from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function,
@ -100,7 +105,11 @@ from .utils import (
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__)
@ -561,18 +570,15 @@ class Qwen2VisionPatchEmbed(nn.Module):
self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
x = self.proj(x)
return x
@ -835,6 +841,9 @@ class Qwen2VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -22,6 +22,7 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Any
@ -53,7 +54,11 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -98,7 +103,11 @@ from .utils import (
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend
from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
try:
import flash_attn
@ -131,18 +140,16 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x
@ -559,6 +566,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -24,6 +24,7 @@
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from itertools import islice
@ -56,7 +57,11 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -107,7 +112,11 @@ from .utils import (
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__)
@ -129,18 +138,15 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x
@ -576,6 +582,9 @@ class Qwen3_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision(
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
return linear_weight