Harry Mellor 8fcaaf6a16
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-12 09:51:31 -07:00

515 lines
17 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import torch.nn as nn
from transformers import SwinConfig
from transformers.models.swin.modeling_swin import SwinEmbeddings, SwinPatchMerging
from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer
from transformers.pytorch_utils import meshgrid
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
class SwinSelfAttention(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
num_heads: int,
window_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of "
f"attention heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = (
window_size
if isinstance(window_size, Iterable)
else (window_size, window_size)
)
self.scale = self.attention_head_size**-0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.relative_position_index = nn.Parameter(
relative_position_index, requires_grad=False
)
self.qkv = QKVParallelLinear(
hidden_size=dim,
head_size=self.attention_head_size,
total_num_heads=self.num_attention_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
return relative_position_bias.unsqueeze(0)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape
qkv_output, _ = self.qkv(hidden_states)
query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
query_layer = self.transpose_for_scores(query_layer)
attention_scores = self._get_rel_pos_bias()
if attention_mask is not None:
mask_shape = attention_mask.shape[0]
attention_mask_expanded = attention_mask.view(
1, mask_shape, 1, dim, dim
).expand(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
)
attention_scores = attention_scores + attention_mask_expanded.unsqueeze(
1
).unsqueeze(0)
attention_scores = attention_scores.view(
-1, self.num_attention_heads, dim, dim
)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_scores,
dropout_p=0.0,
)
attention_probs = None
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
return outputs
class SwinSelfOutput(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.dense = RowParallelLinear(
input_size=dim,
output_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
return hidden_states
class SwinAttention(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
num_heads: int,
window_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.self = SwinSelfAttention(
config,
dim,
num_heads,
window_size,
quant_config=quant_config,
prefix=f"{prefix}.self",
)
self.output = SwinSelfOutput(
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
)
self.pruned_heads = set()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class SwinIntermediate(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.dense = ColumnParallelLinear(
dim,
int(config.mlp_ratio * dim),
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
self.intermediate_act_fn = get_act_fn(config.hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class SwinOutput(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.dense = RowParallelLinear(
int(config.mlp_ratio * dim),
dim,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense(hidden_states)
return hidden_states
class SwinLayer(HFSwinLayer):
def __init__(
self,
config: SwinConfig,
dim: int,
input_resolution: int,
num_heads: int,
drop_path_rate: float = 0.0,
shift_size: int = 0,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
drop_path_rate=drop_path_rate,
shift_size=shift_size,
)
self.attention = SwinAttention(
config,
dim,
num_heads,
window_size=self.window_size,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.intermediate = SwinIntermediate(
config, dim, quant_config=quant_config, prefix=f"{prefix}.intermediate"
)
self.output = SwinOutput(
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
)
class SwinStage(nn.Module):
def __init__(
self,
config: SwinConfig,
dim: int,
input_resolution: int,
depth: int,
num_heads: int,
drop_path: list[float],
downsample: SwinPatchMerging | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
SwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
drop_path_rate=drop_path[layer_idx],
shift_size=0 if (layer_idx % 2 == 0) else config.window_size // 2,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=nn.LayerNorm
)
else:
self.downsample = None
self.pointing = False
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(
hidden_states_before_downsampling, input_dimensions
)
else:
output_dimensions = (height, width, height, width)
stage_outputs = (
hidden_states,
hidden_states_before_downsampling,
output_dimensions,
)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
class SwinEncoder(nn.Module):
def __init__(
self,
config: SwinConfig,
grid_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [
x.item()
for x in torch.linspace(
0, config.drop_path_rate, sum(config.depths), device="cpu"
)
]
self.layers = nn.ModuleList(
[
SwinStage(
config=config,
dim=int(config.embed_dim * 2**layer_idx),
input_resolution=(
grid_size[0] // (2**layer_idx),
grid_size[1] // (2**layer_idx),
),
depth=config.depths[layer_idx],
num_heads=config.num_heads[layer_idx],
drop_path=dpr[
sum(config.depths[:layer_idx]) : sum(
config.depths[: layer_idx + 1]
)
],
downsample=SwinPatchMerging
if (layer_idx < self.num_layers - 1)
else None,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(self.num_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
return hidden_states
class SwinModel(nn.Module):
config_class: SwinConfig
def __init__(
self,
config: SwinConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(
config,
self.embeddings.patch_grid,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
)
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
)
return encoder_outputs
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv", "query", "q"),
("qkv", "key", "k"),
("qkv", "value", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params