[misc] Add Torch profiler support (#7451)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
William Lin 2024-08-21 15:39:26 -07:00 committed by GitHub
parent 970dfdc01d
commit dd53c4b023
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 191 additions and 2 deletions

View File

@ -225,8 +225,8 @@ async def async_request_openai_completions(
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"completions"
), "OpenAI Completions API URL must end with 'completions'."
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search

View File

@ -295,6 +295,7 @@ def calculate_metrics(
async def benchmark(
backend: str,
api_url: str,
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
@ -302,6 +303,7 @@ async def benchmark(
use_beam_search: bool,
request_rate: float,
disable_tqdm: bool,
profile: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
@ -326,6 +328,22 @@ async def benchmark(
f"are correctly specified. Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/start_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
print(f"Traffic request rate: {request_rate}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@ -349,6 +367,21 @@ async def benchmark(
pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
print("Stopping profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/stop_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler stopped")
if pbar is not None:
pbar.close()
@ -433,8 +466,10 @@ def main(args: argparse.Namespace):
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
tokenizer = get_tokenizer(tokenizer_id,
trust_remote_code=args.trust_remote_code)
@ -506,6 +541,7 @@ def main(args: argparse.Namespace):
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
@ -513,6 +549,7 @@ def main(args: argparse.Namespace):
use_beam_search=args.use_beam_search,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
))
# Save config and results to json
@ -693,6 +730,12 @@ if __name__ == "__main__":
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--save-result",
action="store_true",

View File

@ -0,0 +1,33 @@
Profiling vLLM
=================================
We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.
.. warning::
Only enable profiling in a development environment.
Traces can be visualized using https://ui.perfetto.dev/.
.. tip::
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
Example commands:
OpenAI Server:
.. code-block:: bash
VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
benchmark_serving.py:
.. code-block:: bash
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2

View File

@ -136,6 +136,7 @@ Documentation
dev/input_processing/model_inputs_index
dev/multimodal/multimodal_index
dev/dockerfile/dockerfile
dev/profiling/profiling_index
.. toctree::
:maxdepth: 1

View File

@ -1266,3 +1266,9 @@ class AsyncLLMEngine:
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
self.engine.model_executor._run_workers("start_profile")
async def stop_profile(self) -> None:
self.engine.model_executor._run_workers("stop_profile")

View File

@ -91,3 +91,11 @@ class AsyncEngineClient(Protocol):
async def check_health(self) -> None:
"""Raise if unhealthy"""
...
async def start_profile(self) -> None:
"""Start profiling the engine"""
...
async def stop_profile(self) -> None:
"""Start profiling the engine"""
...

View File

@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile():
logger.info("Starting profiler...")
await async_engine_client.start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile():
logger.info("Stopping profiler...")
await async_engine_client.stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)

View File

@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum):
DO_LOG_STATS = 7
IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9
START_PROFILE = 10
STOP_PROFILE = 11
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,

View File

@ -400,3 +400,17 @@ class AsyncEngineRPCClient:
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.START_PROFILE,
error_message="RPCRequest START_PROFILE failed.")
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")

View File

@ -124,6 +124,26 @@ class AsyncEngineRPCServer:
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
@ -153,6 +173,10 @@ class AsyncEngineRPCServer:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
elif request == RPCUtilityRequest.START_PROFILE:
return self.start_profile(identity)
elif request == RPCUtilityRequest.STOP_PROFILE:
return self.stop_profile(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

View File

@ -58,6 +58,7 @@ if TYPE_CHECKING:
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
def get_default_cache_root():
@ -384,6 +385,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR":
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
}
# end-env-vars-definition

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
@ -13,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@ -27,6 +29,8 @@ from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
logger = init_logger(__name__)
class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.
@ -113,6 +117,33 @@ class Worker(LocalOrDistributedWorkerBase):
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model