Matthew Bonanni 430dd4d9eb
[Attention] Remove imports from vllm/attention/__init__.py (#29342)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2025-11-26 10:53:15 -07:00

509 lines
18 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Ouro model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Any
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .utils import (
AutoWeightsLoader,
extract_layer_index,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
class OuroMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class OuroAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.dual_chunk_attention_config = dual_chunk_attention_config
# Get total_ut_steps from config, default to 4 if not specified
total_ut_steps = getattr(config, "total_ut_steps", 4)
# Use total number of hidden layers instead of hardcoded 24
total_layers = config.num_hidden_layers
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
rope_parameters=config.rope_parameters,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = nn.ModuleList()
for ut_step in range(total_ut_steps):
base_layer_idx = extract_layer_index(prefix)
unique_layer_idx = ut_step * total_layers + base_layer_idx
unique_prefix = prefix.replace(
f"layers.{base_layer_idx}", f"layers.{unique_layer_idx}"
)
self.attn.append(
Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{unique_prefix}.attn",
**{
"layer_idx": unique_layer_idx,
"dual_chunk_attention_config": dual_chunk_attention_config,
}
if dual_chunk_attention_config
else {},
)
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
current_ut: int,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn[current_ut](q, k, v)
output, _ = self.o_proj(attn_output)
return output
class OuroDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = OuroAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = OuroMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm_2 = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
current_ut: int,
residual: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions, hidden_states=hidden_states, current_ut=current_ut
)
hidden_states = self.input_layernorm_2(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_attention_layernorm_2(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
)
class OuroModel(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = OuroDecoderLayer,
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if cache_config.sliding_window is not None and hasattr(
config, "max_window_layers"
):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
"is less than `num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
)
)
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
# Use the provided decoder layer type or default to OuroDecoderLayer
decoder_layer_type = decoder_layer_type or OuroDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer_type(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.early_exit_gate = RowParallelLinear(config.hidden_size, 1, bias=True)
self.total_ut_steps = getattr(self.config, "total_ut_steps", 4)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
for current_ut in range(self.total_ut_steps):
residual = None
for layer in self.layers[self.start_layer : self.end_layer]:
hidden_states, residual = layer(
positions, hidden_states, current_ut, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class OuroForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OuroModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)