[Frontend][V1] Online serving performance improvements (#12287)

This commit is contained in:
Nick Hill 2025-01-22 14:22:12 -08:00 committed by GitHub
parent 7206ce4ce1
commit aea94362c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 100 additions and 44 deletions

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import atexit import atexit
import gc
import importlib import importlib
import inspect import inspect
import multiprocessing import multiprocessing
@ -104,6 +105,11 @@ async def lifespan(app: FastAPI):
task.add_done_callback(_running_tasks.remove) task.add_done_callback(_running_tasks.remove)
else: else:
task = None task = None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc.collect()
gc.freeze()
try: try:
yield yield
finally: finally:

View File

@ -3,7 +3,7 @@
import re import re
import time import time
from argparse import Namespace from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -42,23 +42,31 @@ class OpenAIBaseModel(BaseModel):
# OpenAI API does allow extra fields # OpenAI API does allow extra fields
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
# Cache class field names
field_names: ClassVar[Optional[Set[str]]] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def __log_extra_fields__(cls, data): def __log_extra_fields__(cls, data):
if isinstance(data, dict):
field_names = cls.field_names
if field_names is None:
if not isinstance(data, dict):
return data
# Get all class field names and their potential aliases # Get all class field names and their potential aliases
field_names = set() field_names = set()
for field_name, field in cls.model_fields.items(): for field_name, field in cls.model_fields.items():
field_names.add(field_name) field_names.add(field_name)
if hasattr(field, 'alias') and field.alias: if alias := getattr(field, 'alias', None):
field_names.add(field.alias) field_names.add(alias)
cls.field_names = field_names
# Compare against both field names and aliases # Compare against both field names and aliases
extra_fields = data.keys() - field_names if any(k not in field_names for k in data):
if extra_fields:
logger.warning( logger.warning(
"The following fields were present in the request " "The following fields were present in the request "
"but ignored: %s", extra_fields) "but ignored: %s",
data.keys() - field_names)
return data return data

View File

@ -73,6 +73,7 @@ if TYPE_CHECKING:
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
def get_default_cache_root(): def get_default_cache_root():
@ -474,6 +475,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# e.g. `/reset_prefix_cache` # e.g. `/reset_prefix_cache`
"VLLM_SERVER_DEV_MODE": "VLLM_SERVER_DEV_MODE":
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
# Controls the maximum number of requests to handle in a
# single asyncio task when processing per-token outputs in the
# V1 AsyncLLM interface. It is applicable when handling a high
# concurrency of streaming requests.
# Setting this too high can result in a higher variance of
# inter-message latencies. Setting it too low can negatively impact
# TTFT and overall throughput.
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@ -2,9 +2,12 @@ import asyncio
import os import os
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
import numpy as np
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
@ -16,7 +19,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import kill_process_tree from vllm.utils import cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
@ -205,17 +208,15 @@ class AsyncLLM(EngineClient):
# The output_handler task pushes items into the queue. # The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller. # This task pulls from the queue and yields to caller.
while True: finished = False
while not finished:
# Note: drain queue without await if possible (avoids # Note: drain queue without await if possible (avoids
# task switching under load which helps performance). # task switching under load which helps performance).
out = q.get_nowait() if q.qsize() > 0 else await q.get() out = q.get_nowait() if not q.empty() else await q.get()
# Note: both OutputProcessor and EngineCore handle their # Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished. # own request cleanup based on finished.
if out.finished: finished = out.finished
yield out
break
yield out yield out
# If the request is disconnected by the client, the # If the request is disconnected by the client, the
@ -233,11 +234,29 @@ class AsyncLLM(EngineClient):
# 1) Pull EngineCoreOutputs from the EngineCore. # 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async() outputs = await self.engine_core.get_output_async()
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
num_outputs = len(outputs.outputs)
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
iteration_stats = None
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs. # 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs( processed_outputs = self.output_processor.process_outputs(
outputs.outputs) outputs_slice, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues. # NOTE: RequestOutputs are pushed to their queues.
assert len(processed_outputs.request_outputs) == 0 assert not processed_outputs.request_outputs
iteration_stats = processed_outputs.iteration_stats
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# 3) Abort any reqs that finished due to stop strings. # 3) Abort any reqs that finished due to stop strings.
await self.engine_core.abort_requests_async( await self.engine_core.abort_requests_async(
@ -246,9 +265,10 @@ class AsyncLLM(EngineClient):
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once we add Prometheus. # background thread once we add Prometheus.
assert iteration_stats is not None
self._log_stats( self._log_stats(
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=processed_outputs.iteration_stats, iteration_stats=iteration_stats,
) )
except Exception as e: except Exception as e:

View File

@ -1,8 +1,9 @@
import asyncio
import os import os
import signal import signal
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Type from typing import List, Optional, Type
import msgspec import msgspec
import zmq import zmq
@ -255,10 +256,24 @@ class AsyncMPClient(MPClient):
log_stats=True, log_stats=True,
) )
async def get_output_async(self) -> EngineCoreOutputs: self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
self.queue_task: Optional[asyncio.Task] = None
frames = await self.output_socket.recv_multipart(copy=False) async def get_output_async(self) -> EngineCoreOutputs:
return self.decoder.decode(frames[0].buffer) if self.outputs_queue is None:
# Perform IO in separate task to parallelize as much as possible
self.outputs_queue = asyncio.Queue()
async def process_outputs_socket():
assert self.outputs_queue is not None
while True:
(frame, ) = await self.output_socket.recv_multipart(
copy=False)
self.outputs_queue.put_nowait(frame.buffer)
self.queue_task = asyncio.create_task(process_outputs_socket())
return self.decoder.decode(await self.outputs_queue.get())
async def _send_input(self, request_type: EngineCoreRequestType, async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None: request: EngineCoreRequestUnion) -> None:

View File

@ -101,6 +101,7 @@ class OutputProcessor:
def process_outputs( def process_outputs(
self, self,
engine_core_outputs: List[EngineCoreOutput], engine_core_outputs: List[EngineCoreOutput],
iteration_stats: Optional[IterationStats] = None,
) -> OutputProcessorOutput: ) -> OutputProcessorOutput:
""" """
Process the EngineCoreOutputs: Process the EngineCoreOutputs:
@ -133,6 +134,7 @@ class OutputProcessor:
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
reqs_to_abort: List[str] = [] reqs_to_abort: List[str] = []
if not iteration_stats:
iteration_stats = IterationStats(self.log_stats) iteration_stats = IterationStats(self.log_stats)
for engine_core_output in engine_core_outputs: for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id req_id = engine_core_output.request_id
@ -175,8 +177,8 @@ class OutputProcessor:
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
) )
@staticmethod
def _make_request_output( def _make_request_output(
self,
request_state: RequestState, request_state: RequestState,
detokenizer_output: Optional[DetokenizerOutput], detokenizer_output: Optional[DetokenizerOutput],
) -> Optional[RequestOutput]: ) -> Optional[RequestOutput]:

View File

@ -64,6 +64,12 @@ class Request:
# recomputing. # recomputing.
self._kv_block_hashes: List[BlockHashType] = [] self._kv_block_hashes: List[BlockHashType] = []
# Read-only views
# Prevent directly appending to the these lists since
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls( return cls(
@ -79,18 +85,6 @@ class Request:
lora_request=request.lora_request, lora_request=request.lora_request,
) )
@property
def output_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the output_token_ids since
# all_token_ids should also be updated simultaneously.
return ConstantList(self._output_token_ids)
@property
def all_token_ids(self) -> ConstantList[int]:
# Prevent directly appending to the all_token_ids since
# output_token_ids should also be updated simultaneously
return ConstantList(self._all_token_ids)
def append_output_token_ids( def append_output_token_ids(
self, self,
token_ids: Union[int, List[int]], token_ids: Union[int, List[int]],