mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
[V1] Do not detokenize if sampling param detokenize is False (#14224)
Signed-off-by: Himanshu Jaju <hj@mistral.ai> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
9f1710f1ac
commit
cd579352bf
@ -14,7 +14,10 @@ PROMPT = "Hello my name is Robert and I"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model() -> LLM:
|
||||
return LLM(MODEL, enforce_eager=True)
|
||||
# Disable prefix caching so that we can test prompt logprobs.
|
||||
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
|
||||
# is merged
|
||||
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
|
||||
|
||||
|
||||
def test_n_gt_1(model):
|
||||
@ -87,9 +90,33 @@ def test_stop_token_ids(model):
|
||||
|
||||
stop_token_ids = [stop_token_id_0, stop_token_id_1]
|
||||
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
|
||||
output = model.generate(PROMPT, params)
|
||||
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
|
||||
|
||||
|
||||
def test_detokenize_false(model):
|
||||
"""Check that detokenize=False option works."""
|
||||
|
||||
output = model.generate(PROMPT, SamplingParams(detokenize=False))
|
||||
assert len(output[0].outputs[0].token_ids) > 0
|
||||
assert len(output[0].outputs[0].text) == 0
|
||||
|
||||
output = model.generate(
|
||||
PROMPT, SamplingParams(detokenize=False, logprobs=3,
|
||||
prompt_logprobs=3))
|
||||
assert len(output[0].outputs[0].token_ids) > 0
|
||||
assert len(output[0].outputs[0].text) == 0
|
||||
|
||||
prompt_logprobs = output[0].prompt_logprobs
|
||||
sampled_logprobs = output[0].outputs[0].logprobs
|
||||
assert len(prompt_logprobs) > 1
|
||||
assert len(sampled_logprobs) > 1
|
||||
for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
|
||||
for logprobs in all_logprobs:
|
||||
assert 3 <= len(logprobs) <= 4
|
||||
assert all(lp.decoded_token is None for lp in logprobs.values())
|
||||
|
||||
|
||||
def test_bad_words(model):
|
||||
"""Check that we respect bad words."""
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
@ -16,41 +16,46 @@ logger = init_logger(__name__)
|
||||
class IncrementalDetokenizer:
|
||||
|
||||
# Generation data
|
||||
output_text: str
|
||||
tokens: list[str]
|
||||
token_ids: list[int]
|
||||
prompt_len: int
|
||||
output_text: str = ""
|
||||
tokens: list[str] = field(default_factory=list)
|
||||
prompt_len: int = 0
|
||||
|
||||
# Stop strings
|
||||
stop: list[str]
|
||||
include_stop_str_in_output: bool
|
||||
stop: list[str] = field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
|
||||
# Metadata for incremental detokenization
|
||||
prefix_offset: int
|
||||
read_offset: int
|
||||
prefix_offset: int = 0
|
||||
read_offset: int = 0
|
||||
|
||||
# Parameters for detokenization
|
||||
skip_special_tokens: bool
|
||||
spaces_between_special_tokens: bool
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
|
||||
# Tokenizer for this request
|
||||
tokenizer: AnyTokenizer
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
|
||||
# Accounting for stop string buffering
|
||||
stop_buffer_length: int
|
||||
stop_buffer_length: int = 0
|
||||
_last_output_text_offset: int = 0
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return self.token_ids[self.prompt_len:]
|
||||
return self.token_ids if not self.prompt_len else (
|
||||
self.token_ids[self.prompt_len:])
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "IncrementalDetokenizer":
|
||||
|
||||
if tokenizer is None:
|
||||
return cls(token_ids=[])
|
||||
|
||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
@ -66,7 +71,6 @@ class IncrementalDetokenizer:
|
||||
stop_buffer_length = 0
|
||||
|
||||
return cls(
|
||||
output_text="",
|
||||
tokens=tokens,
|
||||
# Detokenizer mutates this list, so need a unique copy.
|
||||
# NOTE(Nick): could we take ownership of it though?
|
||||
@ -93,6 +97,10 @@ class IncrementalDetokenizer:
|
||||
Return matched stop string or None.
|
||||
"""
|
||||
|
||||
if self.tokenizer is None:
|
||||
self.token_ids.extend(new_token_ids)
|
||||
return None
|
||||
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
# new_token_ids is more than 1. We need to optimize this.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@ -13,12 +14,15 @@ from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobsProcessor:
|
||||
|
||||
# Tokenizer for this request
|
||||
tokenizer: AnyTokenizer
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: Optional[AnyTokenizer]
|
||||
|
||||
# Logprobs for this request
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
@ -30,7 +34,7 @@ class LogprobsProcessor:
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
num_logprobs = request.sampling_params.logprobs
|
||||
@ -66,8 +70,8 @@ class LogprobsProcessor:
|
||||
token_ids_lst):
|
||||
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens = convert_ids_list_to_tokens(
|
||||
self.tokenizer, token_ids)
|
||||
decoded_tokens = NONES if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer, token_ids))
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
@ -103,9 +107,9 @@ class LogprobsProcessor:
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = convert_ids_list_to_tokens(
|
||||
self.tokenizer,
|
||||
token_ids.flatten().tolist())
|
||||
decoded_tokens = None if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer,
|
||||
token_ids.flatten().tolist()))
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
@ -121,7 +125,8 @@ class LogprobsProcessor:
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
|
||||
decoded_tokens_for_pos = NONES \
|
||||
if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.prompt_logprobs.append(
|
||||
@ -153,7 +158,7 @@ class LogprobsProcessor:
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: list[str],
|
||||
decoded_tokens: Iterable[Optional[str]],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
|
||||
@ -68,6 +68,8 @@ class RequestState:
|
||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||
log_stats: bool,
|
||||
) -> "RequestState":
|
||||
if not request.sampling_params.detokenize:
|
||||
tokenizer = None
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
parent_req=parent_req,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user