mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[New Model]: google/embeddinggemma-300m (#24318)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
53b19ccdd5
commit
6d6c6b05d3
@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
|
||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |
|
||||
|
||||
@ -10,7 +10,8 @@ import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.models.utils import EmbedModelInfo, RerankModelInfo
|
||||
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
||||
check_embeddings_close)
|
||||
|
||||
# Most embedding models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
@ -163,12 +164,14 @@ def mteb_test_embed_models(hf_runner,
|
||||
model_info: EmbedModelInfo,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None,
|
||||
atol=MTEB_RERANK_TOL):
|
||||
atol=MTEB_EMBED_TOL):
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
example_prompts = ["The chef prepared a delicious meal."]
|
||||
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
@ -191,6 +194,7 @@ def mteb_test_embed_models(hf_runner,
|
||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||
MTEB_EMBED_TASKS)
|
||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
if model_info.mteb_score is None:
|
||||
with hf_runner(model_info.name,
|
||||
@ -202,6 +206,16 @@ def mteb_test_embed_models(hf_runner,
|
||||
|
||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||
st_dtype = next(hf_model.model.parameters()).dtype
|
||||
|
||||
# Test embed_dims and whether to use normalize
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
else:
|
||||
st_main_score = model_info.mteb_score
|
||||
st_dtype = "Constant"
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
|
||||
LASTPoolingEmbedModelInfo)
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
# ST models with projector (Dense) layers
|
||||
@ -13,6 +14,10 @@ ST_PROJECTOR_MODELS = [
|
||||
mteb_score=0.688611955,
|
||||
enable_test=True,
|
||||
),
|
||||
LASTPoolingEmbedModelInfo("google/embeddinggemma-300m",
|
||||
architecture="Gemma3TextModel",
|
||||
mteb_score=0.7473819294684156,
|
||||
enable_test=True)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -352,6 +352,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
|
||||
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
|
||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
trust_remote_code=True),
|
||||
|
||||
@ -2750,6 +2750,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
_FLOAT16_NOT_SUPPORTED_MODELS = {
|
||||
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"gemma3_text":
|
||||
"Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
}
|
||||
|
||||
@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
module = dense_modules[0]
|
||||
folder = module.get("path", "")
|
||||
layers = []
|
||||
for module in dense_modules:
|
||||
folder = module.get("path", "")
|
||||
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not layer_config:
|
||||
return None
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not layer_config:
|
||||
continue
|
||||
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
|
||||
if _load_dense_weights(linear, folder, model_config):
|
||||
layers = [linear]
|
||||
if not _load_dense_weights(linear, folder, model_config):
|
||||
continue
|
||||
|
||||
layers.append(linear)
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
except Exception:
|
||||
logger.exception("ST projector loading failed")
|
||||
|
||||
|
||||
@ -24,6 +24,14 @@ class VerifyAndUpdateConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Gemma3TextModelConfig:
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||
|
||||
|
||||
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
|
||||
@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Gemma3TextConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module):
|
||||
rope_scaling=self.rope_scaling,
|
||||
)
|
||||
|
||||
# Initialize the attention.
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
|
||||
self.attn = attn_cls(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -155,6 +155,7 @@ _EMBEDDING_MODELS = {
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
||||
"GritLM": ("gritlm", "GritLM"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user