mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 07:45:01 +08:00
[Frontend][V1] Online serving performance improvements (#12287)
This commit is contained in:
parent
7206ce4ce1
commit
aea94362c9
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
@ -104,6 +105,11 @@ async def lifespan(app: FastAPI):
|
|||||||
task.add_done_callback(_running_tasks.remove)
|
task.add_done_callback(_running_tasks.remove)
|
||||||
else:
|
else:
|
||||||
task = None
|
task = None
|
||||||
|
|
||||||
|
# Mark the startup heap as static so that it's ignored by GC.
|
||||||
|
# Reduces pause times of oldest generation collections.
|
||||||
|
gc.collect()
|
||||||
|
gc.freeze()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
@ -42,23 +42,31 @@ class OpenAIBaseModel(BaseModel):
|
|||||||
# OpenAI API does allow extra fields
|
# OpenAI API does allow extra fields
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
# Cache class field names
|
||||||
|
field_names: ClassVar[Optional[Set[str]]] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def __log_extra_fields__(cls, data):
|
def __log_extra_fields__(cls, data):
|
||||||
if isinstance(data, dict):
|
|
||||||
|
field_names = cls.field_names
|
||||||
|
if field_names is None:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return data
|
||||||
# Get all class field names and their potential aliases
|
# Get all class field names and their potential aliases
|
||||||
field_names = set()
|
field_names = set()
|
||||||
for field_name, field in cls.model_fields.items():
|
for field_name, field in cls.model_fields.items():
|
||||||
field_names.add(field_name)
|
field_names.add(field_name)
|
||||||
if hasattr(field, 'alias') and field.alias:
|
if alias := getattr(field, 'alias', None):
|
||||||
field_names.add(field.alias)
|
field_names.add(alias)
|
||||||
|
cls.field_names = field_names
|
||||||
|
|
||||||
# Compare against both field names and aliases
|
# Compare against both field names and aliases
|
||||||
extra_fields = data.keys() - field_names
|
if any(k not in field_names for k in data):
|
||||||
if extra_fields:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The following fields were present in the request "
|
"The following fields were present in the request "
|
||||||
"but ignored: %s", extra_fields)
|
"but ignored: %s",
|
||||||
|
data.keys() - field_names)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
11
vllm/envs.py
11
vllm/envs.py
@ -73,6 +73,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||||
VLLM_SERVER_DEV_MODE: bool = False
|
VLLM_SERVER_DEV_MODE: bool = False
|
||||||
|
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -474,6 +475,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# e.g. `/reset_prefix_cache`
|
# e.g. `/reset_prefix_cache`
|
||||||
"VLLM_SERVER_DEV_MODE":
|
"VLLM_SERVER_DEV_MODE":
|
||||||
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
|
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
|
||||||
|
|
||||||
|
# Controls the maximum number of requests to handle in a
|
||||||
|
# single asyncio task when processing per-token outputs in the
|
||||||
|
# V1 AsyncLLM interface. It is applicable when handling a high
|
||||||
|
# concurrency of streaming requests.
|
||||||
|
# Setting this too high can result in a higher variance of
|
||||||
|
# inter-message latencies. Setting it too low can negatively impact
|
||||||
|
# TTFT and overall throughput.
|
||||||
|
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
|
||||||
|
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -2,9 +2,12 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
|
from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -16,7 +19,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import kill_process_tree
|
from vllm.utils import cdiv, kill_process_tree
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.output_processor import OutputProcessor
|
from vllm.v1.engine.output_processor import OutputProcessor
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
@ -205,17 +208,15 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
# The output_handler task pushes items into the queue.
|
# The output_handler task pushes items into the queue.
|
||||||
# This task pulls from the queue and yields to caller.
|
# This task pulls from the queue and yields to caller.
|
||||||
while True:
|
finished = False
|
||||||
|
while not finished:
|
||||||
# Note: drain queue without await if possible (avoids
|
# Note: drain queue without await if possible (avoids
|
||||||
# task switching under load which helps performance).
|
# task switching under load which helps performance).
|
||||||
out = q.get_nowait() if q.qsize() > 0 else await q.get()
|
out = q.get_nowait() if not q.empty() else await q.get()
|
||||||
|
|
||||||
# Note: both OutputProcessor and EngineCore handle their
|
# Note: both OutputProcessor and EngineCore handle their
|
||||||
# own request cleanup based on finished.
|
# own request cleanup based on finished.
|
||||||
if out.finished:
|
finished = out.finished
|
||||||
yield out
|
|
||||||
break
|
|
||||||
|
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
# If the request is disconnected by the client, the
|
# If the request is disconnected by the client, the
|
||||||
@ -233,11 +234,29 @@ class AsyncLLM(EngineClient):
|
|||||||
# 1) Pull EngineCoreOutputs from the EngineCore.
|
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||||
outputs = await self.engine_core.get_output_async()
|
outputs = await self.engine_core.get_output_async()
|
||||||
|
|
||||||
|
# Split outputs into chunks of at most
|
||||||
|
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
||||||
|
# event loop for too long.
|
||||||
|
num_outputs = len(outputs.outputs)
|
||||||
|
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
||||||
|
slices = (outputs.outputs, )
|
||||||
|
else:
|
||||||
|
slices = np.array_split(
|
||||||
|
outputs.outputs,
|
||||||
|
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
||||||
|
|
||||||
|
iteration_stats = None
|
||||||
|
for i, outputs_slice in enumerate(slices):
|
||||||
# 2) Process EngineCoreOutputs.
|
# 2) Process EngineCoreOutputs.
|
||||||
processed_outputs = self.output_processor.process_outputs(
|
processed_outputs = self.output_processor.process_outputs(
|
||||||
outputs.outputs)
|
outputs_slice, iteration_stats)
|
||||||
# NOTE: RequestOutputs are pushed to their queues.
|
# NOTE: RequestOutputs are pushed to their queues.
|
||||||
assert len(processed_outputs.request_outputs) == 0
|
assert not processed_outputs.request_outputs
|
||||||
|
iteration_stats = processed_outputs.iteration_stats
|
||||||
|
|
||||||
|
# Allow other asyncio tasks to run between chunks
|
||||||
|
if i + 1 < len(slices):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
# 3) Abort any reqs that finished due to stop strings.
|
# 3) Abort any reqs that finished due to stop strings.
|
||||||
await self.engine_core.abort_requests_async(
|
await self.engine_core.abort_requests_async(
|
||||||
@ -246,9 +265,10 @@ class AsyncLLM(EngineClient):
|
|||||||
# 4) Logging.
|
# 4) Logging.
|
||||||
# TODO(rob): make into a coroutine and launch it in
|
# TODO(rob): make into a coroutine and launch it in
|
||||||
# background thread once we add Prometheus.
|
# background thread once we add Prometheus.
|
||||||
|
assert iteration_stats is not None
|
||||||
self._log_stats(
|
self._log_stats(
|
||||||
scheduler_stats=outputs.scheduler_stats,
|
scheduler_stats=outputs.scheduler_stats,
|
||||||
iteration_stats=processed_outputs.iteration_stats,
|
iteration_stats=iteration_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import weakref
|
import weakref
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import zmq
|
import zmq
|
||||||
@ -255,10 +256,24 @@ class AsyncMPClient(MPClient):
|
|||||||
log_stats=True,
|
log_stats=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_output_async(self) -> EngineCoreOutputs:
|
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
|
||||||
|
self.queue_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
frames = await self.output_socket.recv_multipart(copy=False)
|
async def get_output_async(self) -> EngineCoreOutputs:
|
||||||
return self.decoder.decode(frames[0].buffer)
|
if self.outputs_queue is None:
|
||||||
|
# Perform IO in separate task to parallelize as much as possible
|
||||||
|
self.outputs_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def process_outputs_socket():
|
||||||
|
assert self.outputs_queue is not None
|
||||||
|
while True:
|
||||||
|
(frame, ) = await self.output_socket.recv_multipart(
|
||||||
|
copy=False)
|
||||||
|
self.outputs_queue.put_nowait(frame.buffer)
|
||||||
|
|
||||||
|
self.queue_task = asyncio.create_task(process_outputs_socket())
|
||||||
|
|
||||||
|
return self.decoder.decode(await self.outputs_queue.get())
|
||||||
|
|
||||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||||
request: EngineCoreRequestUnion) -> None:
|
request: EngineCoreRequestUnion) -> None:
|
||||||
|
|||||||
@ -101,6 +101,7 @@ class OutputProcessor:
|
|||||||
def process_outputs(
|
def process_outputs(
|
||||||
self,
|
self,
|
||||||
engine_core_outputs: List[EngineCoreOutput],
|
engine_core_outputs: List[EngineCoreOutput],
|
||||||
|
iteration_stats: Optional[IterationStats] = None,
|
||||||
) -> OutputProcessorOutput:
|
) -> OutputProcessorOutput:
|
||||||
"""
|
"""
|
||||||
Process the EngineCoreOutputs:
|
Process the EngineCoreOutputs:
|
||||||
@ -133,6 +134,7 @@ class OutputProcessor:
|
|||||||
|
|
||||||
request_outputs: List[RequestOutput] = []
|
request_outputs: List[RequestOutput] = []
|
||||||
reqs_to_abort: List[str] = []
|
reqs_to_abort: List[str] = []
|
||||||
|
if not iteration_stats:
|
||||||
iteration_stats = IterationStats(self.log_stats)
|
iteration_stats = IterationStats(self.log_stats)
|
||||||
for engine_core_output in engine_core_outputs:
|
for engine_core_output in engine_core_outputs:
|
||||||
req_id = engine_core_output.request_id
|
req_id = engine_core_output.request_id
|
||||||
@ -175,8 +177,8 @@ class OutputProcessor:
|
|||||||
iteration_stats=iteration_stats,
|
iteration_stats=iteration_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _make_request_output(
|
def _make_request_output(
|
||||||
self,
|
|
||||||
request_state: RequestState,
|
request_state: RequestState,
|
||||||
detokenizer_output: Optional[DetokenizerOutput],
|
detokenizer_output: Optional[DetokenizerOutput],
|
||||||
) -> Optional[RequestOutput]:
|
) -> Optional[RequestOutput]:
|
||||||
|
|||||||
@ -64,6 +64,12 @@ class Request:
|
|||||||
# recomputing.
|
# recomputing.
|
||||||
self._kv_block_hashes: List[BlockHashType] = []
|
self._kv_block_hashes: List[BlockHashType] = []
|
||||||
|
|
||||||
|
# Read-only views
|
||||||
|
# Prevent directly appending to the these lists since
|
||||||
|
# they should also be updated simultaneously.
|
||||||
|
self.output_token_ids = ConstantList(self._output_token_ids)
|
||||||
|
self.all_token_ids = ConstantList(self._all_token_ids)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||||
return cls(
|
return cls(
|
||||||
@ -79,18 +85,6 @@ class Request:
|
|||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def output_token_ids(self) -> ConstantList[int]:
|
|
||||||
# Prevent directly appending to the output_token_ids since
|
|
||||||
# all_token_ids should also be updated simultaneously.
|
|
||||||
return ConstantList(self._output_token_ids)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def all_token_ids(self) -> ConstantList[int]:
|
|
||||||
# Prevent directly appending to the all_token_ids since
|
|
||||||
# output_token_ids should also be updated simultaneously
|
|
||||||
return ConstantList(self._all_token_ids)
|
|
||||||
|
|
||||||
def append_output_token_ids(
|
def append_output_token_ids(
|
||||||
self,
|
self,
|
||||||
token_ids: Union[int, List[int]],
|
token_ids: Union[int, List[int]],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user