mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[Model][7/N] Improve all pooling task | Deprecation as_reward_model. Extract hidden states prefer using new multi-vector retrieval API (#26686)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
parent
bcb6f5947f
commit
9e77ffca3f
@ -33,8 +33,8 @@ shown in the table below.
|
|||||||
| Architecture | `--convert` | Supported pooling tasks |
|
| Architecture | `--convert` | Supported pooling tasks |
|
||||||
|-------------------------------------------------|-------------|---------------------------------------|
|
|-------------------------------------------------|-------------|---------------------------------------|
|
||||||
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` |
|
| `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` |
|
||||||
|
| `*ForRewardModeling`, `*RewardModel` | `embed` | `token_embed`, `embed` |
|
||||||
| `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` |
|
| `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` |
|
||||||
| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` |
|
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
You can explicitly set `--convert <type>` to specify how to convert the model.
|
You can explicitly set `--convert <type>` to specify how to convert the model.
|
||||||
@ -70,7 +70,6 @@ the pooler assigned to each task has the following attributes by default:
|
|||||||
|
|
||||||
| Task | Pooling Type | Normalization | Softmax |
|
| Task | Pooling Type | Normalization | Softmax |
|
||||||
|------------|--------------|---------------|---------|
|
|------------|--------------|---------------|---------|
|
||||||
| `reward` | `ALL` | ❌ | ❌ |
|
|
||||||
| `embed` | `LAST` | ✅︎ | ❌ |
|
| `embed` | `LAST` | ✅︎ | ❌ |
|
||||||
| `classify` | `LAST` | ❌ | ✅︎ |
|
| `classify` | `LAST` | ❌ | ✅︎ |
|
||||||
|
|
||||||
@ -318,3 +317,10 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
|
|||||||
### Remove softmax from PoolingParams
|
### Remove softmax from PoolingParams
|
||||||
|
|
||||||
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
||||||
|
|
||||||
|
### as_reward_model
|
||||||
|
|
||||||
|
Pooling models now default support all pooling, you can use it without any settings.
|
||||||
|
|
||||||
|
- Extracting hidden states prefers using `token_embed` task.
|
||||||
|
- Reward models prefers using `token_classify` task.
|
||||||
|
|||||||
@ -581,16 +581,9 @@ These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward)
|
|||||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||||
|--------------|--------|-------------------|----------------------|---------------------------|
|
|--------------|--------|-------------------|----------------------|---------------------------|
|
||||||
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ |
|
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ |
|
||||||
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ |
|
| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ |
|
| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ |
|
| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ |
|
||||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
|
|
||||||
|
|
||||||
<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion))
|
|
||||||
\* Feature support is the same as that of the original model.
|
|
||||||
|
|
||||||
If your model is not in the above list, we will try to automatically convert the model using
|
|
||||||
[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly.
|
|
||||||
|
|
||||||
!!! important
|
!!! important
|
||||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||||
|
|||||||
71
examples/pooling/token_embed/jina_embeddings_v4.py
Normal file
71
examples/pooling/token_embed/jina_embeddings_v4.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.inputs.data import TextPrompt
|
||||||
|
from vllm.multimodal.utils import fetch_image
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = LLM(
|
||||||
|
model="jinaai/jina-embeddings-v4-vllm-text-matching",
|
||||||
|
runner="pooling",
|
||||||
|
max_model_len=1024,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create text prompts
|
||||||
|
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
|
||||||
|
text1_prompt = TextPrompt(prompt=f"Query: {text1}")
|
||||||
|
|
||||||
|
text2 = "浜辺に沈む美しい夕日"
|
||||||
|
text2_prompt = TextPrompt(prompt=f"Query: {text2}")
|
||||||
|
|
||||||
|
# Create image prompt
|
||||||
|
image = fetch_image(
|
||||||
|
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
|
||||||
|
)
|
||||||
|
image_prompt = TextPrompt(
|
||||||
|
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
|
||||||
|
multi_modal_data={"image": image},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode all prompts
|
||||||
|
prompts = [text1_prompt, text2_prompt, image_prompt]
|
||||||
|
outputs = model.encode(prompts, pooling_task="token_embed")
|
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings(outputs):
|
||||||
|
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for output in outputs:
|
||||||
|
if VISION_START_TOKEN_ID in output.prompt_token_ids:
|
||||||
|
# Gather only vision tokens
|
||||||
|
img_start_pos = torch.where(
|
||||||
|
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
|
||||||
|
)[0][0]
|
||||||
|
img_end_pos = torch.where(
|
||||||
|
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
|
||||||
|
)[0][0]
|
||||||
|
embeddings_tensor = output.outputs.data.detach().clone()[
|
||||||
|
img_start_pos : img_end_pos + 1
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Use all tokens for text-only prompts
|
||||||
|
embeddings_tensor = output.outputs.data.detach().clone()
|
||||||
|
|
||||||
|
# Pool and normalize embeddings
|
||||||
|
pooled_output = (
|
||||||
|
embeddings_tensor.sum(dim=0, dtype=torch.float32)
|
||||||
|
/ embeddings_tensor.shape[0]
|
||||||
|
)
|
||||||
|
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
embeddings = get_embeddings(outputs)
|
||||||
|
|
||||||
|
for embedding in embeddings:
|
||||||
|
print(embedding.shape)
|
||||||
@ -13,7 +13,6 @@ from vllm.model_executor.models import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.models.adapters import (
|
from vllm.model_executor.models.adapters import (
|
||||||
as_embedding_model,
|
as_embedding_model,
|
||||||
as_reward_model,
|
|
||||||
as_seq_cls_model,
|
as_seq_cls_model,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.registry import (
|
from vllm.model_executor.models.registry import (
|
||||||
@ -46,7 +45,6 @@ def test_registry_imports(model_arch):
|
|||||||
# All vLLM models should be convertible to a pooling model
|
# All vLLM models should be convertible to a pooling model
|
||||||
assert is_pooling_model(as_seq_cls_model(model_cls))
|
assert is_pooling_model(as_seq_cls_model(model_cls))
|
||||||
assert is_pooling_model(as_embedding_model(model_cls))
|
assert is_pooling_model(as_embedding_model(model_cls))
|
||||||
assert is_pooling_model(as_reward_model(model_cls))
|
|
||||||
|
|
||||||
if model_arch in _MULTIMODAL_MODELS:
|
if model_arch in _MULTIMODAL_MODELS:
|
||||||
assert supports_multimodal(model_cls)
|
assert supports_multimodal(model_cls)
|
||||||
|
|||||||
@ -97,7 +97,7 @@ def test_update_config():
|
|||||||
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
|
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
|
||||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
||||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
|
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
|
||||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"),
|
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
|
||||||
("openai/whisper-small", "generate", "none", "transcription"),
|
("openai/whisper-small", "generate", "none", "transcription"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -516,7 +516,11 @@ class ModelConfig:
|
|||||||
if task == "classify":
|
if task == "classify":
|
||||||
return "classify"
|
return "classify"
|
||||||
if task == "reward":
|
if task == "reward":
|
||||||
return "reward"
|
logger.warning(
|
||||||
|
"Pooling models now default support all pooling; "
|
||||||
|
"you can use it without any settings."
|
||||||
|
)
|
||||||
|
return "embed"
|
||||||
if task == "score":
|
if task == "score":
|
||||||
new_task = self._get_default_pooling_task(architectures)
|
new_task = self._get_default_pooling_task(architectures)
|
||||||
return "classify" if new_task == "classify" else "embed"
|
return "classify" if new_task == "classify" else "embed"
|
||||||
@ -1899,8 +1903,8 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
|
|||||||
("ForImageClassification", ("pooling", "classify")),
|
("ForImageClassification", ("pooling", "classify")),
|
||||||
("ForVideoClassification", ("pooling", "classify")),
|
("ForVideoClassification", ("pooling", "classify")),
|
||||||
("ClassificationModel", ("pooling", "classify")),
|
("ClassificationModel", ("pooling", "classify")),
|
||||||
("ForRewardModeling", ("pooling", "reward")),
|
("ForRewardModeling", ("pooling", "embed")),
|
||||||
("RewardModel", ("pooling", "reward")),
|
("RewardModel", ("pooling", "embed")),
|
||||||
# Let other `*Model`s take priority
|
# Let other `*Model`s take priority
|
||||||
("Model", ("pooling", "embed")),
|
("Model", ("pooling", "embed")),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -167,7 +167,6 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
|||||||
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||||
from vllm.model_executor.models.adapters import (
|
from vllm.model_executor.models.adapters import (
|
||||||
as_embedding_model,
|
as_embedding_model,
|
||||||
as_reward_model,
|
|
||||||
as_seq_cls_model,
|
as_seq_cls_model,
|
||||||
try_create_mm_pooling_model_cls,
|
try_create_mm_pooling_model_cls,
|
||||||
)
|
)
|
||||||
@ -207,9 +206,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
|
|||||||
elif convert_type == "classify":
|
elif convert_type == "classify":
|
||||||
logger.debug_once("Converting to sequence classification model.")
|
logger.debug_once("Converting to sequence classification model.")
|
||||||
model_cls = as_seq_cls_model(model_cls)
|
model_cls = as_seq_cls_model(model_cls)
|
||||||
elif convert_type == "reward":
|
|
||||||
logger.debug_once("Converting to reward model.")
|
|
||||||
model_cls = as_reward_model(model_cls)
|
|
||||||
else:
|
else:
|
||||||
assert_never(convert_type)
|
assert_never(convert_type)
|
||||||
|
|
||||||
|
|||||||
@ -346,44 +346,6 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
return ModelForSequenceClassification # type: ignore
|
return ModelForSequenceClassification # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def as_reward_model(cls: _T) -> _T:
|
|
||||||
"""
|
|
||||||
Subclass an existing vLLM model to support reward modeling.
|
|
||||||
|
|
||||||
By default, we return the hidden states of each token directly.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
We assume that no extra layers are added to the original model;
|
|
||||||
please implement your own model if this is not the case.
|
|
||||||
"""
|
|
||||||
# Avoid modifying existing reward models
|
|
||||||
if is_pooling_model(cls):
|
|
||||||
return cls
|
|
||||||
|
|
||||||
# Lazy import
|
|
||||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
|
||||||
|
|
||||||
from .interfaces_base import default_pooling_type
|
|
||||||
|
|
||||||
@default_pooling_type("ALL")
|
|
||||||
class ModelForReward(_create_pooling_model_cls(cls)):
|
|
||||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
|
||||||
assert pooler_config is not None
|
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
|
||||||
{
|
|
||||||
"token_classify": Pooler.for_token_classify(
|
|
||||||
pooler_config=pooler_config
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
|
|
||||||
|
|
||||||
return ModelForReward # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
class SequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user