[Frontend] Add LLM.reward specific to reward models (#21720)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-07-30 11:56:03 +08:00 committed by GitHub
parent 1b0a155534
commit 65f311ce59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 174 additions and 35 deletions

View File

@ -45,14 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
enabling the corresponding APIs:
| Task | APIs |
|------------|--------------------|
| `encode` | `encode` |
| `embed` | `embed`, `score`\* |
| `classify` | `classify` |
| `score` | `score` |
| Task | APIs |
|------------|--------------------------------------|
| `encode` | `LLM.reward(...)` |
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* |
| `classify` | `LLM.classify(...)` |
| `score` | `LLM.score(...)` |
\* The `score` API falls back to `embed` task if the model does not support `score` task.
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
### Pooler Configuration
@ -66,11 +66,11 @@ you can override some of its attributes via the `--override-pooler-config` optio
If the model has been converted via `--convert` (see above),
the pooler assigned to each task has the following attributes by default:
| Task | Pooling Type | Normalization | Softmax |
|------------|----------------|---------------|---------|
| `encode` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ |
| Task | Pooling Type | Normalization | Softmax |
|------------|--------------|---------------|---------|
| `reward` | `ALL` | ❌ | ❌ |
| `embed` | `LAST` | ✅︎ | ❌ |
| `classify` | `LAST` | ❌ | ✅︎ |
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
@ -83,21 +83,6 @@ which takes priority over both the model's and Sentence Transformers's defaults.
The [LLM][vllm.LLM] class provides various methods for offline inference.
See [configuration][configuration] for a list of options when initializing the model.
### `LLM.encode`
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
It returns the extracted hidden states directly, which is useful for reward models.
```python
from vllm import LLM
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", runner="pooling")
(output,) = llm.encode("Hello, my name is")
data = output.outputs.data
print(f"Data: {data!r}")
```
### `LLM.embed`
The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt.
@ -106,7 +91,7 @@ It is primarily designed for embedding models.
```python
from vllm import LLM
llm = LLM(model="intfloat/e5-mistral-7b-instruct", runner="pooling")
llm = LLM(model="intfloat/e5-small", runner="pooling")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
@ -154,6 +139,46 @@ print(f"Score: {score}")
A code example can be found here: <gh-file:examples/offline_inference/basic/score.py>
### `LLM.reward`
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
It returns the extracted hidden states directly.
```python
from vllm import LLM
llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True)
(output,) = llm.reward("Hello, my name is")
data = output.outputs.data
print(f"Data: {data!r}")
```
A code example can be found here: <gh-file:examples/offline_inference/basic/reward.py>
### `LLM.encode`
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
It returns the extracted hidden states directly.
!!! note
Please use one of the more specific methods or set the task directly when using `LLM.encode`:
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
- For similarity scores, use `LLM.score(...)`.
```python
from vllm import LLM
llm = LLM(model="intfloat/e5-small", runner="pooling")
(output,) = llm.encode("Hello, my name is", pooling_task="embed")
data = output.outputs.data
print(f"Data: {data!r}")
```
## Online Serving
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:

View File

@ -12,10 +12,9 @@ def parse_args():
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="intfloat/e5-mistral-7b-instruct",
model="intfloat/e5-small",
runner="pooling",
enforce_eager=True,
max_model_len=1024,
)
return parser.parse_args()

View File

@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="internlm/internlm2-1_8b-reward",
runner="pooling",
enforce_eager=True,
max_model_len=1024,
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create an LLM.
# You should pass runner="pooling" for reward models
llm = LLM(**vars(args))
# Generate rewards. The output is a list of PoolingRequestOutput.
outputs = llm.reward(prompts)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
rewards = output.outputs.data
rewards_trimmed = (
(str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards
)
print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})")
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -1053,6 +1053,10 @@ class VllmRunner:
req_outputs = self.llm.encode(prompts)
return [req_output.outputs.data for req_output in req_outputs]
def reward(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.reward(prompts)
return [req_output.outputs.data for req_output in req_outputs]
def score(
self,
text_1: Union[str, list[str]],

View File

@ -95,7 +95,7 @@ def test_prm_models(
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.encode(math_step_prompts)
vllm_outputs = vllm_model.reward(math_step_prompts)
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
hf_model = step_reward_patch_hf_model(hf_model)

View File

@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner,
with vllm_runner(model_name, runner="pooling",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.llm.encode(
vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
prompt_tokens = vllm_output[0].prompt_token_ids
@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner,
with vllm_runner(model_name, runner="pooling",
max_model_len=max_model_len) as vllm_model:
vllm_output = vllm_model.llm.encode(
vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
prompt_tokens = vllm_output[0].prompt_token_ids
@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner,
model_name, runner="pooling",
max_model_len=max_model_len) as vllm_model:
llm_output = vllm_model.llm.encode(
llm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
assert llm_output == f"""truncate_prompt_tokens value

View File

@ -1037,7 +1037,7 @@ class LLM:
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
pooling_task: PoolingTask = "encode",
pooling_task: Optional[PoolingTask] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input
@ -1069,6 +1069,25 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
if pooling_task is None:
if "embed" in self.supported_tasks:
pooling_task = "embed"
else:
pooling_task = "encode"
logger.warning_once(
"`LLM.encode` is currently using `pooling_task = %s`.\n"
"Please use one of the more specific methods or set the "
"task directly when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
"or `pooling_task=\"embed\"`.\n"
" - For classification logits, use `LLM.classify(...)` "
"or `pooling_task=\"classify\"`.\n"
" - For rewards, use `LLM.reward(...)` "
"or `pooling_task=\"reward\"`\n"
" - For similarity scores, use `LLM.score(...)`.",
pooling_task)
model_config = self.llm_engine.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
@ -1207,6 +1226,45 @@ class LLM:
return [ClassificationRequestOutput.from_base(item) for item in items]
def reward(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[PoolingRequestOutput]:
"""
Generate rewards for each prompt.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
Returns:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
"""
return self.encode(
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="encode",
)
def _embedding_score(
self,
tokenizer: AnyTokenizer,