From dd53c4b023056cda6174cc32dc3d31bc01e8646a Mon Sep 17 00:00:00 2001 From: William Lin Date: Wed, 21 Aug 2024 15:39:26 -0700 Subject: [PATCH] [misc] Add Torch profiler support (#7451) Co-authored-by: Cody Yu --- benchmarks/backend_request_func.py | 4 +- benchmarks/benchmark_serving.py | 43 +++++++++++++++++++ docs/source/dev/profiling/profiling_index.rst | 33 ++++++++++++++ docs/source/index.rst | 1 + vllm/engine/async_llm_engine.py | 6 +++ vllm/engine/protocol.py | 8 ++++ vllm/entrypoints/openai/api_server.py | 20 +++++++++ vllm/entrypoints/openai/rpc/__init__.py | 2 + vllm/entrypoints/openai/rpc/client.py | 14 ++++++ vllm/entrypoints/openai/rpc/server.py | 24 +++++++++++ vllm/envs.py | 7 +++ vllm/worker/worker.py | 31 +++++++++++++ 12 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 docs/source/dev/profiling/profiling_index.rst diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 3b4e31eaa712..f7d67692f697 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -225,8 +225,8 @@ async def async_request_openai_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "completions" - ), "OpenAI Completions API URL must end with 'completions'." + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fc0dbf77f16b..fe687da49290 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -295,6 +295,7 @@ def calculate_metrics( async def benchmark( backend: str, api_url: str, + base_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], @@ -302,6 +303,7 @@ async def benchmark( use_beam_search: bool, request_rate: float, disable_tqdm: bool, + profile: bool, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -326,6 +328,22 @@ async def benchmark( f"are correctly specified. Error: {test_output.error}") else: print("Initial test run completed. Starting main benchmark run...") + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -349,6 +367,21 @@ async def benchmark( pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + if pbar is not None: pbar.close() @@ -433,8 +466,10 @@ def main(args: argparse.Namespace): if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" else: api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) @@ -506,6 +541,7 @@ def main(args: argparse.Namespace): benchmark( backend=backend, api_url=api_url, + base_url=base_url, model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, @@ -513,6 +549,7 @@ def main(args: argparse.Namespace): use_beam_search=args.use_beam_search, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, + profile=args.profile, )) # Save config and results to json @@ -693,6 +730,12 @@ if __name__ == "__main__": action="store_true", help="Specify to disable tqdm progress bar.", ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--save-result", action="store_true", diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst new file mode 100644 index 000000000000..af3c78c3b5a5 --- /dev/null +++ b/docs/source/dev/profiling/profiling_index.rst @@ -0,0 +1,33 @@ +Profiling vLLM +================================= + +We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/`` + +The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set. + +When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag. + +.. warning:: + + Only enable profiling in a development environment. + + +Traces can be visualized using https://ui.perfetto.dev/. + +.. tip:: + + Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. + +Example commands: + +OpenAI Server: + +.. code-block:: bash + + VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B + +benchmark_serving.py: + +.. code-block:: bash + + python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2 \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 4e79871e6e78..4b817c4ba949 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -136,6 +136,7 @@ Documentation dev/input_processing/model_inputs_index dev/multimodal/multimodal_index dev/dockerfile/dockerfile + dev/profiling/profiling_index .. toctree:: :maxdepth: 1 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9911cc9bdd84..8812b853c066 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1266,3 +1266,9 @@ class AsyncLLMEngine: logger_name=logger_name)) else: self.engine.remove_logger(logger_name=logger_name) + + async def start_profile(self) -> None: + self.engine.model_executor._run_workers("start_profile") + + async def stop_profile(self) -> None: + self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6c7fd96a7f8e..1deb75167bc7 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -91,3 +91,11 @@ class AsyncEngineClient(Protocol): async def check_health(self) -> None: """Raise if unhealthy""" ... + + async def start_profile(self) -> None: + """Start profiling the engine""" + ... + + async def stop_profile(self) -> None: + """Start profiling the engine""" + ... diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 94d8525e429c..8e8371ef1559 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(): + logger.info("Starting profiler...") + await async_engine_client.start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(): + logger.info("Stopping profiler...") + await async_engine_client.stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + def build_app(args: Namespace) -> FastAPI: app = FastAPI(lifespan=lifespan) app.include_router(router) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 981dfbfc6670..571dca5f61fa 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum): DO_LOG_STATS = 7 IS_SERVER_HEALTHY = 8 IS_TRACING_ENABLED = 9 + START_PROFILE = 10 + STOP_PROFILE = 11 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 7e360d1defb1..1f26348c74d6 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -400,3 +400,17 @@ class AsyncEngineRPCClient: **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( "Embeddings not supported with multiprocessing backend") + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.START_PROFILE, + error_message="RPCRequest START_PROFILE failed.") + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.STOP_PROFILE, + error_message="RPCRequest STOP_PROFILE failed.") \ No newline at end of file diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 580b83277cfb..738d12bbef05 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -124,6 +124,26 @@ class AsyncEngineRPCServer: except Exception as e: await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + async def start_profile(self, identity): + logger.info("Starting profiler...") + await self.engine.start_profile() + logger.info("Profiler started.") + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def stop_profile(self, identity): + logger.info("Stopping profiler...") + await self.engine.stop_profile() + logger.info("Profiler stopped.") + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" @@ -153,6 +173,10 @@ class AsyncEngineRPCServer: return self.check_health(identity) elif request == RPCUtilityRequest.IS_TRACING_ENABLED: return self.is_tracing_enabled(identity) + elif request == RPCUtilityRequest.START_PROFILE: + return self.start_profile(identity) + elif request == RPCUtilityRequest.STOP_PROFILE: + return self.stop_profile(identity) else: raise ValueError(f"Unknown RPCUtilityRequest type: {request}") diff --git a/vllm/envs.py b/vllm/envs.py index 115ead01f537..e4cf6a028ac1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None + VLLM_TORCH_PROFILER_DIR: Optional[str] = None def get_default_cache_root(): @@ -384,6 +385,12 @@ environment_variables: Dict[str, Callable[[], Any]] = { "VLLM_PLUGINS": lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ "VLLM_PLUGINS"].split(","), + + # Enables torch profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "VLLM_TORCH_PROFILER_DIR": + lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os + .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), } # end-env-vars-definition diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 97be68934be4..331a805caba9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch import torch.distributed +import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, @@ -13,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -27,6 +29,8 @@ from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +logger = init_logger(__name__) + class Worker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a GPU. @@ -113,6 +117,33 @@ class Worker(LocalOrDistributedWorkerBase): self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + def _is_encoder_decoder_model(self): return self.model_config.is_encoder_decoder_model