[Model][7/N] Improve all pooling task | Deprecation as_reward_model. Extract hidden states prefer using new multi-vector retrieval API (#26686)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi 2025-12-08 16:10:09 +08:00 committed by GitHub
parent bcb6f5947f
commit 9e77ffca3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 88 additions and 58 deletions

View File

@ -33,8 +33,8 @@ shown in the table below.
| Architecture | `--convert` | Supported pooling tasks | | Architecture | `--convert` | Supported pooling tasks |
|-------------------------------------------------|-------------|---------------------------------------| |-------------------------------------------------|-------------|---------------------------------------|
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` | | `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` |
| `*ForRewardModeling`, `*RewardModel` | `embed` | `token_embed`, `embed` |
| `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` | | `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` |
| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` |
!!! tip !!! tip
You can explicitly set `--convert <type>` to specify how to convert the model. You can explicitly set `--convert <type>` to specify how to convert the model.
@ -70,7 +70,6 @@ the pooler assigned to each task has the following attributes by default:
| Task | Pooling Type | Normalization | Softmax | | Task | Pooling Type | Normalization | Softmax |
|------------|--------------|---------------|---------| |------------|--------------|---------------|---------|
| `reward` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ | | `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ | | `classify` | `LAST` | ❌ | ✅︎ |
@ -318,3 +317,10 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
### Remove softmax from PoolingParams ### Remove softmax from PoolingParams
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function. We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
### as_reward_model
Pooling models now default support all pooling, you can use it without any settings.
- Extracting hidden states prefers using `token_embed` task.
- Reward models prefers using `token_classify` task.

View File

@ -581,16 +581,9 @@ These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward)
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------| |--------------|--------|-------------------|----------------------|---------------------------|
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | | `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ |
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ |
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | | `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ |
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | | `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion))
\* Feature support is the same as that of the original model.
If your model is not in the above list, we will try to automatically convert the model using
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
!!! important !!! important
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,

View File

@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import LLM
from vllm.inputs.data import TextPrompt
from vllm.multimodal.utils import fetch_image
# Initialize model
model = LLM(
model="jinaai/jina-embeddings-v4-vllm-text-matching",
runner="pooling",
max_model_len=1024,
gpu_memory_utilization=0.8,
)
# Create text prompts
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
text1_prompt = TextPrompt(prompt=f"Query: {text1}")
text2 = "浜辺に沈む美しい夕日"
text2_prompt = TextPrompt(prompt=f"Query: {text2}")
# Create image prompt
image = fetch_image(
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
)
image_prompt = TextPrompt(
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
multi_modal_data={"image": image},
)
# Encode all prompts
prompts = [text1_prompt, text2_prompt, image_prompt]
outputs = model.encode(prompts, pooling_task="token_embed")
def get_embeddings(outputs):
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
embeddings = []
for output in outputs:
if VISION_START_TOKEN_ID in output.prompt_token_ids:
# Gather only vision tokens
img_start_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
)[0][0]
img_end_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
)[0][0]
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# Use all tokens for text-only prompts
embeddings_tensor = output.outputs.data.detach().clone()
# Pool and normalize embeddings
pooled_output = (
embeddings_tensor.sum(dim=0, dtype=torch.float32)
/ embeddings_tensor.shape[0]
)
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
return embeddings
embeddings = get_embeddings(outputs)
for embedding in embeddings:
print(embedding.shape)

View File

@ -13,7 +13,6 @@ from vllm.model_executor.models import (
) )
from vllm.model_executor.models.adapters import ( from vllm.model_executor.models.adapters import (
as_embedding_model, as_embedding_model,
as_reward_model,
as_seq_cls_model, as_seq_cls_model,
) )
from vllm.model_executor.models.registry import ( from vllm.model_executor.models.registry import (
@ -46,7 +45,6 @@ def test_registry_imports(model_arch):
# All vLLM models should be convertible to a pooling model # All vLLM models should be convertible to a pooling model
assert is_pooling_model(as_seq_cls_model(model_cls)) assert is_pooling_model(as_seq_cls_model(model_cls))
assert is_pooling_model(as_embedding_model(model_cls)) assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))
if model_arch in _MULTIMODAL_MODELS: if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls) assert supports_multimodal(model_cls)

View File

@ -97,7 +97,7 @@ def test_update_config():
("intfloat/multilingual-e5-small", "pooling", "none", "embed"), ("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
("openai/whisper-small", "generate", "none", "transcription"), ("openai/whisper-small", "generate", "none", "transcription"),
], ],
) )

View File

@ -516,7 +516,11 @@ class ModelConfig:
if task == "classify": if task == "classify":
return "classify" return "classify"
if task == "reward": if task == "reward":
return "reward" logger.warning(
"Pooling models now default support all pooling; "
"you can use it without any settings."
)
return "embed"
if task == "score": if task == "score":
new_task = self._get_default_pooling_task(architectures) new_task = self._get_default_pooling_task(architectures)
return "classify" if new_task == "classify" else "embed" return "classify" if new_task == "classify" else "embed"
@ -1899,8 +1903,8 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
("ForImageClassification", ("pooling", "classify")), ("ForImageClassification", ("pooling", "classify")),
("ForVideoClassification", ("pooling", "classify")), ("ForVideoClassification", ("pooling", "classify")),
("ClassificationModel", ("pooling", "classify")), ("ClassificationModel", ("pooling", "classify")),
("ForRewardModeling", ("pooling", "reward")), ("ForRewardModeling", ("pooling", "embed")),
("RewardModel", ("pooling", "reward")), ("RewardModel", ("pooling", "embed")),
# Let other `*Model`s take priority # Let other `*Model`s take priority
("Model", ("pooling", "embed")), ("Model", ("pooling", "embed")),
] ]

View File

@ -167,7 +167,6 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import ( from vllm.model_executor.models.adapters import (
as_embedding_model, as_embedding_model,
as_reward_model,
as_seq_cls_model, as_seq_cls_model,
try_create_mm_pooling_model_cls, try_create_mm_pooling_model_cls,
) )
@ -207,9 +206,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
elif convert_type == "classify": elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.") logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls) model_cls = as_seq_cls_model(model_cls)
elif convert_type == "reward":
logger.debug_once("Converting to reward model.")
model_cls = as_reward_model(model_cls)
else: else:
assert_never(convert_type) assert_never(convert_type)

View File

@ -346,44 +346,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return ModelForSequenceClassification # type: ignore return ModelForSequenceClassification # type: ignore
def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.
By default, we return the hidden states of each token directly.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from .interfaces_base import default_pooling_type
@default_pooling_type("ALL")
class ModelForReward(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
)
}
)
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore
class SequenceClassificationConfig(VerifyAndUpdateConfig): class SequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None: