diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 1d165fa6f16bd..7908e42387100 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -363,7 +363,7 @@ th {
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
-| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
+| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ |
@@ -436,17 +436,17 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
-| `BertModel`C | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
-| `Gemma2Model`C | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
-| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
-| `GteModel`C | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | |
-| `GteNewModel`C | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | |
-| `ModernBertModel`C | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | |
-| `NomicBertModel`C | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | |
+| `BertModel`C | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
+| `Gemma2Model`C | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
+| `GteModel`C | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
+| `GteNewModel`C | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |
+| `ModernBertModel`C | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ |
+| `NomicBertModel`C | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ |
| `LlamaModel`C, `LlamaForCausalLM`C, `MistralModel`C, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2Model`C, `Qwen2ForCausalLM`C | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3Model`C, `Qwen3ForCausalLM`C | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | |
+| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ |
| `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* |
C Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion))
@@ -476,7 +476,7 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
-| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
+| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
| `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* |
@@ -493,12 +493,12 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
-| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | |
+| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
-| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | |
-| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | |
+| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
+| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ |
| `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | \* |
C Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py
index d21987571cbaa..17a55d916b1ff 100644
--- a/tests/models/language/pooling/test_gritlm.py
+++ b/tests/models/language/pooling/test_gritlm.py
@@ -14,6 +14,7 @@ from ....utils import RemoteOpenAIServer
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000
+ATOL = 0.002
def _arr(arr):
@@ -97,16 +98,16 @@ def get_test_data():
def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0])
- assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=0.001)
+ assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=ATOL)
cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1])
- assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=0.001)
+ assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=ATOL)
cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0])
- assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=0.001)
+ assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=ATOL)
cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1])
- assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=0.001)
+ assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=ATOL)
def test_gritlm_offline_embedding(vllm_runner):
diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py
index 9e7490e3c4f07..3f6790269ae62 100644
--- a/vllm/model_executor/models/gritlm.py
+++ b/vllm/model_executor/models/gritlm.py
@@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
-from .interfaces import SupportsV0Only
+from .interfaces import default_pooling_type
logger = init_logger(__name__)
@@ -215,7 +215,8 @@ class GritLMPooler(Pooler):
return build_output(pooled_data)
-class GritLM(LlamaForCausalLM, SupportsV0Only):
+@default_pooling_type("MEAN")
+class GritLM(LlamaForCausalLM):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
The class inherits from LlamaForCausalLM and provides a custom pooling
@@ -241,7 +242,6 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
prefix: str = "",
**kwargs,
) -> None:
- # Use full attention for pooling (this is why V1 is not supported yet)
if vllm_config.model_config.runner_type == "pooling":
hf_config = vllm_config.model_config.hf_config
hf_config.is_causal = False
diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py
index c425488f834b5..9415e67924e74 100644
--- a/vllm/model_executor/models/interfaces.py
+++ b/vllm/model_executor/models/interfaces.py
@@ -3,7 +3,7 @@
from collections.abc import Iterable, Mapping, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
- Union, overload, runtime_checkable)
+ TypeVar, Union, overload, runtime_checkable)
import numpy as np
import torch
@@ -641,11 +641,14 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)
-def default_pooling_type(pooling_type: str) -> object:
+_T = TypeVar("_T", bound=type[torch.nn.Module])
+
+
+def default_pooling_type(pooling_type: str):
"""Set default_pooling_type decorator. """
- def func(model: object):
- model.default_pooling_type = pooling_type
+ def func(model: _T) -> _T:
+ model.default_pooling_type = pooling_type # type: ignore
return model
return func