[Frontend][V1] Online serving performance improvements (#12287)

This commit is contained in:
Nick Hill 2025-01-22 14:22:12 -08:00 committed by GitHub
parent 7206ce4ce1
commit aea94362c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 100 additions and 44 deletions

View File

@ -1,5 +1,6 @@
import asyncio
import atexit
import gc
import importlib
import inspect
import multiprocessing
@ -104,6 +105,11 @@ async def lifespan(app: FastAPI):
task.add_done_callback(_running_tasks.remove)
else:
task = None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc.collect()
gc.freeze()
try:
yield
finally:

View File

@ -3,7 +3,7 @@
import re
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -42,23 +42,31 @@ class OpenAIBaseModel(BaseModel):
# OpenAI API does allow extra fields
model_config = ConfigDict(extra="allow")
# Cache class field names
field_names: ClassVar[Optional[Set[str]]] = None
@model_validator(mode="before")
@classmethod
def __log_extra_fields__(cls, data):
if isinstance(data, dict):
field_names = cls.field_names
if field_names is None:
if not isinstance(data, dict):
return data
# Get all class field names and their potential aliases
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if hasattr(field, 'alias') and field.alias:
field_names.add(field.alias)
if alias := getattr(field, 'alias', None):
field_names.add(alias)
cls.field_names = field_names
# Compare against both field names and aliases
extra_fields = data.keys() - field_names
if extra_fields:
logger.warning(
"The following fields were present in the request "
"but ignored: %s", extra_fields)
# Compare against both field names and aliases
if any(k not in field_names for k in data):
logger.warning(
"The following fields were present in the request "
"but ignored: %s",
data.keys() - field_names)
return data

View File

@ -73,6 +73,7 @@ if TYPE_CHECKING:
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
def get_default_cache_root():
@ -474,6 +475,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# e.g. `/reset_prefix_cache`
"VLLM_SERVER_DEV_MODE":
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
# Controls the maximum number of requests to handle in a
# single asyncio task when processing per-token outputs in the
# V1 AsyncLLM interface. It is applicable when handling a high
# concurrency of streaming requests.
# Setting this too high can result in a higher variance of
# inter-message latencies. Setting it too low can negatively impact
# TTFT and overall throughput.
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
}
# end-env-vars-definition

View File

@ -2,9 +2,12 @@ import asyncio
import os
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
import numpy as np
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
@ -16,7 +19,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import kill_process_tree
from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.processor import Processor
@ -205,17 +208,15 @@ class AsyncLLM(EngineClient):
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
while True:
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get()
out = q.get_nowait() if not q.empty() else await q.get()
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
if out.finished:
yield out
break
finished = out.finished
yield out
# If the request is disconnected by the client, the
@ -233,22 +234,41 @@ class AsyncLLM(EngineClient):
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async()
# 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs(
outputs.outputs)
# NOTE: RequestOutputs are pushed to their queues.
assert len(processed_outputs.request_outputs) == 0
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
num_outputs = len(outputs.outputs)
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
# 3) Abort any reqs that finished due to stop strings.
await self.engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
iteration_stats = None
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs(
outputs_slice, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
iteration_stats = processed_outputs.iteration_stats
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings.
await self.engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once we add Prometheus.
assert iteration_stats is not None
self._log_stats(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=processed_outputs.iteration_stats,
iteration_stats=iteration_stats,
)
except Exception as e:

View File

@ -1,8 +1,9 @@
import asyncio
import os
import signal
import weakref
from abc import ABC, abstractmethod
from typing import List, Type
from typing import List, Optional, Type
import msgspec
import zmq
@ -255,10 +256,24 @@ class AsyncMPClient(MPClient):
log_stats=True,
)
async def get_output_async(self) -> EngineCoreOutputs:
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
self.queue_task: Optional[asyncio.Task] = None
frames = await self.output_socket.recv_multipart(copy=False)
return self.decoder.decode(frames[0].buffer)
async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None:
# Perform IO in separate task to parallelize as much as possible
self.outputs_queue = asyncio.Queue()
async def process_outputs_socket():
assert self.outputs_queue is not None
while True:
(frame, ) = await self.output_socket.recv_multipart(
copy=False)
self.outputs_queue.put_nowait(frame.buffer)
self.queue_task = asyncio.create_task(process_outputs_socket())
return self.decoder.decode(await self.outputs_queue.get())
async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:

View File

@ -101,6 +101,7 @@ class OutputProcessor:
def process_outputs(
self,
engine_core_outputs: List[EngineCoreOutput],
iteration_stats: Optional[IterationStats] = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
@ -133,7 +134,8 @@ class OutputProcessor:
request_outputs: List[RequestOutput] = []
reqs_to_abort: List[str] = []
iteration_stats = IterationStats(self.log_stats)
if not iteration_stats:
iteration_stats = IterationStats(self.log_stats)
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
req_state = self.request_states.get(req_id)
@ -175,8 +177,8 @@ class OutputProcessor:
iteration_stats=iteration_stats,
)
@staticmethod
def _make_request_output(
self,
request_state: RequestState,
detokenizer_output: Optional[DetokenizerOutput],
) -> Optional[RequestOutput]:

View File

@ -64,6 +64,12 @@ class Request:
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
# Read-only views
# Prevent directly appending to the these lists since
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
@ -79,18 +85,6 @@ class Request:
lora_request=request.lora_request,
)
@property
def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)
@property
def all_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the all_token_ids since
# output_token_ids should also be updated simultaneously
return ConstantList(self._all_token_ids)
def append_output_token_ids(
self,
token_ids: Union[int, List[int]],