mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:45:00 +08:00
[Model] Add support for Qwen2MoeModel (#3346)
This commit is contained in:
parent
14ccd94c89
commit
d6ea427f04
@ -89,6 +89,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
|||||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
||||||
|
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
||||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
||||||
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
|
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
|
||||||
|
|||||||
@ -119,6 +119,10 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- Qwen2
|
- Qwen2
|
||||||
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
|
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`Qwen2MoeForCausalLM`
|
||||||
|
- Qwen2MoE
|
||||||
|
- :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.
|
||||||
|
-
|
||||||
* - :code:`StableLmForCausalLM`
|
* - :code:`StableLmForCausalLM`
|
||||||
- StableLM
|
- StableLM
|
||||||
- :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
|
- :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
|
||||||
|
|||||||
@ -47,6 +47,7 @@ _MODELS = {
|
|||||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
|
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
|
|||||||
457
vllm/model_executor/models/qwen2_moe.py
Normal file
457
vllm/model_executor/models/qwen2_moe.py
Normal file
@ -0,0 +1,457 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
|
||||||
|
# Copyright 2024 The Qwen team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
|
hf_model_weights_iterator)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
reduce_results: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size, [intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method,
|
||||||
|
reduce_results=reduce_results)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.rank = get_tensor_model_parallel_rank()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.n_routed_experts = config.num_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
if self.tp_size > self.n_routed_experts:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
|
f"the number of experts {self.n_routed_experts}.")
|
||||||
|
|
||||||
|
self.experts = nn.ModuleList([
|
||||||
|
Qwen2MoeMLP(hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.moe_intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
linear_method=linear_method,
|
||||||
|
reduce_results=False)
|
||||||
|
for idx in range(self.n_routed_experts)
|
||||||
|
])
|
||||||
|
self.pack_params()
|
||||||
|
|
||||||
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
|
self.n_routed_experts,
|
||||||
|
bias=False,
|
||||||
|
linear_method=None)
|
||||||
|
if config.shared_expert_intermediate_size > 0:
|
||||||
|
self.shared_expert = Qwen2MoeMLP(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.shared_expert_intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
linear_method=linear_method,
|
||||||
|
reduce_results=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.shared_expert = None
|
||||||
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
|
||||||
|
1,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
def pack_params(self):
|
||||||
|
w1 = []
|
||||||
|
w2 = []
|
||||||
|
for expert in self.experts:
|
||||||
|
w1.append(expert.gate_up_proj.weight)
|
||||||
|
w2.append(expert.down_proj.weight)
|
||||||
|
self.w1 = torch._utils._flatten_dense_tensors(w1)
|
||||||
|
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
||||||
|
for data, param in zip(w1s, w1):
|
||||||
|
param.data = data
|
||||||
|
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
||||||
|
|
||||||
|
self.w2 = torch._utils._flatten_dense_tensors(w2)
|
||||||
|
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
||||||
|
for data, param in zip(w2s, w2):
|
||||||
|
param.data = data
|
||||||
|
|
||||||
|
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
shared_output = None
|
||||||
|
if self.shared_expert is not None:
|
||||||
|
shared_output = self.shared_expert(hidden_states)
|
||||||
|
if self.shared_expert_gate is not None:
|
||||||
|
shared_output = F.sigmoid(
|
||||||
|
self.shared_expert_gate(hidden_states)) * shared_output
|
||||||
|
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
final_hidden_states = fused_moe(hidden_states,
|
||||||
|
self.w1,
|
||||||
|
self.w2,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=self.config.norm_topk_prob,
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=True,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
layer_idx: int,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
self.self_attn = Qwen2MoeAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
if (config.num_experts is not None
|
||||||
|
and (layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||||
|
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
||||||
|
linear_method=linear_method)
|
||||||
|
else:
|
||||||
|
self.mlp = Qwen2MoeMLP(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Qwen2MoeDecoderLayer(config,
|
||||||
|
layer_idx,
|
||||||
|
linear_method=linear_method)
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
residual = None
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(positions, hidden_states,
|
||||||
|
kv_caches[i], attn_metadata,
|
||||||
|
residual)
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MoeForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.model = Qwen2MoeModel(config, linear_method)
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path,
|
||||||
|
cache_dir,
|
||||||
|
load_format,
|
||||||
|
revision,
|
||||||
|
fall_back_to_pt=False):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip experts that are not assigned to this worker.
|
||||||
|
if (("mlp.experts." in name or "mlp.shared_expert." in name)
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Skip experts that are not assigned to this worker.
|
||||||
|
if (("mlp.experts." in name or "mlp.shared_expert." in name)
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
Loading…
x
Reference in New Issue
Block a user