mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:36:29 +08:00
[Model][MM] Extract conv layer as CustomOp (#28455)
Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
360bd8762f
commit
41b92f7d38
236
vllm/model_executor/layers/conv.py
Normal file
236
vllm/model_executor/layers/conv.py
Normal file
@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Conv Layer Class."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.torch_utils import is_torch_equal
|
||||
|
||||
|
||||
class ConvLayerBase(CustomOp):
|
||||
"""Conv layer base class."""
|
||||
|
||||
num_dim: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] = 0,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
*,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
kernel_size = (
|
||||
(kernel_size,) * self.num_dim
|
||||
if isinstance(kernel_size, int)
|
||||
else kernel_size
|
||||
)
|
||||
stride = (stride,) * self.num_dim if isinstance(stride, int) else stride
|
||||
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
|
||||
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.enable_linear = (
|
||||
(self.kernel_size == self.stride)
|
||||
and not any(self.padding)
|
||||
and self.groups == 1
|
||||
)
|
||||
self.input_size = in_channels * math.prod(self.kernel_size)
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(
|
||||
out_channels,
|
||||
in_channels // groups,
|
||||
*kernel_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_channels={self.in_channels}, "
|
||||
s += f"out_channels={self.out_channels}, "
|
||||
s += f"kernel_size={self.kernel_size}, "
|
||||
s += f"stride={self.stride}, "
|
||||
s += f"padding={self.padding}, "
|
||||
s += f"bias={self.bias is not None}"
|
||||
return s
|
||||
|
||||
|
||||
@CustomOp.register("conv2d")
|
||||
class Conv2dLayer(ConvLayerBase):
|
||||
"""Conv layer with Conv2d."""
|
||||
|
||||
num_dim = 2
|
||||
|
||||
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 4
|
||||
B, C, H, W = x.shape
|
||||
K1, K2 = self.kernel_size
|
||||
H, W = H // K1, W // K2
|
||||
x = x.unfold(2, K1, K1).unfold(3, K2, K2)
|
||||
x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size)
|
||||
x = F.linear(
|
||||
x,
|
||||
self.weight.view(self.out_channels, self.input_size),
|
||||
self.bias,
|
||||
)
|
||||
x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 4
|
||||
x = F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Expected input shape: (batch_size, in_channels, height, width)"""
|
||||
assert x.dim() == 4
|
||||
if self.enable_linear:
|
||||
return self._forward_mulmat(x)
|
||||
else:
|
||||
return self._forward_conv(x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# By default, we use CUDNN's convolution ops with optimization.
|
||||
return self._forward_conv(x)
|
||||
|
||||
|
||||
class CausalConv2dLayer(Conv2dLayer):
|
||||
"""
|
||||
A causal version of nn.Conv2d where each location in the 2D matrix would
|
||||
have no access to locations on its right or down
|
||||
All arguments are the same as nn.Conv2d except padding which should be
|
||||
set as None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
*,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
if padding is not None:
|
||||
raise ValueError(
|
||||
"Argument padding should be set to None for CausalConv2dLayer."
|
||||
)
|
||||
self._left_padding: int = kernel_size - 1
|
||||
self._right_padding: int = stride - 1
|
||||
padding = 0
|
||||
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
padding_mode,
|
||||
params_dtype=params_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0))
|
||||
x = super().forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@CustomOp.register("conv3d")
|
||||
class Conv3dLayer(ConvLayerBase):
|
||||
"""Conv layer with Conv3d."""
|
||||
|
||||
num_dim = 3
|
||||
|
||||
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 5
|
||||
B, C, T, H, W = x.shape
|
||||
K1, K2, K3 = self.kernel_size
|
||||
T, H, W = T // K1, H // K2, W // K3
|
||||
x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3)
|
||||
x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size)
|
||||
x = F.linear(
|
||||
x,
|
||||
self.weight.view(self.out_channels, self.input_size),
|
||||
self.bias,
|
||||
)
|
||||
x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3)
|
||||
return x
|
||||
|
||||
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 5
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Expected input shape: (batch_size, in_channels, time, height, width)"""
|
||||
if self.enable_linear:
|
||||
return self._forward_mulmat(x)
|
||||
else:
|
||||
return self._forward_conv(x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a
|
||||
# significant performance regression.
|
||||
# See: https://github.com/vllm-project/vllm/issues/27406
|
||||
# and https://github.com/pytorch/pytorch/issues/166122
|
||||
# By default, we use CUDNN's convolution ops with optimization.
|
||||
if self.enable_linear and is_torch_equal("2.9.0"):
|
||||
return self._forward_mulmat(x)
|
||||
return self._forward_conv(x)
|
||||
@ -20,6 +20,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -315,7 +316,7 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
self.patch_embedding = Conv2dLayer(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
|
||||
@ -56,12 +56,12 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -103,7 +103,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
conv3d_to_linear_weight,
|
||||
get_vit_attn_backend,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
)
|
||||
@ -486,15 +485,18 @@ class Glm4vVisionPatchEmbed(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = ReplicatedLinear(
|
||||
in_channels * math.prod(kernel_size),
|
||||
self.proj = Conv3dLayer(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=True,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.hidden_size)
|
||||
return x
|
||||
|
||||
|
||||
@ -893,9 +895,6 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith("patch_embed.proj.weight"):
|
||||
loaded_weight = conv3d_to_linear_weight(loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@ -26,7 +26,6 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import lru_cache, partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
@ -56,12 +55,12 @@ from vllm.distributed import utils as dist_utils
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -110,7 +109,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
conv3d_to_linear_weight,
|
||||
get_vit_attn_backend,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
)
|
||||
@ -525,15 +523,18 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = ReplicatedLinear(
|
||||
in_channels * math.prod(kernel_size),
|
||||
self.proj = Conv3dLayer(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.hidden_size)
|
||||
return x
|
||||
|
||||
|
||||
@ -957,9 +958,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith("patch_embed.proj.weight"):
|
||||
loaded_weight = conv3d_to_linear_weight(loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@ -25,7 +25,6 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
@ -54,9 +53,9 @@ from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -107,7 +106,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
conv3d_to_linear_weight,
|
||||
get_vit_attn_backend,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
)
|
||||
@ -566,15 +564,18 @@ class Qwen2VisionPatchEmbed(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = ReplicatedLinear(
|
||||
in_channels * math.prod(kernel_size),
|
||||
self.proj = Conv3dLayer(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.embed_dim)
|
||||
return x
|
||||
|
||||
|
||||
@ -844,9 +845,6 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith("patch_embed.proj.weight"):
|
||||
loaded_weight = conv3d_to_linear_weight(loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@ -54,9 +53,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -102,7 +101,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
conv3d_to_linear_weight,
|
||||
get_llm_pos_ids_for_vision,
|
||||
get_vit_attn_backend,
|
||||
)
|
||||
@ -138,16 +136,18 @@ class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = ReplicatedLinear(
|
||||
in_channels * math.prod(kernel_size),
|
||||
self.proj = Conv3dLayer(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=True,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
L, C = x.shape
|
||||
x = self.proj(x)
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.hidden_size)
|
||||
return x
|
||||
|
||||
|
||||
@ -566,9 +566,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith("patch_embed.proj.weight"):
|
||||
loaded_weight = conv3d_to_linear_weight(loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@ -24,7 +24,6 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
@ -57,9 +56,9 @@ from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||
from vllm.model_executor.layers.conv import Conv3dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -114,7 +113,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
conv3d_to_linear_weight,
|
||||
get_vit_attn_backend,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
)
|
||||
@ -139,15 +137,18 @@ class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = ReplicatedLinear(
|
||||
in_channels * math.prod(kernel_size),
|
||||
self.proj = Conv3dLayer(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=True,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.hidden_size)
|
||||
return x
|
||||
|
||||
|
||||
@ -579,9 +580,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith("patch_embed.proj.weight"):
|
||||
loaded_weight = conv3d_to_linear_weight(loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@ -550,19 +550,3 @@ def get_llm_pos_ids_for_vision(
|
||||
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
|
||||
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
||||
return llm_pos_ids
|
||||
|
||||
|
||||
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
|
||||
# Conv3D weights to Linear weights for better performance.
|
||||
# See: https://github.com/vllm-project/vllm/issues/27406
|
||||
# and https://github.com/pytorch/pytorch/issues/166122
|
||||
# FIXME(Isotr0py): Revert the PR introduces this workaround
|
||||
# (https://github.com/vllm-project/vllm/pull/27418),
|
||||
# once the performance issue is resolved in PyTorch.
|
||||
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
|
||||
"""
|
||||
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
|
||||
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
|
||||
return linear_weight
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user