[V1] Do not detokenize if sampling param detokenize is False (#14224)

Signed-off-by: Himanshu Jaju <hj@mistral.ai>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Himanshu Jaju 2025-03-06 19:40:24 +01:00 committed by GitHub
parent 9f1710f1ac
commit cd579352bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 27 deletions

View File

@ -14,7 +14,10 @@ PROMPT = "Hello my name is Robert and I"
@pytest.fixture(scope="module")
def model() -> LLM:
return LLM(MODEL, enforce_eager=True)
# Disable prefix caching so that we can test prompt logprobs.
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
# is merged
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
def test_n_gt_1(model):
@ -87,9 +90,33 @@ def test_stop_token_ids(model):
stop_token_ids = [stop_token_id_0, stop_token_id_1]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
output = model.generate(PROMPT, params)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
def test_detokenize_false(model):
"""Check that detokenize=False option works."""
output = model.generate(PROMPT, SamplingParams(detokenize=False))
assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0
output = model.generate(
PROMPT, SamplingParams(detokenize=False, logprobs=3,
prompt_logprobs=3))
assert len(output[0].outputs[0].token_ids) > 0
assert len(output[0].outputs[0].text) == 0
prompt_logprobs = output[0].prompt_logprobs
sampled_logprobs = output[0].outputs[0].logprobs
assert len(prompt_logprobs) > 1
assert len(sampled_logprobs) > 1
for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
for logprobs in all_logprobs:
assert 3 <= len(logprobs) <= 4
assert all(lp.decoded_token is None for lp in logprobs.values())
def test_bad_words(model):
"""Check that we respect bad words."""

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from vllm.engine.output_processor.stop_checker import StopChecker
@ -16,41 +16,46 @@ logger = init_logger(__name__)
class IncrementalDetokenizer:
# Generation data
output_text: str
tokens: list[str]
token_ids: list[int]
prompt_len: int
output_text: str = ""
tokens: list[str] = field(default_factory=list)
prompt_len: int = 0
# Stop strings
stop: list[str]
include_stop_str_in_output: bool
stop: list[str] = field(default_factory=list)
include_stop_str_in_output: bool = False
# Metadata for incremental detokenization
prefix_offset: int
read_offset: int
prefix_offset: int = 0
read_offset: int = 0
# Parameters for detokenization
skip_special_tokens: bool
spaces_between_special_tokens: bool
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Tokenizer for this request
tokenizer: AnyTokenizer
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer] = None
# Accounting for stop string buffering
stop_buffer_length: int
stop_buffer_length: int = 0
_last_output_text_offset: int = 0
@property
def output_token_ids(self) -> list[int]:
return self.token_ids[self.prompt_len:]
return self.token_ids if not self.prompt_len else (
self.token_ids[self.prompt_len:])
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
if tokenizer is None:
return cls(token_ids=[])
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
@ -66,7 +71,6 @@ class IncrementalDetokenizer:
stop_buffer_length = 0
return cls(
output_text="",
tokens=tokens,
# Detokenizer mutates this list, so need a unique copy.
# NOTE(Nick): could we take ownership of it though?
@ -93,6 +97,10 @@ class IncrementalDetokenizer:
Return matched stop string or None.
"""
if self.tokenizer is None:
self.token_ids.extend(new_token_ids)
return None
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
@ -13,12 +14,15 @@ from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
NONES = itertools.repeat(None)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request
tokenizer: AnyTokenizer
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer]
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
@ -30,7 +34,7 @@ class LogprobsProcessor:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
@ -66,8 +70,8 @@ class LogprobsProcessor:
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer, token_ids)
decoded_tokens = NONES if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer, token_ids))
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
@ -103,9 +107,9 @@ class LogprobsProcessor:
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer,
token_ids.flatten().tolist())
decoded_tokens = None if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer,
token_ids.flatten().tolist()))
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
@ -121,7 +125,8 @@ class LogprobsProcessor:
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
decoded_tokens_for_pos = NONES \
if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
@ -153,7 +158,7 @@ class LogprobsProcessor:
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: list[str],
decoded_tokens: Iterable[Optional[str]],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:

View File

@ -68,6 +68,8 @@ class RequestState:
queue: Optional[asyncio.Queue[RequestOutput]],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
tokenizer = None
return cls(
request_id=request.request_id,
parent_req=parent_req,