mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[misc] Add Torch profiler support (#7451)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
970dfdc01d
commit
dd53c4b023
@ -225,8 +225,8 @@ async def async_request_openai_completions(
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"completions"
|
||||
), "OpenAI Completions API URL must end with 'completions'."
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
|
||||
@ -295,6 +295,7 @@ def calculate_metrics(
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
base_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
@ -302,6 +303,7 @@ async def benchmark(
|
||||
use_beam_search: bool,
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -326,6 +328,22 @@ async def benchmark(
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
profile_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/start_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
@ -349,6 +367,21 @@ async def benchmark(
|
||||
pbar=pbar)))
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
print("Stopping profiler...")
|
||||
profile_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/stop_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler stopped")
|
||||
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
@ -433,8 +466,10 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
base_url = f"{args.base_url}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
@ -506,6 +541,7 @@ def main(args: argparse.Namespace):
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
@ -513,6 +549,7 @@ def main(args: argparse.Namespace):
|
||||
use_beam_search=args.use_beam_search,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
@ -693,6 +730,12 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Specify to disable tqdm progress bar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-result",
|
||||
action="store_true",
|
||||
|
||||
33
docs/source/dev/profiling/profiling_index.rst
Normal file
33
docs/source/dev/profiling/profiling_index.rst
Normal file
@ -0,0 +1,33 @@
|
||||
Profiling vLLM
|
||||
=================================
|
||||
|
||||
We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
|
||||
|
||||
The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
|
||||
|
||||
When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.
|
||||
|
||||
.. warning::
|
||||
|
||||
Only enable profiling in a development environment.
|
||||
|
||||
|
||||
Traces can be visualized using https://ui.perfetto.dev/.
|
||||
|
||||
.. tip::
|
||||
|
||||
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
|
||||
|
||||
Example commands:
|
||||
|
||||
OpenAI Server:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
|
||||
|
||||
benchmark_serving.py:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2
|
||||
@ -136,6 +136,7 @@ Documentation
|
||||
dev/input_processing/model_inputs_index
|
||||
dev/multimodal/multimodal_index
|
||||
dev/dockerfile/dockerfile
|
||||
dev/profiling/profiling_index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
@ -1266,3 +1266,9 @@ class AsyncLLMEngine:
|
||||
logger_name=logger_name))
|
||||
else:
|
||||
self.engine.remove_logger(logger_name=logger_name)
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
self.engine.model_executor._run_workers("start_profile")
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
self.engine.model_executor._run_workers("stop_profile")
|
||||
|
||||
@ -91,3 +91,11 @@ class AsyncEngineClient(Protocol):
|
||||
async def check_health(self) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
...
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
...
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
...
|
||||
|
||||
@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile():
|
||||
logger.info("Starting profiler...")
|
||||
await async_engine_client.start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile():
|
||||
logger.info("Stopping profiler...")
|
||||
await async_engine_client.stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
|
||||
@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum):
|
||||
DO_LOG_STATS = 7
|
||||
IS_SERVER_HEALTHY = 8
|
||||
IS_TRACING_ENABLED = 9
|
||||
START_PROFILE = 10
|
||||
STOP_PROFILE = 11
|
||||
|
||||
|
||||
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
|
||||
|
||||
@ -400,3 +400,17 @@ class AsyncEngineRPCClient:
|
||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
raise NotImplementedError(
|
||||
"Embeddings not supported with multiprocessing backend")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.START_PROFILE,
|
||||
error_message="RPCRequest START_PROFILE failed.")
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
"""Stop profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.STOP_PROFILE,
|
||||
error_message="RPCRequest STOP_PROFILE failed.")
|
||||
@ -124,6 +124,26 @@ class AsyncEngineRPCServer:
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
async def start_profile(self, identity):
|
||||
logger.info("Starting profiler...")
|
||||
await self.engine.start_profile()
|
||||
logger.info("Profiler started.")
|
||||
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def stop_profile(self, identity):
|
||||
logger.info("Stopping profiler...")
|
||||
await self.engine.stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
def _make_handler_coro(self, identity,
|
||||
message) -> Coroutine[Any, Any, Never]:
|
||||
"""Route the zmq message to the handler coroutine."""
|
||||
@ -153,6 +173,10 @@ class AsyncEngineRPCServer:
|
||||
return self.check_health(identity)
|
||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||
return self.is_tracing_enabled(identity)
|
||||
elif request == RPCUtilityRequest.START_PROFILE:
|
||||
return self.start_profile(identity)
|
||||
elif request == RPCUtilityRequest.STOP_PROFILE:
|
||||
return self.stop_profile(identity)
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
|
||||
|
||||
|
||||
@ -58,6 +58,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
|
||||
VLLM_PLUGINS: Optional[List[str]] = None
|
||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -384,6 +385,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_PLUGINS":
|
||||
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
|
||||
"VLLM_PLUGINS"].split(","),
|
||||
|
||||
# Enables torch profiler if set. Path to the directory where torch profiler
|
||||
# traces are saved. Note that it must be an absolute path.
|
||||
"VLLM_TORCH_PROFILER_DIR":
|
||||
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
|
||||
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
@ -13,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
@ -27,6 +29,8 @@ from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Worker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a GPU.
|
||||
@ -113,6 +117,33 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user