mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 06:31:49 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
633 lines
26 KiB
Python
633 lines
26 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""Inference-only Jamba model."""
|
|
from typing import Iterable, List, Optional, Set, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import JambaConfig
|
|
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.attention.layer import Attention
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.distributed.parallel_state import get_pp_group
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
MambaCacheParams)
|
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
|
from vllm.utils import LayerBlockType
|
|
|
|
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
|
from .utils import (is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory, make_layers,
|
|
maybe_prefix)
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|
|
class JambaMoE(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: JambaConfig,
|
|
num_experts: Optional[int] = None,
|
|
top_k: Optional[int] = None,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
tp_size: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__()
|
|
self.num_total_experts = num_experts or config.num_experts
|
|
self.top_k = top_k or config.num_experts_per_tok
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
if self.num_total_experts > 1:
|
|
self.router = ReplicatedLinear(self.hidden_size,
|
|
self.num_total_experts,
|
|
bias=False,
|
|
quant_config=None,
|
|
params_dtype=params_dtype)
|
|
|
|
self.experts = FusedMoE(self.num_total_experts,
|
|
self.top_k,
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
tp_size=tp_size,
|
|
params_dtype=params_dtype,
|
|
reduce_results=True,
|
|
renormalize=False,
|
|
use_grouped_topk=False,
|
|
quant_config=quant_config)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
orig_shape = hidden_states.shape
|
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
|
# router_logits: (batch * sequence_length, n_experts)
|
|
if self.num_total_experts > 1:
|
|
router_logits, _ = self.router(hidden_states)
|
|
else:
|
|
router_logits = torch.ones((hidden_states.shape[0], 1),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype)
|
|
hidden_states = self.experts(hidden_states, router_logits)
|
|
return hidden_states.view(orig_shape)
|
|
|
|
|
|
class JambaMLP(JambaMoE):
|
|
|
|
def __init__(self,
|
|
config: JambaConfig,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
tp_size: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None):
|
|
super().__init__(config,
|
|
num_experts=1,
|
|
top_k=1,
|
|
params_dtype=params_dtype,
|
|
tp_size=tp_size,
|
|
quant_config=quant_config)
|
|
|
|
|
|
class JambaMambaDecoderLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: JambaConfig,
|
|
layer_idx: int,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
is_lora_enabled: Optional[bool] = False,
|
|
**kwargs) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.is_lora_enabled = is_lora_enabled
|
|
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
|
ssm_state_size = config.mamba_d_state,
|
|
conv_kernel_size = config.mamba_d_conv,
|
|
intermediate_size = config.mamba_expand *\
|
|
config.hidden_size,
|
|
time_step_rank = config.mamba_dt_rank,
|
|
use_conv_bias = config.mamba_conv_bias,
|
|
use_bias = config.mamba_proj_bias,
|
|
use_rms_norm=True,
|
|
rms_norm_eps=config.rms_norm_eps,
|
|
activation=config.hidden_act,
|
|
is_lora_enabled = self.is_lora_enabled
|
|
)
|
|
|
|
num_experts = config.layers_num_experts[layer_idx]
|
|
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
|
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
residual: Optional[torch.Tensor],
|
|
mamba_cache_params: MambaCacheParams,
|
|
**kwargs,
|
|
):
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
|
|
hidden_states = self.mamba(hidden_states, attn_metadata,
|
|
mamba_cache_params)
|
|
# Fully Connected
|
|
hidden_states, residual = self.pre_ff_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.feed_forward(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class JambaAttentionDecoderLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
config: JambaConfig,
|
|
layer_idx: int,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
**kwargs) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = config.num_attention_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = config.num_key_value_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
self.head_dim = config.hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
config.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
|
config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config)
|
|
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
num_experts = config.layers_num_experts[layer_idx]
|
|
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
|
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
|
|
def self_attention(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
residual: Optional[torch.Tensor],
|
|
**kwargs,
|
|
):
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
|
|
hidden_states = self.self_attention(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
# Fully Connected
|
|
hidden_states, residual = self.pre_ff_layernorm(
|
|
hidden_states, residual)
|
|
hidden_states = self.feed_forward(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
ALL_DECODER_LAYER_TYPES = {
|
|
"attention": JambaAttentionDecoderLayer,
|
|
"mamba": JambaMambaDecoderLayer
|
|
}
|
|
|
|
|
|
class JambaModel(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
lora_config = vllm_config.lora_config
|
|
|
|
self.config = config
|
|
self.padding_idx = config.pad_token_id
|
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
|
self.vocab_size = config.vocab_size + lora_vocab
|
|
self.org_vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
)
|
|
|
|
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
|
|
|
def get_layer(prefix: str):
|
|
layer_idx = int(prefix.rsplit(".", 1)[1])
|
|
layer_class = ALL_DECODER_LAYER_TYPES[
|
|
config.layers_block_type[layer_idx]]
|
|
return layer_class(config,
|
|
layer_idx,
|
|
cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
**extra_kwargs)
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
|
self.make_empty_intermediate_tensors = (
|
|
make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size))
|
|
|
|
self.final_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
mamba_cache_params: MambaCacheParams,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
kv_cache_index = 0
|
|
mamba_cache_index = 0
|
|
for i in range(self.start_layer, self.end_layer):
|
|
layer = self.layers[i]
|
|
kv_cache = None
|
|
layer_mamba_cache_params = None
|
|
if isinstance(layer, JambaAttentionDecoderLayer):
|
|
kv_cache = kv_caches[kv_cache_index]
|
|
kv_cache_index += 1
|
|
if isinstance(layer, JambaMambaDecoderLayer):
|
|
current_state_layer = mamba_cache_index
|
|
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
current_state_layer)
|
|
mamba_cache_index += 1
|
|
|
|
hidden_states, residual = layer(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
residual=residual,
|
|
mamba_cache_params=layer_mamba_cache_params)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
"residual": residual
|
|
})
|
|
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|
IsHybrid):
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"in_proj": ["in_proj"],
|
|
}
|
|
|
|
# LoRA specific attributes
|
|
supported_lora_modules = [
|
|
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
|
|
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
|
|
]
|
|
embedding_modules = {
|
|
"embed_tokens": "input_embeddings",
|
|
"lm_head": "output_embeddings",
|
|
}
|
|
embedding_padding_modules = ["lm_head"]
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
lora_config = vllm_config.lora_config
|
|
scheduler_config = vllm_config.scheduler_config
|
|
assert not cache_config.enable_prefix_caching, \
|
|
"Jamba currently does not support prefix caching"
|
|
|
|
super().__init__()
|
|
self.config = config
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.scheduler_config = scheduler_config
|
|
self.model = JambaModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
self.unpadded_vocab_size = config.vocab_size
|
|
if lora_config:
|
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
|
self.lm_head = ParallelLMHead(
|
|
self.unpadded_vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
|
# We need bigger padding if using lora for kernel
|
|
# compatibility
|
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
|
)
|
|
# Used to track and store by the Mamba cache between steps.
|
|
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
|
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
config.vocab_size)
|
|
self.sampler = get_sampler()
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
if self.scheduler_config is not None and \
|
|
not self.model_config.enforce_eager:
|
|
if self.scheduler_config.max_num_seqs > \
|
|
vllm_config.compilation_config.max_capture_size:
|
|
self.max_batch_size = \
|
|
vllm_config.compilation_config.max_capture_size
|
|
else:
|
|
self.max_batch_size = vllm_config.pad_for_cudagraph(
|
|
self.scheduler_config.max_num_seqs)
|
|
else:
|
|
self.max_batch_size = 8192 + 2
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[KVCache],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs):
|
|
if self.mamba_cache is None:
|
|
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
|
self.mamba_cache = MambaCacheManager(
|
|
self.lm_head.weight.dtype, num_mamba_layers,
|
|
self.max_batch_size, *self._get_mamba_cache_shape())
|
|
(
|
|
mamba_cache_tensors,
|
|
state_indices_tensor,
|
|
) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
|
|
**kwargs)
|
|
mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
|
|
mamba_cache_tensors[1],
|
|
state_indices_tensor)
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
attn_metadata, mamba_cache_params,
|
|
intermediate_tensors, inputs_embeds)
|
|
return hidden_states
|
|
|
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
input_buffers, **kwargs)
|
|
|
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
|
|
def _get_mamba_cache_shape(
|
|
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
hidden_size = self.config.hidden_size
|
|
conv_state_shape = (
|
|
self.config.mamba_expand * hidden_size // world_size,
|
|
self.config.mamba_d_conv - 1,
|
|
)
|
|
temporal_state_shape = (
|
|
self.config.mamba_expand * hidden_size // world_size,
|
|
self.config.mamba_d_state,
|
|
)
|
|
return conv_state_shape, temporal_state_shape
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: Optional[torch.Tensor],
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
|
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
ckpt_gate_proj_name="gate_proj",
|
|
ckpt_down_proj_name="down_proj",
|
|
ckpt_up_proj_name="up_proj",
|
|
num_experts=self.config.num_experts)
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
if "A_log" in name:
|
|
name = name.replace("A_log", "A")
|
|
|
|
if ".self_attn." in name:
|
|
name = name.replace(".self_attn", "")
|
|
|
|
if "feed_forward" in name and not _is_moe_layer(name):
|
|
## map MLP layers to expert with ID=0
|
|
name = name.replace("feed_forward", "feed_forward.experts.0")
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
if 'experts' in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# Skip layers on other devices.
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
for (
|
|
param_name,
|
|
weight_name,
|
|
expert_id,
|
|
shard_id,
|
|
) in expert_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param,
|
|
loaded_weight,
|
|
name,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
def _is_moe_layer(name: str):
|
|
return any(
|
|
[experts_name in name for experts_name in [
|
|
"experts",
|
|
"router",
|
|
]])
|
|
|
|
|
|
class JambaForSequenceClassification(JambaForCausalLM):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
config = vllm_config.model_config.hf_config
|
|
num_labels: int = config.num_labels
|
|
score_bias: bool = getattr(config, 'score_bias', False)
|
|
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
|
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
self._pooler = Pooler.from_config_with_defaults(
|
|
pooler_config,
|
|
pooling_type=PoolingType.LAST,
|
|
normalize=False,
|
|
softmax=False)
|
|
|
|
def pooler(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
pooling_metadata: PoolingMetadata,
|
|
) -> Optional[PoolerOutput]:
|
|
hidden_states = hidden_states.float()
|
|
logits = self.score(hidden_states)
|
|
return self._pooler(logits, pooling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
# TODO: The reward weights themselves have float32 accuracy data, we
|
|
# would like to load them in fp32 to get that extra precision.
|
|
super().load_weights(weights)
|
|
self.score = self.score.float()
|