mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 16:07:07 +08:00
[New Model]: Snowflake Arctic Embed (Family) (#16649)
This commit is contained in:
parent
686623c5e7
commit
3d3ab3689f
@ -3,24 +3,17 @@
|
||||
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||
|
||||
from ...models.embedding.utils import EmbedModelInfo
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
class ModelInfo(NamedTuple):
|
||||
name: str
|
||||
is_matryoshka: bool
|
||||
|
||||
|
||||
MODELS = [
|
||||
ModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
|
||||
ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
|
||||
EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
|
||||
EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
|
||||
]
|
||||
|
||||
input_texts = [
|
||||
@ -30,7 +23,7 @@ input_texts = [
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
async def test_validating_dimensions(model: ModelInfo):
|
||||
async def test_validating_dimensions(model: EmbedModelInfo):
|
||||
args = [
|
||||
"--task",
|
||||
"embed",
|
||||
|
||||
101
tests/models/embedding/language/test_snowflake_arctic_embed.py
Normal file
101
tests/models/embedding/language/test_snowflake_arctic_embed.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the embedding outputs of HF and vLLM models.
|
||||
|
||||
Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from tests.models.embedding.utils import EmbedModelInfo
|
||||
|
||||
from ..utils import check_embeddings_close
|
||||
|
||||
EMBEDDING_PROMPTS = [
|
||||
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
|
||||
'Mexico City of Course!'
|
||||
]
|
||||
|
||||
MODELS = [
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-s",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long",
|
||||
is_matryoshka=False,
|
||||
architecture="NomicBertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l",
|
||||
is_matryoshka=False,
|
||||
architecture="BertModel",
|
||||
enable_test=False),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
is_matryoshka=True,
|
||||
architecture="BertModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0",
|
||||
is_matryoshka=True,
|
||||
architecture="XLMRobertaModel",
|
||||
enable_test=True),
|
||||
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
is_matryoshka=True,
|
||||
architecture="GteModel",
|
||||
enable_test=True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model_info: EmbedModelInfo,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
example_prompts = example_prompts + EMBEDDING_PROMPTS
|
||||
|
||||
vllm_extra_kwargs = {
|
||||
"hf_overrides": {
|
||||
"is_matryoshka": model_info.is_matryoshka
|
||||
}
|
||||
}
|
||||
|
||||
with hf_runner(model_info.name, dtype=dtype,
|
||||
is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
task="embed",
|
||||
dtype=dtype,
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
|
||||
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
|
||||
model_info.is_matryoshka)
|
||||
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture
|
||||
in vllm_model.model.llm_engine.model_config.architectures)
|
||||
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -37,3 +38,10 @@ def matryoshka_fy(tensor, dimensions):
|
||||
tensor = tensor[..., :dimensions]
|
||||
tensor = F.normalize(tensor, p=2, dim=1)
|
||||
return tensor
|
||||
|
||||
|
||||
class EmbedModelInfo(NamedTuple):
|
||||
name: str
|
||||
is_matryoshka: bool
|
||||
architecture: str = ""
|
||||
enable_test: bool = True
|
||||
|
||||
@ -247,11 +247,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
trust_remote_code=True),
|
||||
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
|
||||
trust_remote_code=True),
|
||||
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
|
||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
|
||||
|
||||
@ -354,6 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
"gelu_and_mul": lambda: GeluAndMul(),
|
||||
})
|
||||
|
||||
|
||||
|
||||
@ -11,8 +11,10 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||
get_act_fn)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
||||
@ -108,6 +110,7 @@ class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -118,6 +121,7 @@ class BertEncoder(nn.Module):
|
||||
BertLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
@ -139,6 +143,7 @@ class BertLayer(nn.Module):
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -149,19 +154,31 @@ class BertLayer(nn.Module):
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
self.intermediate = BertIntermediate(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
if config.hidden_act in ["silu", "gelu_and_mul"]:
|
||||
self.intermediate = BertGatedIntermediate(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
else:
|
||||
self.intermediate = BertIntermediate(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
|
||||
self.output = BertOutput(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
@ -181,6 +198,7 @@ class BertAttention(nn.Module):
|
||||
layer_norm_eps: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@ -190,11 +208,13 @@ class BertAttention(nn.Module):
|
||||
num_attention_heads=num_attention_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
self.output = BertSelfOutput(hidden_size=hidden_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
|
||||
@ -215,6 +235,7 @@ class BertSelfAttention(nn.Module):
|
||||
num_attention_heads: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@ -240,7 +261,7 @@ class BertSelfAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=True,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
@ -278,12 +299,13 @@ class BertSelfOutput(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
layer_norm_eps: float,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(input_size=hidden_size,
|
||||
output_size=hidden_size,
|
||||
bias=True,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
@ -301,12 +323,13 @@ class BertIntermediate(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=True,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
self.intermediate_act_fn = get_act_fn(hidden_act)
|
||||
@ -317,19 +340,46 @@ class BertIntermediate(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertGatedIntermediate(nn.Module):
|
||||
# for NomciBert and GteModel
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.act_fn = get_act_and_mul_fn(hidden_act)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
gate_up, _ = self.gate_up_proj(hidden_states)
|
||||
hidden_states = self.act_fn(gate_up)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_norm_eps: float,
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.dense = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=True,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
|
||||
@ -343,19 +393,32 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["query", "key", "value"],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding,
|
||||
bias: bool = True,
|
||||
rotary_kwargs: Optional[dict] = None,
|
||||
add_pooling_layer: bool = False):
|
||||
super().__init__()
|
||||
"""
|
||||
For BertModel, all linear layers have bias.
|
||||
For NomicBertModel, all linear layers do not have bias.
|
||||
"""
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
bias=bias,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
@ -387,6 +450,8 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
("qkv_proj", "query", "q"),
|
||||
("qkv_proj", "key", "k"),
|
||||
("qkv_proj", "value", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
@ -546,3 +611,115 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
token_type_ids=token_type_ids)
|
||||
|
||||
|
||||
class NomicBertEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"emb_ln": "embeddings.LayerNorm",
|
||||
"layers": "layer",
|
||||
"attn.Wqkv": "attention.self.qkv_proj",
|
||||
"attn.out_proj": "attention.output.dense",
|
||||
'norm1': "attention.output.LayerNorm",
|
||||
'mlp.fc11': "intermediate.up_proj",
|
||||
'mlp.fc12': "intermediate.gate_proj",
|
||||
'mlp.fc2': "output.dense",
|
||||
'norm2': "output.LayerNorm",
|
||||
})
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "NomicBertConfig"
|
||||
assert config.activation_function == "swiglu"
|
||||
|
||||
# Assume NomicBertModel all linear layers do not have bias
|
||||
assert not config.mlp_fc1_bias
|
||||
assert not config.mlp_fc2_bias
|
||||
assert not config.qkv_proj_bias
|
||||
|
||||
config.layer_norm_eps = config.layer_norm_epsilon
|
||||
config.position_embedding_type = "rotary"
|
||||
config.intermediate_size = config.n_inner
|
||||
config.hidden_act = "silu"
|
||||
config.hidden_size = config.n_embd
|
||||
config.num_hidden_layers = config.n_layer
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_trained_positions,
|
||||
"base": config.rotary_emb_base,
|
||||
"rope_scaling": {
|
||||
"rope_type": "dynamic",
|
||||
"factor": config.rotary_scaling_factor
|
||||
}
|
||||
}
|
||||
|
||||
return BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
bias=False,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
|
||||
class GteEmbeddingModel(BertEmbeddingModel):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
"attention.qkv_proj": "attention.self.qkv_proj",
|
||||
"attention.o_proj": "attention.output.dense",
|
||||
'attn_ln': "attention.output.LayerNorm",
|
||||
'mlp.down_proj': "output.dense",
|
||||
'mlp_ln': "output.LayerNorm",
|
||||
})
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> BertModel:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "GteConfig"
|
||||
assert config.position_embedding_type == "rope"
|
||||
assert config.hidden_act == "gelu"
|
||||
|
||||
config.position_embedding_type = "rotary"
|
||||
config.hidden_act = "gelu_and_mul"
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"base": config.rope_theta,
|
||||
}
|
||||
|
||||
model = BertModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
rotary_kwargs=rotary_kwargs,
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
# GteModel only gate_up_proj does not have bias.
|
||||
# Hack method learned from vllm/model_executor/models/glm.py
|
||||
for layer in model.encoder.layer:
|
||||
layer.intermediate.gate_up_proj.bias = None
|
||||
layer.intermediate.skip_bias_add = True
|
||||
return model
|
||||
|
||||
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
n = "mlp.up_gate_proj"
|
||||
for name, weight in weights:
|
||||
if n in name:
|
||||
up, gate = weight.chunk(2, dim=0)
|
||||
yield name.replace(n, "intermediate.up_proj"), up
|
||||
yield name.replace(n, "intermediate.gate_proj"), gate
|
||||
else:
|
||||
yield name, weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
weights = self.split_up_gate_proj(weights)
|
||||
self.model.load_weights(weights)
|
||||
|
||||
@ -122,13 +122,11 @@ _TEXT_GENERATION_MODELS = {
|
||||
_EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"GritLM": ("gritlm", "GritLM"),
|
||||
"GteModel": ("bert", "GteEmbeddingModel"),
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
@ -138,12 +136,16 @@ _EMBEDDING_MODELS = {
|
||||
if arch == "LlamaForCausalLM"
|
||||
},
|
||||
"MistralModel": ("llama", "LlamaForCausalLM"),
|
||||
"NomicBertModel": ("bert", "NomicBertEmbeddingModel"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
|
||||
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user