mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 19:47:06 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
75 lines
2.9 KiB
Python
75 lines
2.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Callable, List
|
|
|
|
from vllm.config import SchedulerConfig
|
|
from vllm.core.scheduler import Scheduler
|
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
|
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.utils import Counter
|
|
|
|
|
|
class SequenceGroupOutputProcessor(ABC):
|
|
"""Interface for logic that processes new token ids in sequence groups,
|
|
managing detokenization, stop checking, and freeing/forking sequences with
|
|
the scheduler.
|
|
|
|
This is highly coupled with the LLMEngine and should be seen as an extension
|
|
of it. The logic is separated to simplify the LLMEngine class and allow
|
|
separate implementations for single-step decoding (which supports beam
|
|
search sequence forking) and multi-step decoding (which does not support
|
|
beam search, but does support speculative decoding).
|
|
"""
|
|
|
|
@staticmethod
|
|
def create_output_processor(
|
|
scheduler_config: SchedulerConfig,
|
|
detokenizer: Detokenizer,
|
|
scheduler: List[Scheduler],
|
|
seq_counter: Counter,
|
|
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
|
stop_checker: "StopChecker",
|
|
):
|
|
"""Create an output processor.
|
|
|
|
This returns a single-step output processor if num_lookahead_slots is
|
|
zero, else returns a multi-step output processor.
|
|
"""
|
|
if scheduler_config.num_lookahead_slots == 0:
|
|
# Importing here to avoid cycle.
|
|
from vllm.engine.output_processor.single_step import (
|
|
SingleStepOutputProcessor)
|
|
return SingleStepOutputProcessor(scheduler_config, detokenizer,
|
|
scheduler, seq_counter,
|
|
stop_checker)
|
|
else:
|
|
# Importing here to avoid cycle.
|
|
from vllm.engine.output_processor.multi_step import (
|
|
MultiStepOutputProcessor)
|
|
return MultiStepOutputProcessor(
|
|
detokenizer,
|
|
scheduler,
|
|
seq_counter,
|
|
get_tokenizer_for_seq,
|
|
stop_checker,
|
|
)
|
|
|
|
@abstractmethod
|
|
def process_outputs(self, sequence_group: SequenceGroup,
|
|
outputs: List[SequenceGroupOutput],
|
|
is_async: bool) -> None:
|
|
"""Process new token ids for the sequence group. Handles logic such as
|
|
detokenization, stop checking, and freeing/forking sequences in the
|
|
scheduler.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def process_prompt_logprob(self, seq_group: SequenceGroup,
|
|
outputs: List[SequenceGroupOutput]) -> None:
|
|
"""Update prompt logprobs received from outputs to seq_group."""
|
|
pass
|