[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer (#27418)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py 2025-10-24 15:32:47 +08:00 committed by GitHub
parent 88d3141ec6
commit 42efe609ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 97 additions and 42 deletions

View File

@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -98,7 +99,11 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -478,18 +483,15 @@ class Glm4vVisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
@ -887,6 +889,9 @@ class Glm4vVisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue

View File

@ -26,6 +26,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
@ -56,6 +57,7 @@ from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -98,7 +100,11 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -532,18 +538,15 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False, bias=False,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
@ -950,6 +953,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue

View File

@ -25,6 +25,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
@ -53,7 +54,11 @@ from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import ( from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function, dispatch_rotary_emb_function,
@ -100,7 +105,11 @@ from .utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -561,18 +570,15 @@ class Qwen2VisionPatchEmbed(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
embed_dim, embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False, bias=False,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x return x
@ -835,6 +841,9 @@ class Qwen2VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue

View File

@ -22,6 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part).""" """Inference-only Qwen3-Omni-Moe model (thinker part)."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any from typing import Any
@ -53,7 +54,11 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -98,7 +103,11 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
try: try:
import flash_attn import flash_attn
@ -131,18 +140,16 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x)
x = self.proj(x).view(L, self.hidden_size)
return x return x
@ -559,6 +566,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue

View File

@ -24,6 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" """Inference-only Qwen3VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from itertools import islice from itertools import islice
@ -56,7 +57,11 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -107,7 +112,11 @@ from .utils import (
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
maybe_prefix, maybe_prefix,
) )
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -129,18 +138,15 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size) kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d( self.proj = ReplicatedLinear(
in_channels, in_channels * math.prod(kernel_size),
hidden_size, hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True, bias=True,
return_bias=False,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x return x
@ -576,6 +582,9 @@ class Qwen3_VisionTransformer(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue

View File

@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision(
llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids return llm_pos_ids
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
return linear_weight