mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Doc] Update V1 status for decoder-only embedding models (#19952)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
1bcd15edc7
commit
5111642a6f
@ -407,15 +407,15 @@ Specified using `--task embed`.
|
||||
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|-----------------------|
|
||||
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
|
||||
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | |
|
||||
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
|
||||
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | | |
|
||||
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ | |
|
||||
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ | |
|
||||
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ | |
|
||||
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | |
|
||||
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | |
|
||||
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | |
|
||||
|
||||
!!! note
|
||||
@ -442,9 +442,10 @@ Specified using `--task reward`.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|---------------------------|-----------------|------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------|
|
||||
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | |
|
||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | |
|
||||
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
|
||||
@ -460,7 +461,7 @@ Specified using `--task classify`.
|
||||
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|----------------------------------|----------|----------------------------------------|------------------------|-----------------------------|-----------------------|
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
|
||||
|
||||
@ -471,7 +472,7 @@ Specified using `--task score`.
|
||||
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
|
||||
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
|
||||
|
||||
|
||||
@ -19,24 +19,12 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
||||
class ReLU(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, input):
|
||||
input, _ = input
|
||||
return self.activation(input)
|
||||
|
||||
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only):
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -65,11 +53,13 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
self.score = nn.Sequential(
|
||||
ColumnParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config),
|
||||
ReLU(),
|
||||
quant_config=quant_config,
|
||||
return_bias=False),
|
||||
nn.ReLU(),
|
||||
RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config),
|
||||
quant_config=quant_config,
|
||||
return_bias=False),
|
||||
)
|
||||
self._pooler: SimplePooler
|
||||
self.make_empty_intermediate_tensors = (
|
||||
@ -87,7 +77,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
logits = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user