mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 21:35:01 +08:00
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
569 lines
20 KiB
Python
569 lines
20 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import itertools
|
|
import math
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar
|
|
|
|
import torch
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_gather,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_C = TypeVar("_C", bound=PretrainedConfig)
|
|
|
|
|
|
class _RootConfig(Protocol[_C]):
|
|
vision_config: _C
|
|
|
|
|
|
class VisionEncoderInfo(ABC, Generic[_C]):
|
|
def __init__(self, hf_config: _RootConfig[_C]) -> None:
|
|
super().__init__()
|
|
|
|
self.hf_config = hf_config
|
|
self.vision_config = hf_config.vision_config
|
|
|
|
@abstractmethod
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_image_size(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_patch_size(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_patch_grid_length(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
|
|
class VisionLanguageConfig(Protocol):
|
|
vision_config: Final[PretrainedConfig]
|
|
|
|
|
|
def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
|
|
# Avoid circular imports
|
|
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
|
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
|
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
|
|
|
if isinstance(hf_config.vision_config, CLIPVisionConfig):
|
|
return CLIPEncoderInfo(hf_config)
|
|
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
|
return PixtralHFEncoderInfo(hf_config)
|
|
if isinstance(hf_config.vision_config, SiglipVisionConfig):
|
|
return SiglipEncoderInfo(hf_config)
|
|
|
|
msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
def get_vit_attn_backend(
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
*,
|
|
attn_backend_override: AttentionBackendEnum | None = None,
|
|
) -> AttentionBackendEnum:
|
|
"""
|
|
Get the available attention backend for Vision Transformer.
|
|
"""
|
|
if attn_backend_override is not None:
|
|
return attn_backend_override
|
|
|
|
# Lazy import to avoid circular dependency
|
|
from vllm.attention.selector import get_env_variable_attn_backend
|
|
|
|
selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend()
|
|
if selected_backend is not None:
|
|
return selected_backend
|
|
|
|
return current_platform.get_vit_attn_backend(head_size, dtype)
|
|
|
|
|
|
def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
|
|
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
|
|
return vllm_config.compilation_config.compile_mm_encoder
|
|
|
|
|
|
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]
|
|
|
|
VisionFeatureSelectStrategy: TypeAlias = (
|
|
VisionFeatureSelectStrategyStr | Callable[[torch.Tensor], torch.Tensor]
|
|
)
|
|
|
|
|
|
def _get_vision_feature_selector(
|
|
strategy: VisionFeatureSelectStrategy | str,
|
|
) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
if callable(strategy):
|
|
return strategy
|
|
|
|
# https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
|
|
if strategy == "class":
|
|
return lambda feats: feats[:, :1, :]
|
|
|
|
# https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
|
|
if strategy == "default":
|
|
return lambda feats: feats[:, 1:, :]
|
|
|
|
if strategy == "full":
|
|
return lambda feats: feats
|
|
|
|
raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
|
|
|
|
|
|
def get_num_selected_vision_tokens(
|
|
num_vision_tokens: int,
|
|
strategy: VisionFeatureSelectStrategy | str,
|
|
) -> int:
|
|
if callable(strategy):
|
|
dummy_features = torch.empty(1, num_vision_tokens, 64) # [B, L, D]
|
|
dummy_selected_features = strategy(dummy_features)
|
|
return dummy_selected_features.shape[1]
|
|
|
|
if strategy == "class":
|
|
return 1
|
|
|
|
if strategy == "default":
|
|
return num_vision_tokens - 1
|
|
|
|
if strategy == "full":
|
|
return num_vision_tokens
|
|
|
|
raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
|
|
|
|
|
|
def resolve_visual_encoder_outputs(
|
|
encoder_outputs: torch.Tensor | list[torch.Tensor],
|
|
post_layer_norm: torch.nn.LayerNorm | None,
|
|
*,
|
|
select_layers: list[int] | None = None,
|
|
max_possible_layers: int | None = None,
|
|
feature_select_strategy: VisionFeatureSelectStrategy | None = None,
|
|
) -> torch.Tensor:
|
|
"""Given the outputs a visual encoder module that may correspond to the
|
|
output of the last layer, or a list of hidden states to be stacked,
|
|
handle post normalization and resolve it into a single output tensor.
|
|
|
|
Args:
|
|
encoder_outputs: Output of encoder's last layer or all hidden states.
|
|
post_layer_norm: Post norm to apply to the output of the encoder.
|
|
select_layers: Optional layer indices to grab from the encoder
|
|
outputs; if provided, encoder outputs must be a list.
|
|
max_possible_layers: Total layers in the fully loaded visual encoder.
|
|
feature_select_strategy: Defines how to select the hidden states
|
|
from each layer.
|
|
"""
|
|
if select_layers is None:
|
|
if not isinstance(encoder_outputs, torch.Tensor):
|
|
raise ValueError(
|
|
"Expected only a single encoder output when "
|
|
"`select_layers` is not provided"
|
|
)
|
|
|
|
if feature_select_strategy is not None:
|
|
select_features = _get_vision_feature_selector(feature_select_strategy)
|
|
encoder_outputs = select_features(encoder_outputs)
|
|
|
|
if post_layer_norm is not None:
|
|
return post_layer_norm(encoder_outputs)
|
|
|
|
return encoder_outputs
|
|
|
|
if max_possible_layers is None:
|
|
raise ValueError(
|
|
"`max_possible_layers` must be provided alongside `select_layers`"
|
|
)
|
|
|
|
# Get the hidden states corresponding to the layer indices.
|
|
# Negative values are relative to the full visual encoder,
|
|
# so offset them depending on how many layers were loaded.
|
|
# NOTE: this assumes that encoder_outputs is a list containing
|
|
# the inputs to the visual encoder, followed by the hidden states
|
|
# of each layer.
|
|
num_loaded_layers = len(encoder_outputs) - 1
|
|
offset = max_possible_layers - num_loaded_layers
|
|
hs_pool = [
|
|
encoder_outputs[layer_idx]
|
|
if layer_idx >= 0
|
|
else encoder_outputs[layer_idx + offset]
|
|
for layer_idx in select_layers
|
|
]
|
|
|
|
if feature_select_strategy is not None:
|
|
select_features = _get_vision_feature_selector(feature_select_strategy)
|
|
hs_pool = [select_features(hs) for hs in hs_pool]
|
|
|
|
# Apply post-norm on the final hidden state if we are using it
|
|
uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
|
|
if post_layer_norm is not None and uses_last_layer:
|
|
hs_pool[-1] = post_layer_norm(hs_pool[-1])
|
|
|
|
return torch.cat(hs_pool, dim=-1)
|
|
|
|
|
|
def run_dp_sharded_vision_model(
|
|
image_input: torch.Tensor, vision_model: torch.nn.Module
|
|
) -> torch.Tensor:
|
|
"""Run a vision model with data parallelism (DP) sharding. The function
|
|
will shard the input image tensor on the first dimension and run the vision
|
|
model
|
|
|
|
Args:
|
|
image_input (torch.Tensor): Image input tensor.
|
|
vision_model (torch.nn.Module): Vision model.
|
|
Returns:
|
|
torch.Tensor: Output image embeddings
|
|
"""
|
|
|
|
num_chunks = image_input.shape[0]
|
|
mp_world_size = get_tensor_model_parallel_world_size()
|
|
num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
|
|
num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks
|
|
pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
|
|
image_input_padded = torch.nn.functional.pad(image_input, pad)
|
|
rank = get_tensor_model_parallel_rank()
|
|
image_input_per_rank = image_input_padded[
|
|
rank * num_chunks_per_rank : (rank + 1) * num_chunks_per_rank, ...
|
|
]
|
|
|
|
vision_embeddings = vision_model(image_input_per_rank)
|
|
# Ensure tensor is contiguous before all_gather
|
|
vision_embeddings = vision_embeddings.contiguous()
|
|
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0)
|
|
vision_embeddings = vision_embeddings[:num_chunks, ...]
|
|
return vision_embeddings
|
|
|
|
|
|
def get_load_balance_assignment(
|
|
sizes: list[int],
|
|
num_gpus: int = 2,
|
|
) -> tuple[list[int], list[int], list[int]]:
|
|
"""
|
|
Generate load balancing assignment and metadata
|
|
for distributing data across GPUs.
|
|
The load is determined by the total image sizes,
|
|
not the number of images.
|
|
|
|
Args:
|
|
sizes: The size of each image
|
|
num_gpus: Number of GPUs to balance across
|
|
|
|
Returns:
|
|
shuffle_indices:
|
|
Indices to reorder data for balanced loading
|
|
gpu_sample_counts:
|
|
Number of samples assigned to each GPU
|
|
grouped_sizes_per_gpu:
|
|
Total size assigned to each GPU
|
|
|
|
Example:
|
|
```
|
|
sizes = [1000, 100, 200, 50]
|
|
num_gpus = 2
|
|
```
|
|
|
|
"""
|
|
|
|
n_samples = len(sizes)
|
|
|
|
# Handle edge cases
|
|
if n_samples == 0:
|
|
return [], [0] * num_gpus, [0] * num_gpus
|
|
|
|
# Use greedy algorithm - balance by total size, not sample count
|
|
gpu_assignments = [list[int]() for _ in range(num_gpus)]
|
|
gpu_loads = [0] * num_gpus # This tracks total SIZE, not sample count
|
|
|
|
# Sort indices by size (largest first for better load balancing)
|
|
# sizes = [1000, 100, 200, 50]
|
|
# large_to_small_indices = [0, 2, 1, 3]
|
|
large_to_small_indices = sorted(
|
|
range(n_samples), key=lambda i: sizes[i], reverse=True
|
|
)
|
|
|
|
for idx in large_to_small_indices:
|
|
# Find GPU with minimum current load (by total size)
|
|
min_gpu = min(range(num_gpus), key=lambda i: gpu_loads[i])
|
|
gpu_assignments[min_gpu].append(idx)
|
|
gpu_loads[min_gpu] += sizes[idx]
|
|
|
|
# Create shuffle indices and counts
|
|
shuffle_indices = list[int]()
|
|
gpu_sample_counts = list[int]()
|
|
for gpu_id in range(num_gpus):
|
|
# GPU_0 = [1000] = [0]
|
|
# GPU_1 = [200, 100, 50] = [2, 1, 3]
|
|
# shuffle_indices = [0, 2, 1, 3]
|
|
shuffle_indices.extend(gpu_assignments[gpu_id])
|
|
# GPU_0 = [1]
|
|
# GPU_1 = [3]
|
|
# gpu_sample_counts = [1, 3]
|
|
gpu_sample_counts.append(len(gpu_assignments[gpu_id]))
|
|
|
|
return (shuffle_indices, gpu_sample_counts, gpu_loads)
|
|
|
|
|
|
def run_dp_sharded_mrope_vision_model(
|
|
vision_model: torch.nn.Module,
|
|
pixel_values: torch.Tensor,
|
|
grid_thw_list: list[list[int]],
|
|
*,
|
|
rope_type: Literal["rope_3d", "rope_2d"],
|
|
) -> tuple[torch.Tensor, ...]:
|
|
"""Run a vision model with data parallelism (DP) sharding.
|
|
The function will shard the input image tensor on the
|
|
first dimension and run the vision model.
|
|
This function is used to run the vision model with mrope.
|
|
|
|
Args:
|
|
vision_model (torch.nn.Module): Vision model.
|
|
pixel_values (torch.Tensor): Image/Video input tensor.
|
|
grid_thw_list: List of grid dimensions for each image
|
|
rope_type: Type of rope used in the vision model.
|
|
Different rope types have different dimension to do ViT.
|
|
"rope_3d" for 3D rope (e.g., Qwen2.5-VL)
|
|
"rope_2d" for 2D rope (e.g., Kimi-VL)
|
|
Returns:
|
|
torch.Tensor: Output image embeddings
|
|
|
|
Example:
|
|
```
|
|
vision_model.out_hidden_size = 64
|
|
vision_model.spatial_merge_size = 2
|
|
pixel_values.shape = (1350, channel)
|
|
grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]]
|
|
tp_size = 2
|
|
```
|
|
|
|
"""
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
# GPU_0 tp_rank_local = 0
|
|
# GPU_1 tp_rank_local = 1
|
|
tp_rank_local = get_tensor_model_parallel_rank()
|
|
|
|
# patches_per_image = [1000, 100, 200, 50]
|
|
patches_per_image = [math.prod(grid_thw) for grid_thw in grid_thw_list]
|
|
# patches_per_image = [0, 1000, 1100, 1300, 1350]
|
|
cum_patches_per_image = [0, *itertools.accumulate(patches_per_image)]
|
|
|
|
# Get load balancing assignment with all metadata
|
|
# image_to_tp_rank = [0, 2, 1, 3]
|
|
# gpu_sample_counts = [1, 3]
|
|
# grouped_pixel_values_len = [1000, 350]
|
|
(image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = (
|
|
get_load_balance_assignment(patches_per_image, tp_size)
|
|
)
|
|
|
|
# cu_gpu_sample_counts = [0, 1, 4]
|
|
cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)]
|
|
|
|
# GPU_0 image_idxs_local = [0]
|
|
# GPU_1 image_idxs_local = [2, 1, 3]
|
|
image_idxs_local = image_to_tp_rank[
|
|
cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1]
|
|
]
|
|
|
|
# Get the pixel values for the local images based on the image_idxs_local
|
|
if len(image_idxs_local) > 0:
|
|
pixel_values_local = torch.cat(
|
|
[
|
|
pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]]
|
|
for i in image_idxs_local
|
|
]
|
|
)
|
|
else:
|
|
# Handle case where this rank has no images
|
|
pixel_values_local = torch.empty(
|
|
(0, pixel_values.shape[1]),
|
|
device=pixel_values.device,
|
|
dtype=pixel_values.dtype,
|
|
)
|
|
# embed_dim_reduction_factor = 2 * 2
|
|
if rope_type == "rope_2d":
|
|
embed_dim_reduction_factor = (
|
|
vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1]
|
|
)
|
|
else:
|
|
embed_dim_reduction_factor = (
|
|
vision_model.spatial_merge_size * vision_model.spatial_merge_size
|
|
)
|
|
|
|
# Find the max length across all ranks
|
|
# The output embedding of every DP rank has to be
|
|
# padded to this length for tensor_model_parallel_all_gather
|
|
# to work
|
|
max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor
|
|
local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local]
|
|
|
|
# Run the vision model on the local pixel_values_local
|
|
if rope_type == "rope_2d":
|
|
if pixel_values_local.shape[0] > 0:
|
|
image_embeds_local = vision_model(
|
|
pixel_values_local, torch.tensor(local_grid_thw_list)
|
|
)
|
|
if isinstance(image_embeds_local, list):
|
|
image_embeds_local = torch.cat(image_embeds_local, dim=0)
|
|
else:
|
|
out_dim = getattr(vision_model.config, "hidden_size", None)
|
|
image_embeds_local = torch.empty(
|
|
(0, embed_dim_reduction_factor, out_dim),
|
|
device=pixel_values.device,
|
|
dtype=pixel_values.dtype,
|
|
)
|
|
else:
|
|
if pixel_values_local.shape[0] > 0:
|
|
image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list)
|
|
else:
|
|
# Handle empty case
|
|
image_embeds_local = torch.empty(
|
|
(0, vision_model.out_hidden_size),
|
|
device=pixel_values.device,
|
|
dtype=pixel_values.dtype,
|
|
)
|
|
|
|
# Pad the output based on max_len_per_rank
|
|
# for tensor_model_parallel_all_gather to work
|
|
current_len = image_embeds_local.shape[0]
|
|
if current_len < max_len_per_rank:
|
|
padding_size = max_len_per_rank - current_len
|
|
if rope_type == "rope_2d":
|
|
padding = torch.empty(
|
|
(
|
|
padding_size,
|
|
image_embeds_local.shape[1],
|
|
image_embeds_local.shape[2],
|
|
),
|
|
dtype=image_embeds_local.dtype,
|
|
device=image_embeds_local.device,
|
|
)
|
|
else:
|
|
padding = torch.empty(
|
|
(padding_size, image_embeds_local.shape[1]),
|
|
dtype=image_embeds_local.dtype,
|
|
device=image_embeds_local.device,
|
|
)
|
|
image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0)
|
|
else:
|
|
image_embeds_local_padded = image_embeds_local
|
|
|
|
# Do all_gather to collect embeddings from all ranks
|
|
gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0)
|
|
|
|
# Remove padding and reconstruct per-rank embeddings
|
|
rank_embeddings = list[torch.Tensor]()
|
|
for rank in range(tp_size):
|
|
start_idx = rank * max_len_per_rank
|
|
end_idx = start_idx + (
|
|
grouped_pixel_values_len[rank] // embed_dim_reduction_factor
|
|
)
|
|
rank_embeddings.append(gathered_embeds[start_idx:end_idx])
|
|
|
|
patches_per_output_image = [
|
|
(patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image
|
|
]
|
|
|
|
# Reconstruct embeddings in the original order
|
|
original_order_embeddings = [None] * len(grid_thw_list)
|
|
current_idx = 0
|
|
for rank in range(tp_size):
|
|
count = gpu_sample_counts[rank]
|
|
if count > 0:
|
|
# Get images assigned to this rank in shuffled order
|
|
# GPU_0 = image_idxs_local [0]
|
|
# GPU_1 = image_idxs_local [2, 1, 3]
|
|
rank_images = image_to_tp_rank[current_idx : current_idx + count]
|
|
|
|
rank_embed = rank_embeddings[rank]
|
|
# Split rank embeddings back to individual images
|
|
embed_start = 0
|
|
for img_idx in rank_images:
|
|
img_patches = patches_per_output_image[img_idx]
|
|
original_order_embeddings[img_idx] = rank_embed[
|
|
embed_start : embed_start + img_patches
|
|
]
|
|
embed_start += img_patches
|
|
current_idx += count
|
|
out_embeddings = tuple(
|
|
embed for embed in original_order_embeddings if embed is not None
|
|
)
|
|
assert len(out_embeddings) == len(original_order_embeddings), (
|
|
"Found unassigned embeddings"
|
|
)
|
|
return out_embeddings
|
|
|
|
|
|
def get_llm_pos_ids_for_vision(
|
|
start_idx: int,
|
|
vision_idx: int,
|
|
spatial_merge_size: int,
|
|
t_index: list[int],
|
|
grid_hs: torch.Tensor,
|
|
grid_ws: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
llm_pos_ids_list = []
|
|
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
|
|
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
|
|
h_index = (
|
|
torch.arange(llm_grid_h)
|
|
.view(1, -1, 1)
|
|
.expand(len(t_index), -1, llm_grid_w)
|
|
.flatten()
|
|
)
|
|
w_index = (
|
|
torch.arange(llm_grid_w)
|
|
.view(1, 1, -1)
|
|
.expand(len(t_index), llm_grid_h, -1)
|
|
.flatten()
|
|
)
|
|
t_index_tensor = (
|
|
torch.Tensor(t_index)
|
|
.to(llm_grid_h.device)
|
|
.view(-1, 1)
|
|
.expand(-1, llm_grid_h * llm_grid_w)
|
|
.long()
|
|
.flatten()
|
|
)
|
|
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
|
|
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
|
|
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
|
return llm_pos_ids
|
|
|
|
|
|
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
|
|
# Conv3D weights to Linear weights for better performance.
|
|
# See: https://github.com/vllm-project/vllm/issues/27406
|
|
# and https://github.com/pytorch/pytorch/issues/166122
|
|
# FIXME(Isotr0py): Revert the PR introduces this workaround
|
|
# (https://github.com/vllm-project/vllm/pull/27418),
|
|
# once the performance issue is resolved in PyTorch.
|
|
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
|
|
"""
|
|
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
|
|
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
|
|
return linear_weight
|