From e4bb2684bcea12f72a36a6c48292f79534af849a Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 19 Nov 2025 02:56:04 +0800 Subject: [PATCH] [Models] Replace all `nn.Conv2d` with vLLM's Conv2dLayer (#28842) Signed-off-by: Isotr0py --- vllm/model_executor/layers/conv.py | 24 +++++++++++++-- vllm/model_executor/models/aimv2.py | 3 +- vllm/model_executor/models/blip.py | 3 +- vllm/model_executor/models/chameleon.py | 29 +++++++++---------- vllm/model_executor/models/deepencoder.py | 13 +++++---- vllm/model_executor/models/dots_ocr.py | 3 +- vllm/model_executor/models/glm4_1v.py | 4 +-- vllm/model_executor/models/glm4v.py | 5 ++-- .../models/idefics2_vision_model.py | 3 +- vllm/model_executor/models/intern_vit.py | 3 +- vllm/model_executor/models/interns1_vit.py | 3 +- vllm/model_executor/models/keye.py | 3 +- vllm/model_executor/models/midashenglm.py | 3 +- vllm/model_executor/models/moonvit.py | 3 +- vllm/model_executor/models/paddleocr_vl.py | 3 +- vllm/model_executor/models/pixtral.py | 5 ++-- vllm/model_executor/models/qwen_vl.py | 3 +- vllm/model_executor/models/siglip.py | 3 +- vllm/model_executor/models/siglip2navit.py | 5 ++-- vllm/model_executor/models/step3_vl.py | 7 +++-- 20 files changed, 83 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py index e6f2d2990c241..8d51e5bd9920a 100644 --- a/vllm/model_executor/layers/conv.py +++ b/vllm/model_executor/layers/conv.py @@ -3,6 +3,7 @@ """Conv Layer Class.""" import math +from typing import Literal import torch import torch.nn as nn @@ -23,11 +24,11 @@ class ConvLayerBase(CustomOp): out_channels: int, kernel_size: int | tuple[int, ...], stride: int | tuple[int, ...] = 1, - padding: int | tuple[int, ...] = 0, + padding: int | tuple[int, ...] | Literal["same", "valid"] = 0, dilation: int | tuple[int, ...] = 1, groups: int = 1, bias: bool = True, - padding_mode: str = "zeros", + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", *, params_dtype: torch.dtype | None = None, ) -> None: @@ -36,6 +37,22 @@ class ConvLayerBase(CustomOp): if params_dtype is None: params_dtype = torch.get_default_dtype() + valid_padding_strings = {"same", "valid"} + if isinstance(padding, str) and padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string '{padding}'. " + f"Expected one of {valid_padding_strings}." + ) + + if padding == "same": + padding = ( + kernel_size // 2 + if isinstance(kernel_size, int) + else tuple(k // 2 for k in kernel_size) + ) + elif padding == "valid": + padding = 0 + kernel_size = ( (kernel_size,) * self.num_dim if isinstance(kernel_size, int) @@ -45,6 +62,9 @@ class ConvLayerBase(CustomOp): padding = (padding,) * self.num_dim if isinstance(padding, int) else padding dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation + if padding == "same" and any(s != 1 for s in stride): + raise ValueError("padding='same' is not supported for strided convolutions") + self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 5872e8196eada..3d000f3ac3ab5 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -12,6 +12,7 @@ from vllm.attention.layer import MultiHeadAttention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -58,7 +59,7 @@ class AIMv2SwiGLUFFN(nn.Module): class AIMv2PatchEmbed(nn.Module): def __init__(self, config: AIMv2Config): super().__init__() - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( config.num_channels, config.hidden_size, kernel_size=(config.patch_size, config.patch_size), diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 2e4f73312efa3..f31f99c0592b2 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -12,6 +12,7 @@ from transformers import Blip2VisionConfig, BlipVisionConfig from vllm.attention.layer import MultiHeadAttention from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -47,7 +48,7 @@ class BlipVisionEmbeddings(nn.Module): self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index fb7476c45fcdb..3c87bbfefab3d 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -22,6 +22,7 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -549,7 +550,7 @@ class ChameleonVQVAEVectorQuantizer(nn.Module): class ChameleonVQVAEEncoderConvDownsample(nn.Module): def __init__(self, in_channels: int): super().__init__() - self.conv = nn.Conv2d( + self.conv = Conv2dLayer( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) @@ -577,23 +578,23 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module): self.norm1 = torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) - self.conv1 = torch.nn.Conv2d( + self.conv1 = Conv2dLayer( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) self.norm2 = torch.nn.GroupNorm( num_groups=32, num_channels=out_channels, eps=1e-6, affine=True ) self.dropout = torch.nn.Dropout(config.dropout) - self.conv2 = torch.nn.Conv2d( + self.conv2 = Conv2dLayer( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( + self.conv_shortcut = Conv2dLayer( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: - self.nin_shortcut = torch.nn.Conv2d( + self.nin_shortcut = Conv2dLayer( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) @@ -626,16 +627,16 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module): self.norm = torch.nn.GroupNorm( num_groups=32, num_channels=in_channels, eps=1e-6, affine=True ) - self.q = torch.nn.Conv2d( + self.q = Conv2dLayer( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) - self.k = torch.nn.Conv2d( + self.k = Conv2dLayer( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) - self.v = torch.nn.Conv2d( + self.v = Conv2dLayer( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) - self.proj_out = torch.nn.Conv2d( + self.proj_out = Conv2dLayer( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) @@ -681,7 +682,7 @@ class ChameleonVQVAEEncoder(nn.Module): latent_channels = config.latent_channels channel_multiplier = config.channel_multiplier - self.conv_in = torch.nn.Conv2d( + self.conv_in = Conv2dLayer( in_channels, base_channels, kernel_size=3, stride=1, padding=1 ) @@ -738,7 +739,7 @@ class ChameleonVQVAEEncoder(nn.Module): self.norm_out = torch.nn.GroupNorm( num_groups=32, num_channels=block_in, eps=1e-6, affine=True ) - self.conv_out = torch.nn.Conv2d( + self.conv_out = Conv2dLayer( block_in, 2 * latent_channels if double_latent else latent_channels, kernel_size=3, @@ -779,10 +780,8 @@ class ChameleonVQVAE(nn.Module): super().__init__() self.encoder = ChameleonVQVAEEncoder(config) self.quantize = ChameleonVQVAEVectorQuantizer(config) - self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d( - config.embed_dim, config.latent_channels, 1 - ) + self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1) self.eval() # Chameleon's VQ model is frozen def encode( diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py index e62a57eccc953..8f1660891fcbf 100644 --- a/vllm/model_executor/models/deepencoder.py +++ b/vllm/model_executor/models/deepencoder.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from transformers import CLIPVisionConfig from vllm.attention.layer import MultiHeadAttention +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -133,14 +134,14 @@ class ImageEncoderViT(nn.Module): self.blocks.append(block) self.neck = nn.Sequential( - nn.Conv2d( + Conv2dLayer( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), - nn.Conv2d( + Conv2dLayer( out_chans, out_chans, kernel_size=3, @@ -150,8 +151,10 @@ class ImageEncoderViT(nn.Module): LayerNorm2d(out_chans), ) - self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) - self.net_3 = nn.Conv2d( + self.net_2 = Conv2dLayer( + 256, 512, kernel_size=3, stride=2, padding=1, bias=False + ) + self.net_3 = Conv2dLayer( 512, 1024, kernel_size=3, stride=2, padding=1, bias=False ) @@ -500,7 +503,7 @@ class PatchEmbed(nn.Module): """ super().__init__() - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index f46caaa095c6a..2d2251e83b5b1 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -22,6 +22,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -471,7 +472,7 @@ class DotsPatchEmbed(nn.Module): self.temporal_patch_size = config.temporal_patch_size self.embed_dim = config.embed_dim self.config = config - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( config.num_channels, config.embed_dim, kernel_size=(config.patch_size, config.patch_size), diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 65c3fc2d9e975..2c2f45c2453ee 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -56,7 +56,7 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor.layers.conv import Conv3dLayer +from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -734,7 +734,7 @@ class Glm4vVisionTransformer(nn.Module): self.post_conv_layernorm = RMSNorm( vision_config.hidden_size, eps=vision_config.rms_norm_eps ) - self.downsample = nn.Conv2d( + self.downsample = Conv2dLayer( in_channels=vision_config.hidden_size, out_channels=vision_config.out_hidden_size, kernel_size=vision_config.spatial_merge_size, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 1c18ea0745f2b..514082cf60ce2 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -24,6 +24,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, @@ -78,7 +79,7 @@ class GLMVImagePixelInputs(TensorSchema): class EVA2CLIPPatchEmbedding(nn.Module): def __init__(self, config): super().__init__() - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( config.in_channels, config.hidden_size, kernel_size=config.patch_size, @@ -333,7 +334,7 @@ class EVA2CLIPModel(nn.Module): quant_config=quant_config, prefix=f"{prefix}.linear_proj", ) - self.conv = nn.Conv2d( + self.conv = Conv2dLayer( in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2, diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 727c8ec0397ca..06b8468e18db9 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -30,6 +30,7 @@ from transformers.models.idefics2.configuration_idefics2 import ( from vllm.attention.layer import MultiHeadAttention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -60,7 +61,7 @@ class Idefics2VisionEmbeddings(nn.Module): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 03918127c6ae1..61aeafc2ab436 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -24,6 +24,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -51,7 +52,7 @@ class InternVisionEmbeddings(nn.Module): self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index 507503d75046d..cb0414bbc95a8 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -16,6 +16,7 @@ from transformers.utils import torch_int from vllm.attention.layer import MultiHeadAttention from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -43,7 +44,7 @@ class InternS1VisionPatchEmbeddings(nn.Module): self.num_patches = num_patches self.patch_shape = patch_shape - self.projection = nn.Conv2d( + self.projection = Conv2dLayer( num_channels, hidden_size, kernel_size=patch_size, stride=patch_size ) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 1eb0eccc0411c..8fc3db296aa79 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -24,6 +24,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -204,7 +205,7 @@ class KeyeVisionEmbeddings(nn.Module): self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index a84c99059cd9c..d9b23811730d4 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -39,6 +39,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -120,7 +121,7 @@ class AudioPatchEmbed(nn.Module): self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( in_chans, embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 8017c947bf9ad..2e3e6dc166ad8 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -53,6 +53,7 @@ from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.utils import is_flash_attn_2_available +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.models.utils import maybe_prefix from vllm.transformers_utils.configs.moonvit import MoonViTConfig @@ -244,7 +245,7 @@ class MoonVisionPatchEmbed(nn.Module): ) self.patch_size = patch_size - self.proj = nn.Conv2d( + self.proj = Conv2dLayer( in_dim, out_dim, kernel_size=patch_size, stride=patch_size ) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 3ef6470070d18..dee0c16ab0f63 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -45,6 +45,7 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -419,7 +420,7 @@ class SiglipVisionEmbeddings(nn.Module): self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 8cb7d6a889da4..8a034fd72b02a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -31,6 +31,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -747,7 +748,7 @@ class VisionTransformer(nn.Module): def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args - self.patch_conv = nn.Conv2d( + self.patch_conv = Conv2dLayer( in_channels=args.num_channels, out_channels=args.hidden_size, kernel_size=args.patch_size, @@ -1212,7 +1213,7 @@ class PixtralHFVisionModel(nn.Module): self.config = config - self.patch_conv = nn.Conv2d( + self.patch_conv = Conv2dLayer( in_channels=config.num_channels, out_channels=config.hidden_size, kernel_size=config.patch_size, diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 6a259cade9cf1..4906cf441f6fb 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -25,6 +25,7 @@ from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, @@ -333,7 +334,7 @@ class VisionTransformer(nn.Module): patch_height, patch_width = self.patch_size = (patch_size, patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) self.output_dim = output_dim - self.conv1 = nn.Conv2d( + self.conv1 = Conv2dLayer( in_channels=3, out_channels=width, kernel_size=patch_size, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 42d906d089f90..ce5847bf79a5e 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -24,6 +24,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -286,7 +287,7 @@ class SiglipVisionEmbeddings(nn.Module): self.image_size = config.image_size self.patch_size = config.patch_size - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 29dd164ad37fd..46f5e67d659ef 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -16,6 +16,7 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearBase, @@ -67,7 +68,7 @@ class Siglip2VisionEmbeddings(nn.Module): self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) else: - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, @@ -99,7 +100,7 @@ class Siglip2VisionEmbeddings(nn.Module): target_dtype = self.patch_embedding.weight.dtype if isinstance(self.patch_embedding, LinearBase): patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) - elif isinstance(self.patch_embedding, nn.Conv2d): + elif isinstance(self.patch_embedding, Conv2dLayer): pixel_values = pixel_values.view( -1, self.config.num_channels * self.config.temporal_patch_size, diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 5d16be1eb3128..1c60cb4148121 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -20,6 +20,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -667,7 +668,7 @@ class Step3VisionEmbeddings(nn.Module): self.class_embedding = nn.Parameter(torch.randn(1, self.embed_dim)) - self.patch_embedding = nn.Conv2d( + self.patch_embedding = Conv2dLayer( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, @@ -950,13 +951,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) prefix=maybe_prefix(prefix, "vision_model"), use_data_parallel=self.use_data_parallel, ) - self.vit_downsampler = nn.Conv2d( + self.vit_downsampler = Conv2dLayer( config.vision_config.hidden_size, config.vision_config.output_hidden_size, kernel_size=2, stride=config.understand_projector_stride, ) - self.vit_downsampler2 = nn.Conv2d( + self.vit_downsampler2 = Conv2dLayer( config.vision_config.output_hidden_size, config.vision_config.output_hidden_size * 2, kernel_size=3,