mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:24:54 +08:00
[Frontend] Add LLM.reward specific to reward models (#21720)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
1b0a155534
commit
65f311ce59
@ -45,14 +45,14 @@ Each pooling model in vLLM supports one or more of these tasks according to
|
||||
[Pooler.get_supported_tasks][vllm.model_executor.layers.pooler.Pooler.get_supported_tasks],
|
||||
enabling the corresponding APIs:
|
||||
|
||||
| Task | APIs |
|
||||
|------------|--------------------|
|
||||
| `encode` | `encode` |
|
||||
| `embed` | `embed`, `score`\* |
|
||||
| `classify` | `classify` |
|
||||
| `score` | `score` |
|
||||
| Task | APIs |
|
||||
|------------|--------------------------------------|
|
||||
| `encode` | `LLM.reward(...)` |
|
||||
| `embed` | `LLM.embed(...)`, `LLM.score(...)`\* |
|
||||
| `classify` | `LLM.classify(...)` |
|
||||
| `score` | `LLM.score(...)` |
|
||||
|
||||
\* The `score` API falls back to `embed` task if the model does not support `score` task.
|
||||
\* The `LLM.score(...)` API falls back to `embed` task if the model does not support `score` task.
|
||||
|
||||
### Pooler Configuration
|
||||
|
||||
@ -66,11 +66,11 @@ you can override some of its attributes via the `--override-pooler-config` optio
|
||||
If the model has been converted via `--convert` (see above),
|
||||
the pooler assigned to each task has the following attributes by default:
|
||||
|
||||
| Task | Pooling Type | Normalization | Softmax |
|
||||
|------------|----------------|---------------|---------|
|
||||
| `encode` | `ALL` | ❌ | ❌ |
|
||||
| `embed` | `LAST` | ✅︎ | ❌ |
|
||||
| `classify` | `LAST` | ❌ | ✅︎ |
|
||||
| Task | Pooling Type | Normalization | Softmax |
|
||||
|------------|--------------|---------------|---------|
|
||||
| `reward` | `ALL` | ❌ | ❌ |
|
||||
| `embed` | `LAST` | ✅︎ | ❌ |
|
||||
| `classify` | `LAST` | ❌ | ✅︎ |
|
||||
|
||||
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
||||
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
|
||||
@ -83,21 +83,6 @@ which takes priority over both the model's and Sentence Transformers's defaults.
|
||||
The [LLM][vllm.LLM] class provides various methods for offline inference.
|
||||
See [configuration][configuration] for a list of options when initializing the model.
|
||||
|
||||
### `LLM.encode`
|
||||
|
||||
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
|
||||
It returns the extracted hidden states directly, which is useful for reward models.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", runner="pooling")
|
||||
(output,) = llm.encode("Hello, my name is")
|
||||
|
||||
data = output.outputs.data
|
||||
print(f"Data: {data!r}")
|
||||
```
|
||||
|
||||
### `LLM.embed`
|
||||
|
||||
The [embed][vllm.LLM.embed] method outputs an embedding vector for each prompt.
|
||||
@ -106,7 +91,7 @@ It is primarily designed for embedding models.
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="intfloat/e5-mistral-7b-instruct", runner="pooling")
|
||||
llm = LLM(model="intfloat/e5-small", runner="pooling")
|
||||
(output,) = llm.embed("Hello, my name is")
|
||||
|
||||
embeds = output.outputs.embedding
|
||||
@ -154,6 +139,46 @@ print(f"Score: {score}")
|
||||
|
||||
A code example can be found here: <gh-file:examples/offline_inference/basic/score.py>
|
||||
|
||||
### `LLM.reward`
|
||||
|
||||
The [reward][vllm.LLM.reward] method is available to all reward models in vLLM.
|
||||
It returns the extracted hidden states directly.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="internlm/internlm2-1_8b-reward", runner="pooling", trust_remote_code=True)
|
||||
(output,) = llm.reward("Hello, my name is")
|
||||
|
||||
data = output.outputs.data
|
||||
print(f"Data: {data!r}")
|
||||
```
|
||||
|
||||
A code example can be found here: <gh-file:examples/offline_inference/basic/reward.py>
|
||||
|
||||
### `LLM.encode`
|
||||
|
||||
The [encode][vllm.LLM.encode] method is available to all pooling models in vLLM.
|
||||
It returns the extracted hidden states directly.
|
||||
|
||||
!!! note
|
||||
Please use one of the more specific methods or set the task directly when using `LLM.encode`:
|
||||
|
||||
- For embeddings, use `LLM.embed(...)` or `pooling_task="embed"`.
|
||||
- For classification logits, use `LLM.classify(...)` or `pooling_task="classify"`.
|
||||
- For rewards, use `LLM.reward(...)` or `pooling_task="reward"`.
|
||||
- For similarity scores, use `LLM.score(...)`.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="intfloat/e5-small", runner="pooling")
|
||||
(output,) = llm.encode("Hello, my name is", pooling_task="embed")
|
||||
|
||||
data = output.outputs.data
|
||||
print(f"Data: {data!r}")
|
||||
```
|
||||
|
||||
## Online Serving
|
||||
|
||||
Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs:
|
||||
|
||||
@ -12,10 +12,9 @@ def parse_args():
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="intfloat/e5-mistral-7b-instruct",
|
||||
model="intfloat/e5-small",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
53
examples/offline_inference/basic/reward.py
Normal file
53
examples/offline_inference/basic/reward.py
Normal file
@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="internlm/internlm2-1_8b-reward",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for reward models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate rewards. The output is a list of PoolingRequestOutput.
|
||||
outputs = llm.reward(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
rewards = output.outputs.data
|
||||
rewards_trimmed = (
|
||||
(str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards
|
||||
)
|
||||
print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -1053,6 +1053,10 @@ class VllmRunner:
|
||||
req_outputs = self.llm.encode(prompts)
|
||||
return [req_output.outputs.data for req_output in req_outputs]
|
||||
|
||||
def reward(self, prompts: list[str]) -> list[list[float]]:
|
||||
req_outputs = self.llm.reward(prompts)
|
||||
return [req_output.outputs.data for req_output in req_outputs]
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[str, list[str]],
|
||||
|
||||
@ -95,7 +95,7 @@ def test_prm_models(
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(math_step_prompts)
|
||||
vllm_outputs = vllm_model.reward(math_step_prompts)
|
||||
|
||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
|
||||
hf_model = step_reward_patch_hf_model(hf_model)
|
||||
|
||||
@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner,
|
||||
|
||||
with vllm_runner(model_name, runner="pooling",
|
||||
max_model_len=max_model_len) as vllm_model:
|
||||
vllm_output = vllm_model.llm.encode(
|
||||
vllm_output = vllm_model.llm.embed(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
|
||||
prompt_tokens = vllm_output[0].prompt_token_ids
|
||||
@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner,
|
||||
|
||||
with vllm_runner(model_name, runner="pooling",
|
||||
max_model_len=max_model_len) as vllm_model:
|
||||
vllm_output = vllm_model.llm.encode(
|
||||
vllm_output = vllm_model.llm.embed(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
|
||||
prompt_tokens = vllm_output[0].prompt_token_ids
|
||||
@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner,
|
||||
model_name, runner="pooling",
|
||||
max_model_len=max_model_len) as vllm_model:
|
||||
|
||||
llm_output = vllm_model.llm.encode(
|
||||
llm_output = vllm_model.llm.embed(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
|
||||
assert llm_output == f"""truncate_prompt_tokens value
|
||||
|
||||
@ -1037,7 +1037,7 @@ class LLM:
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
pooling_task: Optional[PoolingTask] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
"""Apply pooling to the hidden states corresponding to the input
|
||||
@ -1069,6 +1069,25 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
if pooling_task is None:
|
||||
if "embed" in self.supported_tasks:
|
||||
pooling_task = "embed"
|
||||
else:
|
||||
pooling_task = "encode"
|
||||
|
||||
logger.warning_once(
|
||||
"`LLM.encode` is currently using `pooling_task = %s`.\n"
|
||||
"Please use one of the more specific methods or set the "
|
||||
"task directly when using `LLM.encode`:\n"
|
||||
" - For embeddings, use `LLM.embed(...)` "
|
||||
"or `pooling_task=\"embed\"`.\n"
|
||||
" - For classification logits, use `LLM.classify(...)` "
|
||||
"or `pooling_task=\"classify\"`.\n"
|
||||
" - For rewards, use `LLM.reward(...)` "
|
||||
"or `pooling_task=\"reward\"`\n"
|
||||
" - For similarity scores, use `LLM.score(...)`.",
|
||||
pooling_task)
|
||||
|
||||
model_config = self.llm_engine.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
@ -1207,6 +1226,45 @@ class LLM:
|
||||
|
||||
return [ClassificationRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def reward(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
"""
|
||||
Generate rewards for each prompt.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See [PromptType][vllm.inputs.PromptType]
|
||||
for more details about the format of each prompts.
|
||||
use_tqdm: If `True`, shows a tqdm progress bar.
|
||||
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
Returns:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
"""
|
||||
|
||||
return self.encode(
|
||||
prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
pooling_task="encode",
|
||||
)
|
||||
|
||||
def _embedding_score(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user