mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 21:26:11 +08:00
[Model] Support Solar Model (#8386)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
d65798f78c
commit
e18749ff09
@ -179,6 +179,10 @@ Decoder-only Language Models
|
|||||||
- Starcoder2
|
- Starcoder2
|
||||||
- :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc.
|
- :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`SolarForCausalLM`
|
||||||
|
- EXAONE-3
|
||||||
|
- :code:`upstage/solar-pro-preview-instruct`, etc.
|
||||||
|
-
|
||||||
* - :code:`XverseForCausalLM`
|
* - :code:`XverseForCausalLM`
|
||||||
- Xverse
|
- Xverse
|
||||||
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
|
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
|
||||||
|
|||||||
@ -60,6 +60,7 @@ _GENERATION_MODELS = {
|
|||||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||||
|
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||||
|
|||||||
580
vllm/model_executor/models/solar.py
Normal file
580
vllm/model_executor/models/solar.py
Normal file
@ -0,0 +1,580 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only Solar model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
get_compressed_tensors_cache_scale)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||||
|
from vllm.model_executor.models.utils import (PPMissingLayer,
|
||||||
|
is_pp_missing_parameter,
|
||||||
|
make_layers)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
|
||||||
|
class SolarMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=hidden_size,
|
||||||
|
output_sizes=[intermediate_size] * 2,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
input_size=intermediate_size,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SolarAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||||
|
self.head_dim = getattr(config, "head_dim",
|
||||||
|
self.hidden_size // self.total_num_heads)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.total_num_heads * self.head_dim,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class SolarDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
if rope_scaling is not None and getattr(
|
||||||
|
config, "original_max_position_embeddings", None):
|
||||||
|
rope_scaling["original_max_position_embeddings"] \
|
||||||
|
= config.original_max_position_embeddings
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
|
8192)
|
||||||
|
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||||
|
# Support internlm/internlm-7b with bias
|
||||||
|
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||||
|
config, "bias", False)
|
||||||
|
self.self_attn = SolarAttention(
|
||||||
|
config=config,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||||
|
config.num_attention_heads),
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=attention_bias,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
self.mlp = SolarMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=getattr(config, "mlp_bias", False),
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class SolarModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||||
|
and get_pp_group().is_last_rank):
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: SolarDecoderLayer(
|
||||||
|
config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
bskcn_h_1 = None
|
||||||
|
bskcn_h_2 = None
|
||||||
|
bskcn_r_1 = None
|
||||||
|
bskcn_r_2 = None
|
||||||
|
bskcn_tv = (self.config.bskcn_tv[0]
|
||||||
|
if self.training else self.config.bskcn_tv[1])
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
if i in self.config.bskcn_1:
|
||||||
|
bskcn_h_1 = hidden_states.clone()
|
||||||
|
bskcn_r_1 = residual.clone()
|
||||||
|
if i in self.config.bskcn_2:
|
||||||
|
bskcn_h_2 = hidden_states.clone()
|
||||||
|
bskcn_r_2 = residual.clone()
|
||||||
|
if i in self.config.bskcn_3:
|
||||||
|
hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * (
|
||||||
|
1 - bskcn_tv)
|
||||||
|
residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv)
|
||||||
|
if i in self.config.bskcn_4:
|
||||||
|
hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * (
|
||||||
|
1 - bskcn_tv)
|
||||||
|
residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv)
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i - self.start_layer],
|
||||||
|
attn_metadata,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual
|
||||||
|
})
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SolarForCausalLM(nn.Module, SupportsLoRA):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.model = SolarModel(
|
||||||
|
config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
prefix="model",
|
||||||
|
)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
model_output = self.model(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, intermediate_tensors)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def make_empty_intermediate_tensors(
|
||||||
|
self, batch_size: int, dtype: torch.dtype,
|
||||||
|
device: torch.device) -> IntermediateTensors:
|
||||||
|
return IntermediateTensors({
|
||||||
|
"hidden_states":
|
||||||
|
torch.zeros(
|
||||||
|
(batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
"residual":
|
||||||
|
torch.zeros(
|
||||||
|
(batch_size, self.config.hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||||
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# If this function is called, it should always initialize KV cache scale
|
||||||
|
# factors (or else raise an exception). Thus, handled exceptions should
|
||||||
|
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||||
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||||
|
quantization_param_path,
|
||||||
|
tp_rank,
|
||||||
|
tp_size,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.__class__.model_type,
|
||||||
|
):
|
||||||
|
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
||||||
|
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
# The scaling factor convention we are assuming is
|
||||||
|
# quantized_value * scaling_factor ~= true_value
|
||||||
|
# which is consistent with the practice of setting
|
||||||
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
|
scaling_factor *= 2
|
||||||
|
if hasattr(layer_self_attn, "kv_scale"):
|
||||||
|
layer_self_attn.attn._kv_scale = scaling_factor
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
|
"factor attribute!")
|
||||||
@ -24,7 +24,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
|||||||
JAISConfig, MedusaConfig,
|
JAISConfig, MedusaConfig,
|
||||||
MLPSpeculatorConfig, MPTConfig,
|
MLPSpeculatorConfig, MPTConfig,
|
||||||
NemotronConfig, RWConfig,
|
NemotronConfig, RWConfig,
|
||||||
UltravoxConfig)
|
SolarConfig, UltravoxConfig)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
|
|
||||||
@ -50,6 +50,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
"exaone": ExaoneConfig,
|
"exaone": ExaoneConfig,
|
||||||
"internvl_chat": InternVLChatConfig,
|
"internvl_chat": InternVLChatConfig,
|
||||||
"nemotron": NemotronConfig,
|
"nemotron": NemotronConfig,
|
||||||
|
"solar": SolarConfig,
|
||||||
"ultravox": UltravoxConfig,
|
"ultravox": UltravoxConfig,
|
||||||
# Granite can be removed from here once we have upgraded to
|
# Granite can be removed from here once we have upgraded to
|
||||||
# transformers 4.45+
|
# transformers 4.45+
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.transformers_utils.configs.medusa import MedusaConfig
|
|||||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||||
|
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"ExaoneConfig",
|
"ExaoneConfig",
|
||||||
"MLPSpeculatorConfig",
|
"MLPSpeculatorConfig",
|
||||||
"NemotronConfig",
|
"NemotronConfig",
|
||||||
|
"SolarConfig",
|
||||||
"UltravoxConfig",
|
"UltravoxConfig",
|
||||||
# Granite can be removed from here once we have upgraded to
|
# Granite can be removed from here once we have upgraded to
|
||||||
# transformers 4.45+
|
# transformers 4.45+
|
||||||
|
|||||||
245
vllm/transformers_utils/configs/solar.py
Normal file
245
vllm/transformers_utils/configs/solar.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Solar model configuration"""
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SolarConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store
|
||||||
|
the configuration of a [`SolarModel`].
|
||||||
|
It is used to instantiate an LLaMA model
|
||||||
|
according to the specified arguments,
|
||||||
|
defining the model architecture.
|
||||||
|
Instantiating a configuration with the
|
||||||
|
defaults will yield a similar
|
||||||
|
configuration to that of the LLaMA-7B.
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`]
|
||||||
|
and can be used to control the model outputs.
|
||||||
|
Read the documentation from [`PretrainedConfig`] for more information.
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the LLaMA model.
|
||||||
|
Defines the number of different tokens
|
||||||
|
that can be represented by the `inputs_ids`
|
||||||
|
passed when calling [`SolarModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer
|
||||||
|
in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that
|
||||||
|
should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`,
|
||||||
|
the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model
|
||||||
|
will use Multi Query Attention (MQA)
|
||||||
|
otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint,
|
||||||
|
each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group.
|
||||||
|
For more details checkout [this paper]
|
||||||
|
(https://arxiv.org/pdf/2305.13245.pdf).
|
||||||
|
If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string)
|
||||||
|
in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
Solar 1 supports up to 2048 tokens,
|
||||||
|
Solar 2 up to 4096, CodeSolar up to 16384.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of
|
||||||
|
the truncated_normal_initializer for initializing
|
||||||
|
all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return
|
||||||
|
the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
Padding token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
Beginning of stream token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
End of stream token id.
|
||||||
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||||
|
Experimental feature. Tensor parallelism rank
|
||||||
|
used during pretraining.
|
||||||
|
Please refer to [this
|
||||||
|
document](https://huggingface.co/docs/
|
||||||
|
transformers/main/
|
||||||
|
perf_train_gpu_many#tensor-parallelism)
|
||||||
|
to understand more about it. This value is
|
||||||
|
necessary to ensure exact reproducibility
|
||||||
|
of the pretraining results.
|
||||||
|
Please refer to [this
|
||||||
|
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for
|
||||||
|
the RoPE embeddings.
|
||||||
|
Currently supports two scaling
|
||||||
|
strategies: linear and dynamic.
|
||||||
|
Their scaling factor must be a float greater than 1.
|
||||||
|
The expected format is
|
||||||
|
`{"type": strategy name, "factor": scaling factor}`.
|
||||||
|
When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum.
|
||||||
|
See the following thread for more information on how
|
||||||
|
these scaling strategies behave:
|
||||||
|
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/
|
||||||
|
dynamically_scaled_rope_further_increases/. This is an
|
||||||
|
experimental feature, subject to breaking
|
||||||
|
API changes in future versions.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value
|
||||||
|
and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in up_proj, down_proj and gate_proj
|
||||||
|
layers in the MLP layers.
|
||||||
|
sliding_window (`int`, *optional*, defaults to 2047):
|
||||||
|
Sliding window attention window size. If not specified,
|
||||||
|
will default to `2047`.
|
||||||
|
```python
|
||||||
|
>>> from transformers import SolarModel, SolarConfig
|
||||||
|
>>> # Initializing a Solar-pro style configuration
|
||||||
|
>>> configuration = SolarConfig()
|
||||||
|
>>> # Initializing a model from the Solar-pro style configuration
|
||||||
|
>>> model = SolarModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "solar"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
mlp_bias=False,
|
||||||
|
sliding_window=2047,
|
||||||
|
bskcn_1=None,
|
||||||
|
bskcn_2=None,
|
||||||
|
bskcn_3=None,
|
||||||
|
bskcn_4=None,
|
||||||
|
bskcn_tv=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self._rope_scaling_validation()
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.mlp_bias = mlp_bias
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.bskcn_1 = bskcn_1 if bskcn_1 is not None else [12, 20, 32, 44]
|
||||||
|
self.bskcn_2 = bskcn_2 if bskcn_2 is not None else [20, 32]
|
||||||
|
self.bskcn_3 = bskcn_3 if bskcn_3 is not None else [16, 24, 36, 48]
|
||||||
|
self.bskcn_4 = bskcn_4 if bskcn_4 is not None else [28, 40]
|
||||||
|
self.bskcn_tv = bskcn_tv if bskcn_tv is not None else [0.9, 0.8]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _rope_scaling_validation(self):
|
||||||
|
"""
|
||||||
|
Validate the `rope_scaling` configuration.
|
||||||
|
"""
|
||||||
|
if self.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (not isinstance(self.rope_scaling, dict)
|
||||||
|
or len(self.rope_scaling) != 2):
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling` must be a dictionary with two fields,"
|
||||||
|
" `type` and `factor`, "
|
||||||
|
f"got {self.rope_scaling}")
|
||||||
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
|
if rope_scaling_type is None or rope_scaling_type not in [
|
||||||
|
"linear",
|
||||||
|
"dynamic",
|
||||||
|
]:
|
||||||
|
raise ValueError(f"`rope_scaling`'s type field must be one of "
|
||||||
|
f"['linear', 'dynamic'], got {rope_scaling_type}")
|
||||||
|
if (rope_scaling_factor is None
|
||||||
|
or not isinstance(rope_scaling_factor, float)
|
||||||
|
or rope_scaling_factor <= 1.0):
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s factor field must be a float > 1,"
|
||||||
|
f" got {rope_scaling_factor}")
|
||||||
Loading…
x
Reference in New Issue
Block a user