mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 15:59:43 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
371 lines
16 KiB
Python
371 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# This code is based off the following work:
|
|
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
|
|
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
|
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
|
model compatible with HuggingFace weights."""
|
|
from typing import Iterable, List, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import StableLmConfig
|
|
|
|
from vllm.attention import Attention, AttentionMetadata
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .interfaces import SupportsPP
|
|
from .utils import (is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory, make_layers,
|
|
maybe_prefix)
|
|
|
|
|
|
class StablelmMLP(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: StableLmConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "") -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
config.hidden_size, [config.intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj")
|
|
self.down_proj = RowParallelLinear(config.intermediate_size,
|
|
config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj")
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class StablelmAttention(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: StableLmConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "") -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = config.num_attention_heads
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
|
|
self.total_num_key_value_heads = config.num_key_value_heads
|
|
if self.total_num_key_value_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_key_value_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_key_value_heads == 0
|
|
self.num_key_value_heads = max(
|
|
1, self.total_num_key_value_heads // tp_size)
|
|
self.head_dim = self.hidden_size // self.total_num_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
rope_pct = getattr(config, "rope_pct",
|
|
getattr(config, "partial_rotary_factor", 1))
|
|
self.rotary_ndims = int(self.head_dim * rope_pct)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_key_value_heads * self.head_dim
|
|
self.qkv_bias = getattr(config, "use_qkv_bias", False)
|
|
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
|
|
raise ValueError(f"hidden_size must be divisible by num_heads "
|
|
f"(got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads}).")
|
|
|
|
self.qkv_proj = QKVParallelLinear(self.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_key_value_heads,
|
|
self.qkv_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj")
|
|
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj")
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.rotary_ndims,
|
|
max_position=self.config.max_position_embeddings,
|
|
base=self.config.rope_theta,
|
|
)
|
|
self.attn = Attention(self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_key_value_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn")
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class StablelmDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: StableLmConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.self_attn = StablelmAttention(config,
|
|
cache_config,
|
|
quant_config,
|
|
prefix=f"{prefix}.self_attn")
|
|
self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
|
norm_eps = getattr(config, "norm_eps",
|
|
getattr(config, "layer_norm_eps", 1e-05))
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
|
eps=norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states, residual
|
|
|
|
|
|
class StableLMEpochModel(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.embed_tokens",
|
|
)
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: StablelmDecoderLayer(
|
|
config, cache_config, quant_config, prefix=prefix),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
norm_eps = getattr(config, "norm_eps",
|
|
getattr(config, "layer_norm_eps", 1e-05))
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
|
self.make_empty_intermediate_tensors = (
|
|
make_empty_intermediate_tensors_factory(["hidden_states"],
|
|
config.hidden_size))
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors],
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
for i in range(self.start_layer, self.end_layer):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
kv_caches[i - self.start_layer],
|
|
attn_metadata,
|
|
)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({"hidden_states": hidden_states})
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class StablelmForCausalLM(nn.Module, SupportsPP):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.model = StableLMEpochModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.lm_head")
|
|
if self.config.tie_word_embeddings:
|
|
self.lm_head.weight = self.model.embed_tokens.weight
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.sampler = get_sampler()
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
attn_metadata, intermediate_tensors,
|
|
inputs_embeds)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
if ("rotary_emb.cos_cached" in name
|
|
or "rotary_emb.sin_cached" in name):
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|