mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:15:23 +08:00
[V1] Do not detokenize if sampling param detokenize is False (#14224)
Signed-off-by: Himanshu Jaju <hj@mistral.ai> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
9f1710f1ac
commit
cd579352bf
@ -14,7 +14,10 @@ PROMPT = "Hello my name is Robert and I"
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def model() -> LLM:
|
def model() -> LLM:
|
||||||
return LLM(MODEL, enforce_eager=True)
|
# Disable prefix caching so that we can test prompt logprobs.
|
||||||
|
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
|
||||||
|
# is merged
|
||||||
|
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
|
||||||
|
|
||||||
|
|
||||||
def test_n_gt_1(model):
|
def test_n_gt_1(model):
|
||||||
@ -87,9 +90,33 @@ def test_stop_token_ids(model):
|
|||||||
|
|
||||||
stop_token_ids = [stop_token_id_0, stop_token_id_1]
|
stop_token_ids = [stop_token_id_0, stop_token_id_1]
|
||||||
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
|
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
|
||||||
|
output = model.generate(PROMPT, params)
|
||||||
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
|
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detokenize_false(model):
|
||||||
|
"""Check that detokenize=False option works."""
|
||||||
|
|
||||||
|
output = model.generate(PROMPT, SamplingParams(detokenize=False))
|
||||||
|
assert len(output[0].outputs[0].token_ids) > 0
|
||||||
|
assert len(output[0].outputs[0].text) == 0
|
||||||
|
|
||||||
|
output = model.generate(
|
||||||
|
PROMPT, SamplingParams(detokenize=False, logprobs=3,
|
||||||
|
prompt_logprobs=3))
|
||||||
|
assert len(output[0].outputs[0].token_ids) > 0
|
||||||
|
assert len(output[0].outputs[0].text) == 0
|
||||||
|
|
||||||
|
prompt_logprobs = output[0].prompt_logprobs
|
||||||
|
sampled_logprobs = output[0].outputs[0].logprobs
|
||||||
|
assert len(prompt_logprobs) > 1
|
||||||
|
assert len(sampled_logprobs) > 1
|
||||||
|
for all_logprobs in (prompt_logprobs[1:], sampled_logprobs):
|
||||||
|
for logprobs in all_logprobs:
|
||||||
|
assert 3 <= len(logprobs) <= 4
|
||||||
|
assert all(lp.decoded_token is None for lp in logprobs.values())
|
||||||
|
|
||||||
|
|
||||||
def test_bad_words(model):
|
def test_bad_words(model):
|
||||||
"""Check that we respect bad words."""
|
"""Check that we respect bad words."""
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
@ -16,41 +16,46 @@ logger = init_logger(__name__)
|
|||||||
class IncrementalDetokenizer:
|
class IncrementalDetokenizer:
|
||||||
|
|
||||||
# Generation data
|
# Generation data
|
||||||
output_text: str
|
|
||||||
tokens: list[str]
|
|
||||||
token_ids: list[int]
|
token_ids: list[int]
|
||||||
prompt_len: int
|
output_text: str = ""
|
||||||
|
tokens: list[str] = field(default_factory=list)
|
||||||
|
prompt_len: int = 0
|
||||||
|
|
||||||
# Stop strings
|
# Stop strings
|
||||||
stop: list[str]
|
stop: list[str] = field(default_factory=list)
|
||||||
include_stop_str_in_output: bool
|
include_stop_str_in_output: bool = False
|
||||||
|
|
||||||
# Metadata for incremental detokenization
|
# Metadata for incremental detokenization
|
||||||
prefix_offset: int
|
prefix_offset: int = 0
|
||||||
read_offset: int
|
read_offset: int = 0
|
||||||
|
|
||||||
# Parameters for detokenization
|
# Parameters for detokenization
|
||||||
skip_special_tokens: bool
|
skip_special_tokens: bool = True
|
||||||
spaces_between_special_tokens: bool
|
spaces_between_special_tokens: bool = True
|
||||||
|
|
||||||
# Tokenizer for this request
|
# Tokenizer for this request,
|
||||||
tokenizer: AnyTokenizer
|
# None if detokenization is disabled.
|
||||||
|
tokenizer: Optional[AnyTokenizer] = None
|
||||||
|
|
||||||
# Accounting for stop string buffering
|
# Accounting for stop string buffering
|
||||||
stop_buffer_length: int
|
stop_buffer_length: int = 0
|
||||||
_last_output_text_offset: int = 0
|
_last_output_text_offset: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_token_ids(self) -> list[int]:
|
def output_token_ids(self) -> list[int]:
|
||||||
return self.token_ids[self.prompt_len:]
|
return self.token_ids if not self.prompt_len else (
|
||||||
|
self.token_ids[self.prompt_len:])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_new_request(
|
def from_new_request(
|
||||||
cls,
|
cls,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: Optional[AnyTokenizer],
|
||||||
request: EngineCoreRequest,
|
request: EngineCoreRequest,
|
||||||
) -> "IncrementalDetokenizer":
|
) -> "IncrementalDetokenizer":
|
||||||
|
|
||||||
|
if tokenizer is None:
|
||||||
|
return cls(token_ids=[])
|
||||||
|
|
||||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prompt_ids=request.prompt_token_ids,
|
prompt_ids=request.prompt_token_ids,
|
||||||
@ -66,7 +71,6 @@ class IncrementalDetokenizer:
|
|||||||
stop_buffer_length = 0
|
stop_buffer_length = 0
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
output_text="",
|
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
# Detokenizer mutates this list, so need a unique copy.
|
# Detokenizer mutates this list, so need a unique copy.
|
||||||
# NOTE(Nick): could we take ownership of it though?
|
# NOTE(Nick): could we take ownership of it though?
|
||||||
@ -93,6 +97,10 @@ class IncrementalDetokenizer:
|
|||||||
Return matched stop string or None.
|
Return matched stop string or None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.tokenizer is None:
|
||||||
|
self.token_ids.extend(new_token_ids)
|
||||||
|
return None
|
||||||
|
|
||||||
# 1) Detokenize the new token ids incrementally.
|
# 1) Detokenize the new token ids incrementally.
|
||||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||||
# new_token_ids is more than 1. We need to optimize this.
|
# new_token_ids is more than 1. We need to optimize this.
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -13,12 +14,15 @@ from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
NONES = itertools.repeat(None)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LogprobsProcessor:
|
class LogprobsProcessor:
|
||||||
|
|
||||||
# Tokenizer for this request
|
# Tokenizer for this request,
|
||||||
tokenizer: AnyTokenizer
|
# None if detokenization is disabled.
|
||||||
|
tokenizer: Optional[AnyTokenizer]
|
||||||
|
|
||||||
# Logprobs for this request
|
# Logprobs for this request
|
||||||
logprobs: Optional[SampleLogprobs]
|
logprobs: Optional[SampleLogprobs]
|
||||||
@ -30,7 +34,7 @@ class LogprobsProcessor:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_new_request(
|
def from_new_request(
|
||||||
cls,
|
cls,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: Optional[AnyTokenizer],
|
||||||
request: EngineCoreRequest,
|
request: EngineCoreRequest,
|
||||||
) -> "LogprobsProcessor":
|
) -> "LogprobsProcessor":
|
||||||
num_logprobs = request.sampling_params.logprobs
|
num_logprobs = request.sampling_params.logprobs
|
||||||
@ -66,8 +70,8 @@ class LogprobsProcessor:
|
|||||||
token_ids_lst):
|
token_ids_lst):
|
||||||
|
|
||||||
# Detokenize (non-incrementally).
|
# Detokenize (non-incrementally).
|
||||||
decoded_tokens = convert_ids_list_to_tokens(
|
decoded_tokens = NONES if self.tokenizer is None else (
|
||||||
self.tokenizer, token_ids)
|
convert_ids_list_to_tokens(self.tokenizer, token_ids))
|
||||||
|
|
||||||
# Sampler puts the sampled logprob in first.
|
# Sampler puts the sampled logprob in first.
|
||||||
sampled_token_logprob = logprobs[0]
|
sampled_token_logprob = logprobs[0]
|
||||||
@ -103,9 +107,9 @@ class LogprobsProcessor:
|
|||||||
|
|
||||||
# Detokenize non-incrementally.
|
# Detokenize non-incrementally.
|
||||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||||
decoded_tokens = convert_ids_list_to_tokens(
|
decoded_tokens = None if self.tokenizer is None else (
|
||||||
self.tokenizer,
|
convert_ids_list_to_tokens(self.tokenizer,
|
||||||
token_ids.flatten().tolist())
|
token_ids.flatten().tolist()))
|
||||||
|
|
||||||
# Recover shapes.
|
# Recover shapes.
|
||||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||||
@ -121,7 +125,8 @@ class LogprobsProcessor:
|
|||||||
# Handle flattening.
|
# Handle flattening.
|
||||||
offset = pos * num_logprobs
|
offset = pos * num_logprobs
|
||||||
offset_end = offset + num_logprobs
|
offset_end = offset + num_logprobs
|
||||||
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
|
decoded_tokens_for_pos = NONES \
|
||||||
|
if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||||
|
|
||||||
# Update with the Logprob dictionary for this pos.
|
# Update with the Logprob dictionary for this pos.
|
||||||
self.prompt_logprobs.append(
|
self.prompt_logprobs.append(
|
||||||
@ -153,7 +158,7 @@ class LogprobsProcessor:
|
|||||||
def _make_logprob_dict(
|
def _make_logprob_dict(
|
||||||
logprobs: list[float],
|
logprobs: list[float],
|
||||||
logprob_token_ids: list[int],
|
logprob_token_ids: list[int],
|
||||||
decoded_tokens: list[str],
|
decoded_tokens: Iterable[Optional[str]],
|
||||||
rank: int,
|
rank: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> dict[int, Logprob]:
|
) -> dict[int, Logprob]:
|
||||||
|
|||||||
@ -68,6 +68,8 @@ class RequestState:
|
|||||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
) -> "RequestState":
|
) -> "RequestState":
|
||||||
|
if not request.sampling_params.detokenize:
|
||||||
|
tokenizer = None
|
||||||
return cls(
|
return cls(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
parent_req=parent_req,
|
parent_req=parent_req,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user