[misc] Add Torch profiler support (#7451)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
William Lin 2024-08-21 15:39:26 -07:00 committed by GitHub
parent 970dfdc01d
commit dd53c4b023
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 191 additions and 2 deletions

View File

@ -225,8 +225,8 @@ async def async_request_openai_completions(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith( assert api_url.endswith(
"completions" ("completions", "profile")
), "OpenAI Completions API URL must end with 'completions'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search assert not request_func_input.use_beam_search

View File

@ -295,6 +295,7 @@ def calculate_metrics(
async def benchmark( async def benchmark(
backend: str, backend: str,
api_url: str, api_url: str,
base_url: str,
model_id: str, model_id: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
@ -302,6 +303,7 @@ async def benchmark(
use_beam_search: bool, use_beam_search: bool,
request_rate: float, request_rate: float,
disable_tqdm: bool, disable_tqdm: bool,
profile: bool,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
@ -326,6 +328,22 @@ async def benchmark(
f"are correctly specified. Error: {test_output.error}") f"are correctly specified. Error: {test_output.error}")
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/start_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
print(f"Traffic request rate: {request_rate}") print(f"Traffic request rate: {request_rate}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests)) pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@ -349,6 +367,21 @@ async def benchmark(
pbar=pbar))) pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
print("Stopping profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/stop_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler stopped")
if pbar is not None: if pbar is not None:
pbar.close() pbar.close()
@ -433,8 +466,10 @@ def main(args: argparse.Namespace):
if args.base_url is not None: if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}" api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}"
else: else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}" api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
tokenizer = get_tokenizer(tokenizer_id, tokenizer = get_tokenizer(tokenizer_id,
trust_remote_code=args.trust_remote_code) trust_remote_code=args.trust_remote_code)
@ -506,6 +541,7 @@ def main(args: argparse.Namespace):
benchmark( benchmark(
backend=backend, backend=backend,
api_url=api_url, api_url=api_url,
base_url=base_url,
model_id=model_id, model_id=model_id,
tokenizer=tokenizer, tokenizer=tokenizer,
input_requests=input_requests, input_requests=input_requests,
@ -513,6 +549,7 @@ def main(args: argparse.Namespace):
use_beam_search=args.use_beam_search, use_beam_search=args.use_beam_search,
request_rate=args.request_rate, request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile,
)) ))
# Save config and results to json # Save config and results to json
@ -693,6 +730,12 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Specify to disable tqdm progress bar.", help="Specify to disable tqdm progress bar.",
) )
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument( parser.add_argument(
"--save-result", "--save-result",
action="store_true", action="store_true",

View File

@ -0,0 +1,33 @@
Profiling vLLM
=================================
We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.
.. warning::
Only enable profiling in a development environment.
Traces can be visualized using https://ui.perfetto.dev/.
.. tip::
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
Example commands:
OpenAI Server:
.. code-block:: bash
VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
benchmark_serving.py:
.. code-block:: bash
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2

View File

@ -136,6 +136,7 @@ Documentation
dev/input_processing/model_inputs_index dev/input_processing/model_inputs_index
dev/multimodal/multimodal_index dev/multimodal/multimodal_index
dev/dockerfile/dockerfile dev/dockerfile/dockerfile
dev/profiling/profiling_index
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1

View File

@ -1266,3 +1266,9 @@ class AsyncLLMEngine:
logger_name=logger_name)) logger_name=logger_name))
else: else:
self.engine.remove_logger(logger_name=logger_name) self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
self.engine.model_executor._run_workers("start_profile")
async def stop_profile(self) -> None:
self.engine.model_executor._run_workers("stop_profile")

View File

@ -91,3 +91,11 @@ class AsyncEngineClient(Protocol):
async def check_health(self) -> None: async def check_health(self) -> None:
"""Raise if unhealthy""" """Raise if unhealthy"""
... ...
async def start_profile(self) -> None:
"""Start profiling the engine"""
...
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...

View File

@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile():
logger.info("Starting profiler...")
await async_engine_client.start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile():
logger.info("Stopping profiler...")
await async_engine_client.stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
def build_app(args: Namespace) -> FastAPI: def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.include_router(router) app.include_router(router)

View File

@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum):
DO_LOG_STATS = 7 DO_LOG_STATS = 7
IS_SERVER_HEALTHY = 8 IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9 IS_TRACING_ENABLED = 9
START_PROFILE = 10
STOP_PROFILE = 11
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,

View File

@ -400,3 +400,17 @@ class AsyncEngineRPCClient:
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError( raise NotImplementedError(
"Embeddings not supported with multiprocessing backend") "Embeddings not supported with multiprocessing backend")
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.START_PROFILE,
error_message="RPCRequest START_PROFILE failed.")
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")

View File

@ -124,6 +124,26 @@ class AsyncEngineRPCServer:
except Exception as e: except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
def _make_handler_coro(self, identity, def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]: message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine.""" """Route the zmq message to the handler coroutine."""
@ -153,6 +173,10 @@ class AsyncEngineRPCServer:
return self.check_health(identity) return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED: elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity) return self.is_tracing_enabled(identity)
elif request == RPCUtilityRequest.START_PROFILE:
return self.start_profile(identity)
elif request == RPCUtilityRequest.STOP_PROFILE:
return self.stop_profile(identity)
else: else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}") raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

View File

@ -58,6 +58,7 @@ if TYPE_CHECKING:
VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
def get_default_cache_root(): def get_default_cache_root():
@ -384,6 +385,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_PLUGINS": "VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","), "VLLM_PLUGINS"].split(","),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
import torch.distributed import torch.distributed
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
@ -13,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@ -27,6 +29,8 @@ from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
logger = init_logger(__name__)
class Worker(LocalOrDistributedWorkerBase): class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU. """A worker class that executes (a partition of) the model on a GPU.
@ -113,6 +117,33 @@ class Worker(LocalOrDistributedWorkerBase):
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def _is_encoder_decoder_model(self): def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model return self.model_config.is_encoder_decoder_model