[Frontend] MQLLMEngine supports profiling. (#8761)

This commit is contained in:
科英 2024-09-26 00:37:41 +08:00 committed by GitHub
parent 28e1299e60
commit 64840dfae4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 7 deletions

View File

@ -107,7 +107,13 @@ class RPCStartupResponse:
tracing_enabled: bool tracing_enabled: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] class RPCUProfileRequest(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

View File

@ -21,7 +21,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest, RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse) RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType from vllm.inputs import PromptType
@ -38,10 +39,10 @@ logger = init_logger(__name__)
class MQClientClosedError(Exception): class MQClientClosedError(Exception):
"""Exception class raised when the client is used post-close. """Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace. causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it. So, we throw this error such that we can suppress it.
""" """
@ -345,7 +346,7 @@ class MQLLMEngineClient:
async def check_health(self): async def check_health(self):
""" """
The check health loop probes the health status of the The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with Engine's health every N seconds and sets _errored_with
if the engine is unhealthy. if the engine is unhealthy.
""" """
if self._errored_with is not None: if self._errored_with is not None:
@ -561,3 +562,15 @@ class MQLLMEngineClient:
await self.abort(request_id) await self.abort(request_id)
finally: finally:
self.output_queues.pop(request_id) self.output_queues.pop(request_id)
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)

View File

@ -18,9 +18,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest, RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse) RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
@ -249,6 +251,11 @@ class MQLLMEngine:
self._handle_process_request(request) self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest): elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request) self._handle_abort_request(request)
elif isinstance(request, RPCUProfileRequest):
if request == RPCUProfileRequest.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
else: else:
raise ValueError("Unknown RPCRequest Type: " raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}") f"{type(request)}")
@ -356,6 +363,18 @@ class MQLLMEngine:
def _alive(self): def _alive(self):
self._last_alive_time = time.time() self._last_alive_time = time.time()
def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
def stop_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str): ipc_path: str):