diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 9db6f8036a73..bdb29aac333c 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
| `BertModel`C | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
| `Gemma2Model`C | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Gemma3TextModel`C | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
| `GteModel`C | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
| `GteNewModel`C | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |
diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py
index 7be1bba2ff69..68b1cc80303a 100644
--- a/tests/models/language/pooling/mteb_utils.py
+++ b/tests/models/language/pooling/mteb_utils.py
@@ -10,7 +10,8 @@ import numpy as np
import pytest
import requests
-from tests.models.utils import EmbedModelInfo, RerankModelInfo
+from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
+ check_embeddings_close)
# Most embedding models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
@@ -163,12 +164,14 @@ def mteb_test_embed_models(hf_runner,
model_info: EmbedModelInfo,
vllm_extra_kwargs=None,
hf_model_callback=None,
- atol=MTEB_RERANK_TOL):
+ atol=MTEB_EMBED_TOL):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
+ example_prompts = ["The chef prepared a delicious meal."]
+
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
@@ -191,6 +194,7 @@ def mteb_test_embed_models(hf_runner,
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
+ vllm_outputs = vllm_model.embed(example_prompts)
if model_info.mteb_score is None:
with hf_runner(model_info.name,
@@ -202,6 +206,16 @@ def mteb_test_embed_models(hf_runner,
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
st_dtype = next(hf_model.model.parameters()).dtype
+
+ # Test embed_dims and whether to use normalize
+ hf_outputs = hf_model.encode(example_prompts)
+ check_embeddings_close(
+ embeddings_0_lst=hf_outputs,
+ embeddings_1_lst=vllm_outputs,
+ name_0="hf",
+ name_1="vllm",
+ tol=1e-2,
+ )
else:
st_main_score = model_info.mteb_score
st_dtype = "Constant"
diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py
index bafeb4060d80..9301e705c433 100644
--- a/tests/models/language/pooling/test_st_projector.py
+++ b/tests/models/language/pooling/test_st_projector.py
@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
-from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
+from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
+ LASTPoolingEmbedModelInfo)
from .mteb_utils import mteb_test_embed_models
# ST models with projector (Dense) layers
@@ -13,6 +14,10 @@ ST_PROJECTOR_MODELS = [
mteb_score=0.688611955,
enable_test=True,
),
+ LASTPoolingEmbedModelInfo("google/embeddinggemma-300m",
+ architecture="Gemma3TextModel",
+ mteb_score=0.7473819294684156,
+ enable_test=True)
]
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 38efb01341eb..c6ff50b5426e 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -352,6 +352,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
+ "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py
index 8bdc22acf380..c4434c37f4c7 100644
--- a/vllm/config/__init__.py
+++ b/vllm/config/__init__.py
@@ -2750,6 +2750,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
_FLOAT16_NOT_SUPPORTED_MODELS = {
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
+ "gemma3_text":
+ "Numerical instability. Please use bfloat16 or float32 instead.",
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
}
diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py
index 50c2cd97f3d0..bb96bc559200 100644
--- a/vllm/model_executor/models/adapters.py
+++ b/vllm/model_executor/models/adapters.py
@@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
if not dense_modules:
return None
- module = dense_modules[0]
- folder = module.get("path", "")
+ layers = []
+ for module in dense_modules:
+ folder = module.get("path", "")
- config_path = f"{folder}/config.json" if folder else "config.json"
- layer_config = get_hf_file_to_dict(config_path, model_config.model,
- model_config.revision)
- if not layer_config:
- return None
+ config_path = f"{folder}/config.json" if folder else "config.json"
+ layer_config = get_hf_file_to_dict(config_path, model_config.model,
+ model_config.revision)
+ if not layer_config:
+ continue
- linear = nn.Linear(layer_config.get("in_features", 768),
- layer_config.get("out_features", 768),
- bias=layer_config.get("bias", True),
- dtype=torch.float32)
+ linear = nn.Linear(layer_config.get("in_features", 768),
+ layer_config.get("out_features", 768),
+ bias=layer_config.get("bias", True),
+ dtype=torch.float32)
- if _load_dense_weights(linear, folder, model_config):
- layers = [linear]
+ if not _load_dense_weights(linear, folder, model_config):
+ continue
+
+ layers.append(linear)
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
- return nn.Sequential(*layers).to(dtype=torch.float32)
-
+ return nn.Sequential(*layers).to(dtype=torch.float32)
except Exception:
logger.exception("ST projector loading failed")
diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py
index 8b76a54332f8..f38e7fc20220 100644
--- a/vllm/model_executor/models/config.py
+++ b/vllm/model_executor/models/config.py
@@ -24,6 +24,14 @@ class VerifyAndUpdateConfig:
raise NotImplementedError
+class Gemma3TextModelConfig:
+
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ hf_config = vllm_config.model_config.hf_config
+ hf_config.is_causal = not hf_config.use_bidirectional_attention
+
+
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
@@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
+ "Gemma3TextModel": Gemma3TextModelConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py
index 410c715d5241..1263e3049a14 100644
--- a/vllm/model_executor/models/gemma3.py
+++ b/vllm/model_executor/models/gemma3.py
@@ -24,7 +24,7 @@ import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig
-from vllm.attention import Attention
+from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
+from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
@@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module):
rope_scaling=self.rope_scaling,
)
- # Initialize the attention.
- self.attn = Attention(self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config,
- logits_soft_cap=attn_logits_soft_cap,
- per_layer_sliding_window=sliding_window,
- prefix=f"{prefix}.attn")
+ if getattr(config, "is_causal", True):
+ attn_type = AttentionType.DECODER
+ else:
+ attn_type = AttentionType.ENCODER_ONLY
+
+ attn_cls = (EncoderOnlyAttention
+ if attn_type == AttentionType.ENCODER_ONLY else Attention)
+
+ self.attn = attn_cls(self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ attn_type=attn_type,
+ logits_soft_cap=attn_logits_soft_cap,
+ per_layer_sliding_window=sliding_window,
+ prefix=f"{prefix}.attn")
def forward(
self,
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 38d300b03d2c..c522fcab7f33 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -155,6 +155,7 @@ _EMBEDDING_MODELS = {
"BertModel": ("bert", "BertEmbeddingModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
+ "Gemma3TextModel": ("gemma3", "Gemma3Model"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
"GritLM": ("gritlm", "GritLM"),