vllm/vllm/engine/protocol.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

285 lines
9.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid
logger = init_logger(__name__)
class EngineClient(ABC):
"""Protocol class for Clients to Engine"""
@property
@abstractmethod
def is_running(self) -> bool:
...
@property
@abstractmethod
def is_stopped(self) -> bool:
...
@property
@abstractmethod
def errored(self) -> bool:
...
@property
@abstractmethod
def dead_error(self) -> BaseException:
...
@abstractmethod
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...
async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = preprocessor._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
]
completed = []
for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams
]
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
finish_reason="stop",
stop_reason=tokenizer.eos_token_id))
else:
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs))
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos):
# Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1]
else:
tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens)
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None)
yield beam_search_output
@abstractmethod
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model."""
...
@abstractmethod
async def abort(self, request_id: str) -> None:
"""Abort a request.
Args:
request_id: The unique id of the request.
"""
...
@abstractmethod
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
...
@abstractmethod
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
...
@abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine."""
...
@abstractmethod
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
...
@abstractmethod
async def is_tracing_enabled(self) -> bool:
...
@abstractmethod
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
...
@abstractmethod
async def check_health(self) -> None:
"""Raise if unhealthy"""
...
@abstractmethod
async def start_profile(self) -> None:
"""Start profiling the engine"""
...
@abstractmethod
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...
@abstractmethod
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""
...
@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
...