From 9e77ffca3f41e0e73879098f1686a4c82b8619d9 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Mon, 8 Dec 2025 16:10:09 +0800 Subject: [PATCH] [Model][7/N] Improve all pooling task | Deprecation as_reward_model. Extract hidden states prefer using new multi-vector retrieval API (#26686) Signed-off-by: wang.yuqi --- docs/models/pooling_models.md | 10 ++- docs/models/supported_models.md | 9 +-- .../pooling/token_embed/jina_embeddings_v4.py | 71 +++++++++++++++++++ tests/models/test_registry.py | 2 - tests/test_config.py | 2 +- vllm/config/model.py | 10 ++- vllm/model_executor/model_loader/utils.py | 4 -- vllm/model_executor/models/adapters.py | 38 ---------- 8 files changed, 88 insertions(+), 58 deletions(-) create mode 100644 examples/pooling/token_embed/jina_embeddings_v4.py diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index e2d427e8a459..32ffcf96fabe 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -33,8 +33,8 @@ shown in the table below. | Architecture | `--convert` | Supported pooling tasks | |-------------------------------------------------|-------------|---------------------------------------| | `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` | +| `*ForRewardModeling`, `*RewardModel` | `embed` | `token_embed`, `embed` | | `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` | -| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` | !!! tip You can explicitly set `--convert ` to specify how to convert the model. @@ -70,7 +70,6 @@ the pooler assigned to each task has the following attributes by default: | Task | Pooling Type | Normalization | Softmax | |------------|--------------|---------------|---------| -| `reward` | `ALL` | ❌ | ❌ | | `embed` | `LAST` | ✅︎ | ❌ | | `classify` | `LAST` | ❌ | ✅︎ | @@ -318,3 +317,10 @@ We have split the `encode` task into two more specific token-wise tasks: `token_ ### Remove softmax from PoolingParams We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function. + +### as_reward_model + +Pooling models now default support all pooling, you can use it without any settings. + +- Extracting hidden states prefers using `token_embed` task. +- Reward models prefers using `token_classify` task. diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index ec3ba4474c19..d0166060c267 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -581,16 +581,9 @@ These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|-------------------|----------------------|---------------------------| | `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | -| `LlamaForCausalLM`C | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | +| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | | `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | | `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | -| `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | - -C Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion)) -\* Feature support is the same as that of the original model. - -If your model is not in the above list, we will try to automatically convert the model using -[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. !!! important For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, diff --git a/examples/pooling/token_embed/jina_embeddings_v4.py b/examples/pooling/token_embed/jina_embeddings_v4.py new file mode 100644 index 000000000000..83d4c446d426 --- /dev/null +++ b/examples/pooling/token_embed/jina_embeddings_v4.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm import LLM +from vllm.inputs.data import TextPrompt +from vllm.multimodal.utils import fetch_image + +# Initialize model +model = LLM( + model="jinaai/jina-embeddings-v4-vllm-text-matching", + runner="pooling", + max_model_len=1024, + gpu_memory_utilization=0.8, +) + +# Create text prompts +text1 = "Ein wunderschöner Sonnenuntergang am Strand" +text1_prompt = TextPrompt(prompt=f"Query: {text1}") + +text2 = "浜辺に沈む美しい夕日" +text2_prompt = TextPrompt(prompt=f"Query: {text2}") + +# Create image prompt +image = fetch_image( + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501 +) +image_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501 + multi_modal_data={"image": image}, +) + +# Encode all prompts +prompts = [text1_prompt, text2_prompt, image_prompt] +outputs = model.encode(prompts, pooling_task="token_embed") + + +def get_embeddings(outputs): + VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653 + + embeddings = [] + for output in outputs: + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # Gather only vision tokens + img_start_pos = torch.where( + torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID + )[0][0] + img_end_pos = torch.where( + torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID + )[0][0] + embeddings_tensor = output.outputs.data.detach().clone()[ + img_start_pos : img_end_pos + 1 + ] + else: + # Use all tokens for text-only prompts + embeddings_tensor = output.outputs.data.detach().clone() + + # Pool and normalize embeddings + pooled_output = ( + embeddings_tensor.sum(dim=0, dtype=torch.float32) + / embeddings_tensor.shape[0] + ) + embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1)) + return embeddings + + +embeddings = get_embeddings(outputs) + +for embedding in embeddings: + print(embedding.shape) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 9017a0fd9140..a089696e10ff 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -13,7 +13,6 @@ from vllm.model_executor.models import ( ) from vllm.model_executor.models.adapters import ( as_embedding_model, - as_reward_model, as_seq_cls_model, ) from vllm.model_executor.models.registry import ( @@ -46,7 +45,6 @@ def test_registry_imports(model_arch): # All vLLM models should be convertible to a pooling model assert is_pooling_model(as_seq_cls_model(model_cls)) assert is_pooling_model(as_embedding_model(model_cls)) - assert is_pooling_model(as_reward_model(model_cls)) if model_arch in _MULTIMODAL_MODELS: assert supports_multimodal(model_cls) diff --git a/tests/test_config.py b/tests/test_config.py index 203447cd531f..77d3a7115978 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -97,7 +97,7 @@ def test_update_config(): ("intfloat/multilingual-e5-small", "pooling", "none", "embed"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"), - ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"), + ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"), ("openai/whisper-small", "generate", "none", "transcription"), ], ) diff --git a/vllm/config/model.py b/vllm/config/model.py index 583904a949ea..764bdf700056 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -516,7 +516,11 @@ class ModelConfig: if task == "classify": return "classify" if task == "reward": - return "reward" + logger.warning( + "Pooling models now default support all pooling; " + "you can use it without any settings." + ) + return "embed" if task == "score": new_task = self._get_default_pooling_task(architectures) return "classify" if new_task == "classify" else "embed" @@ -1899,8 +1903,8 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ ("ForImageClassification", ("pooling", "classify")), ("ForVideoClassification", ("pooling", "classify")), ("ClassificationModel", ("pooling", "classify")), - ("ForRewardModeling", ("pooling", "reward")), - ("RewardModel", ("pooling", "reward")), + ("ForRewardModeling", ("pooling", "embed")), + ("RewardModel", ("pooling", "embed")), # Let other `*Model`s take priority ("Model", ("pooling", "embed")), ] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index eeb2444150ee..74b02e4c6258 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -167,7 +167,6 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: from vllm.model_executor.models.adapters import ( as_embedding_model, - as_reward_model, as_seq_cls_model, try_create_mm_pooling_model_cls, ) @@ -207,9 +206,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], elif convert_type == "classify": logger.debug_once("Converting to sequence classification model.") model_cls = as_seq_cls_model(model_cls) - elif convert_type == "reward": - logger.debug_once("Converting to reward model.") - model_cls = as_reward_model(model_cls) else: assert_never(convert_type) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 007d847ac3b7..70f203b9f7c6 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -346,44 +346,6 @@ def as_seq_cls_model(cls: _T) -> _T: return ModelForSequenceClassification # type: ignore -def as_reward_model(cls: _T) -> _T: - """ - Subclass an existing vLLM model to support reward modeling. - - By default, we return the hidden states of each token directly. - - Note: - We assume that no extra layers are added to the original model; - please implement your own model if this is not the case. - """ - # Avoid modifying existing reward models - if is_pooling_model(cls): - return cls - - # Lazy import - from vllm.model_executor.layers.pooler import DispatchPooler, Pooler - - from .interfaces_base import default_pooling_type - - @default_pooling_type("ALL") - class ModelForReward(_create_pooling_model_cls(cls)): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config=pooler_config - ) - } - ) - - ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") - - return ModelForReward # type: ignore - - class SequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: