mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 16:15:54 +08:00
[V0 Deprecation] Remove V0 Output Processor (#25320)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
52c2a8d4ad
commit
86647d1cd0
@ -1,39 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import Cython.Compiler.Options
|
|
||||||
from Cython.Build import cythonize
|
|
||||||
from setuptools import setup
|
|
||||||
|
|
||||||
Cython.Compiler.Options.annotate = True
|
|
||||||
|
|
||||||
infiles = []
|
|
||||||
|
|
||||||
infiles += [
|
|
||||||
"vllm/engine/llm_engine.py",
|
|
||||||
"vllm/transformers_utils/detokenizer.py",
|
|
||||||
"vllm/engine/output_processor/single_step.py",
|
|
||||||
"vllm/outputs.py",
|
|
||||||
"vllm/engine/output_processor/stop_checker.py",
|
|
||||||
]
|
|
||||||
|
|
||||||
infiles += [
|
|
||||||
"vllm/core/scheduler.py",
|
|
||||||
"vllm/sequence.py",
|
|
||||||
"vllm/core/block_manager.py",
|
|
||||||
]
|
|
||||||
|
|
||||||
infiles += [
|
|
||||||
"vllm/model_executor/layers/sampler.py",
|
|
||||||
"vllm/sampling_params.py",
|
|
||||||
"vllm/utils/__init__.py",
|
|
||||||
]
|
|
||||||
|
|
||||||
setup(ext_modules=cythonize(infiles,
|
|
||||||
annotate=False,
|
|
||||||
force=True,
|
|
||||||
compiler_directives={
|
|
||||||
'language_level': "3",
|
|
||||||
'infer_types': True
|
|
||||||
}))
|
|
||||||
|
|
||||||
# example usage: python3 build_cython.py build_ext --inplace
|
|
||||||
@ -1,59 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from vllm.config import SchedulerConfig
|
|
||||||
from vllm.core.scheduler import Scheduler
|
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
|
||||||
from vllm.sequence import SequenceGroup, SequenceGroupOutput
|
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
|
||||||
from vllm.utils import Counter
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupOutputProcessor(ABC):
|
|
||||||
"""Interface for logic that processes new token ids in sequence groups,
|
|
||||||
managing detokenization, stop checking, and freeing/forking sequences with
|
|
||||||
the scheduler.
|
|
||||||
|
|
||||||
This is highly coupled with the LLMEngine and should be seen as an extension
|
|
||||||
of it. The logic is separated to simplify the LLMEngine class and allow
|
|
||||||
separate implementations for single-step decoding (which supports beam
|
|
||||||
search sequence forking) and multi-step decoding (which does not support
|
|
||||||
beam search, but does support speculative decoding).
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_output_processor(
|
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
detokenizer: Detokenizer,
|
|
||||||
scheduler: List[Scheduler],
|
|
||||||
seq_counter: Counter,
|
|
||||||
stop_checker: "StopChecker",
|
|
||||||
):
|
|
||||||
"""Create an output processor.
|
|
||||||
|
|
||||||
Multi-step scheduling is no longer supported. Always return a
|
|
||||||
single-step output processor.
|
|
||||||
"""
|
|
||||||
from vllm.engine.output_processor.single_step import (
|
|
||||||
SingleStepOutputProcessor)
|
|
||||||
return SingleStepOutputProcessor(scheduler_config, detokenizer,
|
|
||||||
scheduler, seq_counter, stop_checker)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process_outputs(self, sequence_group: SequenceGroup,
|
|
||||||
outputs: List[SequenceGroupOutput],
|
|
||||||
is_async: bool) -> None:
|
|
||||||
"""Process new token ids for the sequence group. Handles logic such as
|
|
||||||
detokenization, stop checking, and freeing/forking sequences in the
|
|
||||||
scheduler.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
|
||||||
outputs: List[SequenceGroupOutput]) -> None:
|
|
||||||
"""Update prompt logprobs received from outputs to seq_group."""
|
|
||||||
pass
|
|
||||||
@ -1,145 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from vllm.config import SchedulerConfig
|
|
||||||
from vllm.core.scheduler import Scheduler
|
|
||||||
from vllm.engine.output_processor.interfaces import (
|
|
||||||
SequenceGroupOutputProcessor)
|
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
|
|
||||||
SequenceGroupOutput)
|
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
|
||||||
from vllm.utils import Counter
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def single_step_process_prompt_logprob(
|
|
||||||
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
|
|
||||||
output: CompletionSequenceGroupOutput) -> None:
|
|
||||||
"""Process prompt logprobs associated with the
|
|
||||||
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step.
|
|
||||||
|
|
||||||
Do nothing if the output has no prompt logprobs.
|
|
||||||
|
|
||||||
Account for the fact that transformers do not compute first-token logprobs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sg_output_proc:
|
|
||||||
[`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor]
|
|
||||||
instance
|
|
||||||
seq_group: the output is associated with this
|
|
||||||
[`SequenceGroup`][vllm.sequence.SequenceGroup]
|
|
||||||
output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
|
|
||||||
for a single scheduler step
|
|
||||||
"""
|
|
||||||
prompt_logprobs = output.prompt_logprobs
|
|
||||||
|
|
||||||
# If this is the first (or only) "chunk" of the prefill, we need
|
|
||||||
# to prepend None to the list of prompt logprobs. The reason for this
|
|
||||||
# is that for N prompt tokens, the Sampler will generate N-1 total
|
|
||||||
# prompt logprobs during prefill since the token at idx 0 will not
|
|
||||||
# have a logprob associated with it.
|
|
||||||
if prompt_logprobs is not None:
|
|
||||||
if not seq_group.prompt_logprobs:
|
|
||||||
prompt_logprobs = [None] + prompt_logprobs
|
|
||||||
seq_group.prompt_logprobs = []
|
|
||||||
|
|
||||||
assert hasattr(sg_output_proc, 'detokenizer')
|
|
||||||
if (seq_group.sampling_params.detokenize
|
|
||||||
and sg_output_proc.detokenizer):
|
|
||||||
sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
|
|
||||||
seq_group,
|
|
||||||
prompt_logprobs,
|
|
||||||
position_offset=len(seq_group.prompt_logprobs))
|
|
||||||
|
|
||||||
seq_group.prompt_logprobs.extend(prompt_logprobs)
|
|
||||||
|
|
||||||
|
|
||||||
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|
||||||
"""SequenceGroupOutputProcessor which handles "output processing" logic,
|
|
||||||
which happens after the model returns generated token ids and before
|
|
||||||
scheduling of the next batch. Output processing logic includes
|
|
||||||
detokenization, and determining if a sequence is finished (e.g. via max len
|
|
||||||
or eos token).
|
|
||||||
|
|
||||||
The SingleStepOutputProcessor is specialized to the case where the model
|
|
||||||
emits at most a single token per invocation, which precludes configurations
|
|
||||||
such as speculative decoding or multi-step decoding. This enables beam
|
|
||||||
search sampling, which requires forking/finishing/freeing sequences in a way
|
|
||||||
that is currently difficult to schedule multiple steps ahead of time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, scheduler_config: SchedulerConfig,
|
|
||||||
detokenizer: Detokenizer, scheduler: List[Scheduler],
|
|
||||||
seq_counter: Counter, stop_checker: StopChecker):
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.detokenizer = detokenizer
|
|
||||||
self.scheduler = scheduler
|
|
||||||
self.seq_counter = seq_counter
|
|
||||||
self.stop_checker = stop_checker
|
|
||||||
|
|
||||||
def process_outputs(self, sequence_group: SequenceGroup,
|
|
||||||
outputs: List[SequenceGroupOutput],
|
|
||||||
is_async: bool) -> None:
|
|
||||||
"""Append all new tokens to sequences in the sequence group. Fork any
|
|
||||||
surviving beam candidates; free any unsurviving ones.
|
|
||||||
|
|
||||||
Invokes detokenizer to detokenize new tokens, and also marks sequences
|
|
||||||
as finished if they meet stop conditions.
|
|
||||||
|
|
||||||
is_async - Indicates whether this postprocessor runs in
|
|
||||||
parallel with the GPU forward pass and is processing
|
|
||||||
tokens from the previous step. If this is true, then
|
|
||||||
no tokens need to be appended since it is already done
|
|
||||||
externally (before the next schedule() call)
|
|
||||||
"""
|
|
||||||
assert (len(outputs) == 1
|
|
||||||
), f"{type(self)} does not support multiple outputs per step"
|
|
||||||
return self._process_sequence_group_outputs(sequence_group, outputs[0],
|
|
||||||
is_async)
|
|
||||||
|
|
||||||
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
|
||||||
outputs: List[SequenceGroupOutput]) -> None:
|
|
||||||
"""Process prompt logprobs associated with one step of a single-step-
|
|
||||||
scheduled computation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_group: the output is associated with this
|
|
||||||
[`SequenceGroup`][vllm.sequence.SequenceGroup]
|
|
||||||
outputs: the
|
|
||||||
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
|
|
||||||
for a single scheduler step
|
|
||||||
"""
|
|
||||||
assert len(outputs) == 1, "Single step should only have 1 output."
|
|
||||||
output = outputs[0]
|
|
||||||
assert isinstance(output, CompletionSequenceGroupOutput)
|
|
||||||
single_step_process_prompt_logprob(self, seq_group, output)
|
|
||||||
|
|
||||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
|
||||||
outputs: SequenceGroupOutput,
|
|
||||||
is_async: bool) -> None:
|
|
||||||
sampling_params = seq_group.sampling_params
|
|
||||||
|
|
||||||
sample = outputs.samples[0]
|
|
||||||
seq = seq_group.first_seq
|
|
||||||
if not is_async:
|
|
||||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
|
||||||
sample.output_embed)
|
|
||||||
if sampling_params.detokenize and self.detokenizer:
|
|
||||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
|
||||||
seq, sampling_params)
|
|
||||||
else:
|
|
||||||
new_char_count = 0
|
|
||||||
self.stop_checker.maybe_stop_sequence(
|
|
||||||
seq,
|
|
||||||
new_char_count,
|
|
||||||
sampling_params,
|
|
||||||
lora_req=seq_group.lora_request,
|
|
||||||
)
|
|
||||||
if seq.is_finished():
|
|
||||||
for scheduler in self.scheduler:
|
|
||||||
scheduler.free_seq(seq)
|
|
||||||
@ -1,139 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.reasoning import ReasoningParser
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.sequence import Sequence, SequenceStatus
|
|
||||||
|
|
||||||
|
|
||||||
class StopChecker:
|
|
||||||
"""LLMEngine helper class which separates out the logic involving stop
|
|
||||||
checking. This checks things such as: whether the eos token was emitted,
|
|
||||||
whether the max_tokens has been consumed, whether a stop string has been
|
|
||||||
emitted, or if we have exceeded the max model len.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_model_len: int,
|
|
||||||
reasoner: Optional[ReasoningParser] = None,
|
|
||||||
):
|
|
||||||
# Do not use it directly, but use `self._get_max_model_len`.
|
|
||||||
self._max_model_len = max_model_len
|
|
||||||
self.reasoner = reasoner
|
|
||||||
|
|
||||||
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
|
|
||||||
if lora_req and lora_req.long_lora_max_len:
|
|
||||||
return lora_req.long_lora_max_len
|
|
||||||
else:
|
|
||||||
return self._max_model_len
|
|
||||||
|
|
||||||
def maybe_stop_sequence(
|
|
||||||
self,
|
|
||||||
seq: Sequence,
|
|
||||||
new_char_count: int,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
lora_req: Optional[LoRARequest] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Stop the finished sequences.
|
|
||||||
|
|
||||||
new_char_count is the number of chars added to the
|
|
||||||
sequence's output text for the newly generated token
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Check if the minimum number of tokens has been generated yet;
|
|
||||||
# skip the stop string/token checks if not
|
|
||||||
if seq.get_output_len() < sampling_params.min_tokens:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the sequence has generated the EOS token.
|
|
||||||
if ((not sampling_params.ignore_eos)
|
|
||||||
and seq.get_last_token_id() == seq.eos_token_id):
|
|
||||||
# Remove the last EOS token unless explicitly specified
|
|
||||||
# This prevents unintended exposure of the EOS token
|
|
||||||
if new_char_count and (
|
|
||||||
not sampling_params.include_stop_str_in_output):
|
|
||||||
seq.output_text = seq.output_text[:-new_char_count]
|
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
|
||||||
return
|
|
||||||
|
|
||||||
# Skip stop string/token checks if in reasoning content generation
|
|
||||||
if self.reasoner is not None and \
|
|
||||||
not self.reasoner.is_reasoning_end(seq.get_token_ids()):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if a stop token was encountered.
|
|
||||||
# This assumes a single token produced per step.
|
|
||||||
last_token_id = seq.get_last_token_id()
|
|
||||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
|
||||||
if new_char_count and (
|
|
||||||
not sampling_params.include_stop_str_in_output):
|
|
||||||
# Remove last token
|
|
||||||
seq.output_text = seq.output_text[:-new_char_count]
|
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
|
||||||
seq.stop_reason = last_token_id
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if any stop strings are matched.
|
|
||||||
stop = self.check_stop_strings(
|
|
||||||
seq.output_text, new_char_count, sampling_params.stop,
|
|
||||||
sampling_params.include_stop_str_in_output)
|
|
||||||
if stop is not None:
|
|
||||||
stop_str, truncate_to = stop
|
|
||||||
if truncate_to != -1:
|
|
||||||
seq.output_text = seq.output_text[:truncate_to]
|
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
|
||||||
seq.stop_reason = stop_str
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the sequence has reached max_model_len.
|
|
||||||
if seq.get_len() >= self._get_max_model_len(lora_req):
|
|
||||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the sequence has reached max_tokens.
|
|
||||||
if seq.get_output_len() == sampling_params.max_tokens:
|
|
||||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_stop_strings(
|
|
||||||
output_text: str,
|
|
||||||
new_char_count: int,
|
|
||||||
stop: List[str],
|
|
||||||
include_in_output: bool,
|
|
||||||
) -> Optional[Tuple[str, int]]:
|
|
||||||
"""Check if any stop strings are matched and truncate sequence
|
|
||||||
output text accordingly.
|
|
||||||
|
|
||||||
Returns tuple (stop_string, offset) if matched or else None.
|
|
||||||
|
|
||||||
Where stop_string is the matched stop string and offset is the
|
|
||||||
length to which output_text should be truncated, or -1 for no
|
|
||||||
truncation.
|
|
||||||
"""
|
|
||||||
if not new_char_count or not stop:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for stop_str in stop:
|
|
||||||
stop_string_len = len(stop_str)
|
|
||||||
# Avoid searching already-searched text.
|
|
||||||
stop_index = output_text.find(stop_str,
|
|
||||||
1 - new_char_count - stop_string_len)
|
|
||||||
if stop_index == -1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if include_in_output:
|
|
||||||
# Truncate to end of stop string.
|
|
||||||
stop_index += stop_string_len
|
|
||||||
if stop_index >= len(output_text):
|
|
||||||
# No truncation required.
|
|
||||||
return stop_str, -1
|
|
||||||
|
|
||||||
# Truncate the output text to either the beginning
|
|
||||||
# or end of the stop string.
|
|
||||||
return stop_str, stop_index
|
|
||||||
return None
|
|
||||||
@ -9,7 +9,6 @@ from tokenizers import Tokenizer
|
|||||||
from tokenizers.decoders import DecodeStream
|
from tokenizers.decoders import DecodeStream
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.detokenizer_utils import (
|
from vllm.transformers_utils.detokenizer_utils import (
|
||||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||||
@ -129,7 +128,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||||||
# 2) Evaluate stop strings.
|
# 2) Evaluate stop strings.
|
||||||
stop_string = None
|
stop_string = None
|
||||||
if self.stop and len(self.output_token_ids) > self.min_tokens:
|
if self.stop and len(self.output_token_ids) > self.min_tokens:
|
||||||
stop = StopChecker.check_stop_strings(
|
stop = check_stop_strings(
|
||||||
output_text=self.output_text,
|
output_text=self.output_text,
|
||||||
new_char_count=len(self.output_text) - stop_check_offset,
|
new_char_count=len(self.output_text) - stop_check_offset,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
@ -309,3 +308,42 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
self.read_offset = read_offset
|
self.read_offset = read_offset
|
||||||
|
|
||||||
return decoded_text
|
return decoded_text
|
||||||
|
|
||||||
|
|
||||||
|
def check_stop_strings(
|
||||||
|
output_text: str,
|
||||||
|
new_char_count: int,
|
||||||
|
stop: list[str],
|
||||||
|
include_in_output: bool,
|
||||||
|
) -> Optional[tuple[str, int]]:
|
||||||
|
"""Check if any stop strings are matched and truncate sequence
|
||||||
|
output text accordingly.
|
||||||
|
|
||||||
|
Returns tuple (stop_string, offset) if matched or else None.
|
||||||
|
|
||||||
|
Where stop_string is the matched stop string and offset is the
|
||||||
|
length to which output_text should be truncated, or -1 for no
|
||||||
|
truncation.
|
||||||
|
"""
|
||||||
|
if not new_char_count or not stop:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for stop_str in stop:
|
||||||
|
stop_string_len = len(stop_str)
|
||||||
|
# Avoid searching already-searched text.
|
||||||
|
stop_index = output_text.find(stop_str,
|
||||||
|
1 - new_char_count - stop_string_len)
|
||||||
|
if stop_index == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if include_in_output:
|
||||||
|
# Truncate to end of stop string.
|
||||||
|
stop_index += stop_string_len
|
||||||
|
if stop_index >= len(output_text):
|
||||||
|
# No truncation required.
|
||||||
|
return stop_str, -1
|
||||||
|
|
||||||
|
# Truncate the output text to either the beginning
|
||||||
|
# or end of the stop string.
|
||||||
|
return stop_str, stop_index
|
||||||
|
return None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user