[Model][MM] Extract conv layer as CustomOp (#28455)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Shanshan Shen 2025-11-14 19:16:13 +08:00 committed by GitHub
parent 360bd8762f
commit 41b92f7d38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 277 additions and 66 deletions

View File

@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Conv Layer Class."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import is_torch_equal
class ConvLayerBase(CustomOp):
"""Conv layer base class."""
num_dim: int
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
kernel_size = (
(kernel_size,) * self.num_dim
if isinstance(kernel_size, int)
else kernel_size
)
stride = (stride,) * self.num_dim if isinstance(stride, int) else stride
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self.enable_linear = (
(self.kernel_size == self.stride)
and not any(self.padding)
and self.groups == 1
)
self.input_size = in_channels * math.prod(self.kernel_size)
self.weight = nn.Parameter(
torch.empty(
out_channels,
in_channels // groups,
*kernel_size,
dtype=params_dtype,
),
)
if bias:
self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype))
else:
self.register_parameter("bias", None)
def extra_repr(self) -> str:
s = f"in_channels={self.in_channels}, "
s += f"out_channels={self.out_channels}, "
s += f"kernel_size={self.kernel_size}, "
s += f"stride={self.stride}, "
s += f"padding={self.padding}, "
s += f"bias={self.bias is not None}"
return s
@CustomOp.register("conv2d")
class Conv2dLayer(ConvLayerBase):
"""Conv layer with Conv2d."""
num_dim = 2
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
B, C, H, W = x.shape
K1, K2 = self.kernel_size
H, W = H // K1, W // K2
x = x.unfold(2, K1, K1).unfold(3, K2, K2)
x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2)
return x
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
x = F.conv2d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, height, width)"""
assert x.dim() == 4
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# By default, we use CUDNN's convolution ops with optimization.
return self._forward_conv(x)
class CausalConv2dLayer(Conv2dLayer):
"""
A causal version of nn.Conv2d where each location in the 2D matrix would
have no access to locations on its right or down
All arguments are the same as nn.Conv2d except padding which should be
set as None
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
if padding is not None:
raise ValueError(
"Argument padding should be set to None for CausalConv2dLayer."
)
self._left_padding: int = kernel_size - 1
self._right_padding: int = stride - 1
padding = 0
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
params_dtype=params_dtype,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0))
x = super().forward(x)
return x
@CustomOp.register("conv3d")
class Conv3dLayer(ConvLayerBase):
"""Conv layer with Conv3d."""
num_dim = 3
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
B, C, T, H, W = x.shape
K1, K2, K3 = self.kernel_size
T, H, W = T // K1, H // K2, W // K3
x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3)
x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3)
return x
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
x = F.conv3d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, time, height, width)"""
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a
# significant performance regression.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# By default, we use CUDNN's convolution ops with optimization.
if self.enable_linear and is_torch_equal("2.9.0"):
return self._forward_mulmat(x)
return self._forward_conv(x)

View File

@ -20,6 +20,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -315,7 +316,7 @@ class CLIPVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,

View File

@ -56,12 +56,12 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -103,7 +103,6 @@ from .utils import (
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
@ -486,15 +485,18 @@ class Glm4vVisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
@ -893,9 +895,6 @@ class Glm4vVisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -26,7 +26,6 @@
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias
@ -56,12 +55,12 @@ from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -110,7 +109,6 @@ from .utils import (
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
@ -525,15 +523,18 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
@ -957,9 +958,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -25,7 +25,6 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
@ -54,9 +53,9 @@ from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -107,7 +106,6 @@ from .utils import (
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
@ -566,15 +564,18 @@ class Qwen2VisionPatchEmbed(nn.Module):
self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x
@ -844,9 +845,6 @@ class Qwen2VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -22,7 +22,6 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Any
@ -54,9 +53,9 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -102,7 +101,6 @@ from .utils import (
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
@ -138,16 +136,18 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = self.proj(x)
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
@ -566,9 +566,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -24,7 +24,6 @@
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from itertools import islice
@ -57,9 +56,9 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -114,7 +113,6 @@ from .utils import (
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
@ -139,15 +137,18 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
@ -579,9 +580,6 @@ class Qwen3_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -550,19 +550,3 @@ def get_llm_pos_ids_for_vision(
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
return linear_weight