[Models] Replace all nn.Conv2d with vLLM's Conv2dLayer (#28842)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-19 02:56:04 +08:00 committed by GitHub
parent c64c0b78de
commit e4bb2684bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 83 additions and 45 deletions

View File

@ -3,6 +3,7 @@
"""Conv Layer Class."""
import math
from typing import Literal
import torch
import torch.nn as nn
@ -23,11 +24,11 @@ class ConvLayerBase(CustomOp):
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] = 0,
padding: int | tuple[int, ...] | Literal["same", "valid"] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
@ -36,6 +37,22 @@ class ConvLayerBase(CustomOp):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
valid_padding_strings = {"same", "valid"}
if isinstance(padding, str) and padding not in valid_padding_strings:
raise ValueError(
f"Invalid padding string '{padding}'. "
f"Expected one of {valid_padding_strings}."
)
if padding == "same":
padding = (
kernel_size // 2
if isinstance(kernel_size, int)
else tuple(k // 2 for k in kernel_size)
)
elif padding == "valid":
padding = 0
kernel_size = (
(kernel_size,) * self.num_dim
if isinstance(kernel_size, int)
@ -45,6 +62,9 @@ class ConvLayerBase(CustomOp):
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
if padding == "same" and any(s != 1 for s in stride):
raise ValueError("padding='same' is not supported for strided convolutions")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size

View File

@ -12,6 +12,7 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@ -58,7 +59,7 @@ class AIMv2SwiGLUFFN(nn.Module):
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),

View File

@ -12,6 +12,7 @@ from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -47,7 +48,7 @@ class BlipVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,

View File

@ -22,6 +22,7 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@ -549,7 +550,7 @@ class ChameleonVQVAEVectorQuantizer(nn.Module):
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
self.conv = Conv2dLayer(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
@ -577,23 +578,23 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
self.norm1 = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.conv1 = torch.nn.Conv2d(
self.conv1 = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = torch.nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = torch.nn.Conv2d(
self.conv2 = Conv2dLayer(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
self.conv_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
self.nin_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
@ -626,16 +627,16 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = torch.nn.Conv2d(
self.q = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
self.k = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
self.v = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
self.proj_out = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
@ -681,7 +682,7 @@ class ChameleonVQVAEEncoder(nn.Module):
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = torch.nn.Conv2d(
self.conv_in = Conv2dLayer(
in_channels, base_channels, kernel_size=3, stride=1, padding=1
)
@ -738,7 +739,7 @@ class ChameleonVQVAEEncoder(nn.Module):
self.norm_out = torch.nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = torch.nn.Conv2d(
self.conv_out = Conv2dLayer(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
@ -779,10 +780,8 @@ class ChameleonVQVAE(nn.Module):
super().__init__()
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
config.embed_dim, config.latent_channels, 1
)
self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
def encode(

View File

@ -19,6 +19,7 @@ import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -133,14 +134,14 @@ class ImageEncoderViT(nn.Module):
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
Conv2dLayer(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
Conv2dLayer(
out_chans,
out_chans,
kernel_size=3,
@ -150,8 +151,10 @@ class ImageEncoderViT(nn.Module):
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
self.net_2 = Conv2dLayer(
256, 512, kernel_size=3, stride=2, padding=1, bias=False
)
self.net_3 = Conv2dLayer(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)
@ -500,7 +503,7 @@ class PatchEmbed(nn.Module):
"""
super().__init__()
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)

View File

@ -22,6 +22,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@ -471,7 +472,7 @@ class DotsPatchEmbed(nn.Module):
self.temporal_patch_size = config.temporal_patch_size
self.embed_dim = config.embed_dim
self.config = config
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
config.num_channels,
config.embed_dim,
kernel_size=(config.patch_size, config.patch_size),

View File

@ -56,7 +56,7 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@ -734,7 +734,7 @@ class Glm4vVisionTransformer(nn.Module):
self.post_conv_layernorm = RMSNorm(
vision_config.hidden_size, eps=vision_config.rms_norm_eps
)
self.downsample = nn.Conv2d(
self.downsample = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=vision_config.out_hidden_size,
kernel_size=vision_config.spatial_merge_size,

View File

@ -24,6 +24,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
@ -78,7 +79,7 @@ class GLMVImagePixelInputs(TensorSchema):
class EVA2CLIPPatchEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
config.in_channels,
config.hidden_size,
kernel_size=config.patch_size,
@ -333,7 +334,7 @@ class EVA2CLIPModel(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.linear_proj",
)
self.conv = nn.Conv2d(
self.conv = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,

View File

@ -30,6 +30,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -60,7 +61,7 @@ class Idefics2VisionEmbeddings(nn.Module):
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,

View File

@ -24,6 +24,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@ -51,7 +52,7 @@ class InternVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,

View File

@ -16,6 +16,7 @@ from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -43,7 +44,7 @@ class InternS1VisionPatchEmbeddings(nn.Module):
self.num_patches = num_patches
self.patch_shape = patch_shape
self.projection = nn.Conv2d(
self.projection = Conv2dLayer(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)

View File

@ -24,6 +24,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -204,7 +205,7 @@ class KeyeVisionEmbeddings(nn.Module):
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,

View File

@ -39,6 +39,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -120,7 +121,7 @@ class AudioPatchEmbed(nn.Module):
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
in_chans,
embed_dim,
kernel_size=self.patch_size,

View File

@ -53,6 +53,7 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
@ -244,7 +245,7 @@ class MoonVisionPatchEmbed(nn.Module):
)
self.patch_size = patch_size
self.proj = nn.Conv2d(
self.proj = Conv2dLayer(
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
)

View File

@ -45,6 +45,7 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -419,7 +420,7 @@ class SiglipVisionEmbeddings(nn.Module):
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,

View File

@ -31,6 +31,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@ -747,7 +748,7 @@ class VisionTransformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
self.patch_conv = nn.Conv2d(
self.patch_conv = Conv2dLayer(
in_channels=args.num_channels,
out_channels=args.hidden_size,
kernel_size=args.patch_size,
@ -1212,7 +1213,7 @@ class PixtralHFVisionModel(nn.Module):
self.config = config
self.patch_conv = nn.Conv2d(
self.patch_conv = Conv2dLayer(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,

View File

@ -25,6 +25,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
@ -333,7 +334,7 @@ class VisionTransformer(nn.Module):
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
self.conv1 = Conv2dLayer(
in_channels=3,
out_channels=width,
kernel_size=patch_size,

View File

@ -24,6 +24,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -286,7 +287,7 @@ class SiglipVisionEmbeddings(nn.Module):
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,

View File

@ -16,6 +16,7 @@ from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
@ -67,7 +68,7 @@ class Siglip2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
else:
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
@ -99,7 +100,7 @@ class Siglip2VisionEmbeddings(nn.Module):
target_dtype = self.patch_embedding.weight.dtype
if isinstance(self.patch_embedding, LinearBase):
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
elif isinstance(self.patch_embedding, nn.Conv2d):
elif isinstance(self.patch_embedding, Conv2dLayer):
pixel_values = pixel_values.view(
-1,
self.config.num_channels * self.config.temporal_patch_size,

View File

@ -20,6 +20,7 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
@ -667,7 +668,7 @@ class Step3VisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
self.patch_embedding = Conv2dLayer(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
@ -950,13 +951,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
self.vit_downsampler = nn.Conv2d(
self.vit_downsampler = Conv2dLayer(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride,
)
self.vit_downsampler2 = nn.Conv2d(
self.vit_downsampler2 = Conv2dLayer(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,