mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:05:48 +08:00
[misc] Add Torch profiler support (#7451)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
970dfdc01d
commit
dd53c4b023
@ -225,8 +225,8 @@ async def async_request_openai_completions(
|
|||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(
|
assert api_url.endswith(
|
||||||
"completions"
|
("completions", "profile")
|
||||||
), "OpenAI Completions API URL must end with 'completions'."
|
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
assert not request_func_input.use_beam_search
|
assert not request_func_input.use_beam_search
|
||||||
|
|||||||
@ -295,6 +295,7 @@ def calculate_metrics(
|
|||||||
async def benchmark(
|
async def benchmark(
|
||||||
backend: str,
|
backend: str,
|
||||||
api_url: str,
|
api_url: str,
|
||||||
|
base_url: str,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: List[Tuple[str, int, int]],
|
||||||
@ -302,6 +303,7 @@ async def benchmark(
|
|||||||
use_beam_search: bool,
|
use_beam_search: bool,
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
|
profile: bool,
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@ -326,6 +328,22 @@ async def benchmark(
|
|||||||
f"are correctly specified. Error: {test_output.error}")
|
f"are correctly specified. Error: {test_output.error}")
|
||||||
else:
|
else:
|
||||||
print("Initial test run completed. Starting main benchmark run...")
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
print("Starting profiler...")
|
||||||
|
profile_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=test_prompt,
|
||||||
|
api_url=base_url + "/start_profile",
|
||||||
|
prompt_len=test_prompt_len,
|
||||||
|
output_len=test_output_len,
|
||||||
|
best_of=best_of,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
)
|
||||||
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler started")
|
||||||
|
|
||||||
print(f"Traffic request rate: {request_rate}")
|
print(f"Traffic request rate: {request_rate}")
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
@ -349,6 +367,21 @@ async def benchmark(
|
|||||||
pbar=pbar)))
|
pbar=pbar)))
|
||||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
print("Stopping profiler...")
|
||||||
|
profile_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=test_prompt,
|
||||||
|
api_url=base_url + "/stop_profile",
|
||||||
|
prompt_len=test_prompt_len,
|
||||||
|
output_len=test_output_len,
|
||||||
|
best_of=best_of,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
)
|
||||||
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler stopped")
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
@ -433,8 +466,10 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
if args.base_url is not None:
|
if args.base_url is not None:
|
||||||
api_url = f"{args.base_url}{args.endpoint}"
|
api_url = f"{args.base_url}{args.endpoint}"
|
||||||
|
base_url = f"{args.base_url}"
|
||||||
else:
|
else:
|
||||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_id,
|
tokenizer = get_tokenizer(tokenizer_id,
|
||||||
trust_remote_code=args.trust_remote_code)
|
trust_remote_code=args.trust_remote_code)
|
||||||
@ -506,6 +541,7 @@ def main(args: argparse.Namespace):
|
|||||||
benchmark(
|
benchmark(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
|
base_url=base_url,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
input_requests=input_requests,
|
input_requests=input_requests,
|
||||||
@ -513,6 +549,7 @@ def main(args: argparse.Namespace):
|
|||||||
use_beam_search=args.use_beam_search,
|
use_beam_search=args.use_beam_search,
|
||||||
request_rate=args.request_rate,
|
request_rate=args.request_rate,
|
||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
profile=args.profile,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
@ -693,6 +730,12 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Specify to disable tqdm progress bar.",
|
help="Specify to disable tqdm progress bar.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Torch Profiler. The endpoint must be launched with "
|
||||||
|
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-result",
|
"--save-result",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
33
docs/source/dev/profiling/profiling_index.rst
Normal file
33
docs/source/dev/profiling/profiling_index.rst
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
Profiling vLLM
|
||||||
|
=================================
|
||||||
|
|
||||||
|
We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
|
||||||
|
|
||||||
|
The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
|
||||||
|
|
||||||
|
When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Only enable profiling in a development environment.
|
||||||
|
|
||||||
|
|
||||||
|
Traces can be visualized using https://ui.perfetto.dev/.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
|
||||||
|
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
|
||||||
|
|
||||||
|
Example commands:
|
||||||
|
|
||||||
|
OpenAI Server:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
|
||||||
|
|
||||||
|
benchmark_serving.py:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2
|
||||||
@ -136,6 +136,7 @@ Documentation
|
|||||||
dev/input_processing/model_inputs_index
|
dev/input_processing/model_inputs_index
|
||||||
dev/multimodal/multimodal_index
|
dev/multimodal/multimodal_index
|
||||||
dev/dockerfile/dockerfile
|
dev/dockerfile/dockerfile
|
||||||
|
dev/profiling/profiling_index
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|||||||
@ -1266,3 +1266,9 @@ class AsyncLLMEngine:
|
|||||||
logger_name=logger_name))
|
logger_name=logger_name))
|
||||||
else:
|
else:
|
||||||
self.engine.remove_logger(logger_name=logger_name)
|
self.engine.remove_logger(logger_name=logger_name)
|
||||||
|
|
||||||
|
async def start_profile(self) -> None:
|
||||||
|
self.engine.model_executor._run_workers("start_profile")
|
||||||
|
|
||||||
|
async def stop_profile(self) -> None:
|
||||||
|
self.engine.model_executor._run_workers("stop_profile")
|
||||||
|
|||||||
@ -91,3 +91,11 @@ class AsyncEngineClient(Protocol):
|
|||||||
async def check_health(self) -> None:
|
async def check_health(self) -> None:
|
||||||
"""Raise if unhealthy"""
|
"""Raise if unhealthy"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def start_profile(self) -> None:
|
||||||
|
"""Start profiling the engine"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def stop_profile(self) -> None:
|
||||||
|
"""Start profiling the engine"""
|
||||||
|
...
|
||||||
|
|||||||
@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|||||||
assert_never(generator)
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
logger.warning(
|
||||||
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
|
"used for local development!")
|
||||||
|
|
||||||
|
@router.post("/start_profile")
|
||||||
|
async def start_profile():
|
||||||
|
logger.info("Starting profiler...")
|
||||||
|
await async_engine_client.start_profile()
|
||||||
|
logger.info("Profiler started.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
@router.post("/stop_profile")
|
||||||
|
async def stop_profile():
|
||||||
|
logger.info("Stopping profiler...")
|
||||||
|
await async_engine_client.stop_profile()
|
||||||
|
logger.info("Profiler stopped.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def build_app(args: Namespace) -> FastAPI:
|
def build_app(args: Namespace) -> FastAPI:
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|||||||
@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum):
|
|||||||
DO_LOG_STATS = 7
|
DO_LOG_STATS = 7
|
||||||
IS_SERVER_HEALTHY = 8
|
IS_SERVER_HEALTHY = 8
|
||||||
IS_TRACING_ENABLED = 9
|
IS_TRACING_ENABLED = 9
|
||||||
|
START_PROFILE = 10
|
||||||
|
STOP_PROFILE = 11
|
||||||
|
|
||||||
|
|
||||||
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
|
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
|
||||||
|
|||||||
@ -400,3 +400,17 @@ class AsyncEngineRPCClient:
|
|||||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Embeddings not supported with multiprocessing backend")
|
"Embeddings not supported with multiprocessing backend")
|
||||||
|
|
||||||
|
async def start_profile(self) -> None:
|
||||||
|
"""Start profiling the engine"""
|
||||||
|
|
||||||
|
await self._send_one_way_rpc_request(
|
||||||
|
request=RPCUtilityRequest.START_PROFILE,
|
||||||
|
error_message="RPCRequest START_PROFILE failed.")
|
||||||
|
|
||||||
|
async def stop_profile(self) -> None:
|
||||||
|
"""Stop profiling the engine"""
|
||||||
|
|
||||||
|
await self._send_one_way_rpc_request(
|
||||||
|
request=RPCUtilityRequest.STOP_PROFILE,
|
||||||
|
error_message="RPCRequest STOP_PROFILE failed.")
|
||||||
@ -124,6 +124,26 @@ class AsyncEngineRPCServer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||||
|
|
||||||
|
async def start_profile(self, identity):
|
||||||
|
logger.info("Starting profiler...")
|
||||||
|
await self.engine.start_profile()
|
||||||
|
logger.info("Profiler started.")
|
||||||
|
|
||||||
|
await self.socket.send_multipart([
|
||||||
|
identity,
|
||||||
|
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def stop_profile(self, identity):
|
||||||
|
logger.info("Stopping profiler...")
|
||||||
|
await self.engine.stop_profile()
|
||||||
|
logger.info("Profiler stopped.")
|
||||||
|
|
||||||
|
await self.socket.send_multipart([
|
||||||
|
identity,
|
||||||
|
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||||
|
])
|
||||||
|
|
||||||
def _make_handler_coro(self, identity,
|
def _make_handler_coro(self, identity,
|
||||||
message) -> Coroutine[Any, Any, Never]:
|
message) -> Coroutine[Any, Any, Never]:
|
||||||
"""Route the zmq message to the handler coroutine."""
|
"""Route the zmq message to the handler coroutine."""
|
||||||
@ -153,6 +173,10 @@ class AsyncEngineRPCServer:
|
|||||||
return self.check_health(identity)
|
return self.check_health(identity)
|
||||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||||
return self.is_tracing_enabled(identity)
|
return self.is_tracing_enabled(identity)
|
||||||
|
elif request == RPCUtilityRequest.START_PROFILE:
|
||||||
|
return self.start_profile(identity)
|
||||||
|
elif request == RPCUtilityRequest.STOP_PROFILE:
|
||||||
|
return self.stop_profile(identity)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
|
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
|
||||||
|
|
||||||
|
|||||||
@ -58,6 +58,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||||
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
|
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
|
||||||
VLLM_PLUGINS: Optional[List[str]] = None
|
VLLM_PLUGINS: Optional[List[str]] = None
|
||||||
|
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -384,6 +385,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_PLUGINS":
|
"VLLM_PLUGINS":
|
||||||
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
|
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
|
||||||
"VLLM_PLUGINS"].split(","),
|
"VLLM_PLUGINS"].split(","),
|
||||||
|
|
||||||
|
# Enables torch profiler if set. Path to the directory where torch profiler
|
||||||
|
# traces are saved. Note that it must be an absolute path.
|
||||||
|
"VLLM_TORCH_PROFILER_DIR":
|
||||||
|
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
|
||||||
|
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig,
|
PromptAdapterConfig, SchedulerConfig,
|
||||||
@ -13,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
|||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
@ -27,6 +29,8 @@ from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
|||||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||||
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Worker(LocalOrDistributedWorkerBase):
|
class Worker(LocalOrDistributedWorkerBase):
|
||||||
"""A worker class that executes (a partition of) the model on a GPU.
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
@ -113,6 +117,33 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||||
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||||||
|
|
||||||
|
# Torch profiler. Enabled and configured through env vars:
|
||||||
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||||
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||||
|
torch_profiler_trace_dir)
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
|
torch_profiler_trace_dir, use_gzip=True))
|
||||||
|
else:
|
||||||
|
self.profiler = None
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.start()
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.stop()
|
||||||
|
|
||||||
def _is_encoder_decoder_model(self):
|
def _is_encoder_decoder_model(self):
|
||||||
return self.model_config.is_encoder_decoder_model
|
return self.model_config.is_encoder_decoder_model
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user