mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 07:45:01 +08:00
[TPU] Add example for profiling TPU inference (#12531)
Signed-off-by: mgoin <mgoin@redhat.com>
This commit is contained in:
parent
80fcc3ed1c
commit
fbb5bd4cef
67
examples/offline_inference/profiling_tpu/README.md
Normal file
67
examples/offline_inference/profiling_tpu/README.md
Normal file
@ -0,0 +1,67 @@
|
||||
# vLLM TPU Profiling
|
||||
|
||||
This script is used to profile the TPU performance of vLLM for specific prefill or decode token shapes.
|
||||
|
||||
Note: an actual running server is a mix of both prefill of many shapes and decode of many shapes.
|
||||
|
||||
We assume you are on a TPU already (this was tested on TPU v6e) and have installed vLLM according to the [installation guide](https://docs.vllm.ai/en/latest/getting_started/installation/ai_accelerator/index.html).
|
||||
|
||||
> In all examples below, we run several warmups before (so `--enforce-eager` is okay)
|
||||
|
||||
## Profile Examples
|
||||
|
||||
### Generate Prefill Trace
|
||||
|
||||
This example runs Qwen/Qwen2.5-7B-Instruct with a single request of 1024 input tokens. This is set up in attempt to profile just the prefill time and operations.
|
||||
|
||||
```bash
|
||||
export XLA_HLO_DEBUG=1
|
||||
export MODEL=Qwen/Qwen2.5-7B-Instruct
|
||||
export VLLM_TPU_PROFILE_DURATION_MS=3000
|
||||
export VLLM_TPU_PROFILE_DELAY_MS=0
|
||||
|
||||
python3 profiling.py \
|
||||
--model $MODEL \
|
||||
--input-len 1024 --output-len 1 \
|
||||
--batch-size 1 --enforce-eager \
|
||||
--max-model-len 2048 \
|
||||
--tensor-parallel-size 1 \
|
||||
--profile-result-dir profiles
|
||||
```
|
||||
|
||||
|
||||
### Generate Decode Trace
|
||||
|
||||
This example runs Llama 3.1 70B with a batch of 32 requests where each has 1 input token and 128 output tokens. This is set up in attempt to profile just the 32 decodes running in parallel by having an extremely small prefill of 1 token and setting `VLLM_TPU_PROFILE_DELAY_MS=1000` to skip the first second of inference (hopefully prefill).
|
||||
|
||||
```bash
|
||||
export XLA_HLO_DEBUG=1
|
||||
export MODEL=meta-llama/Llama-3.1-70B-Instruct
|
||||
export VLLM_TPU_PROFILE_DURATION_MS=2000
|
||||
export VLLM_TPU_PROFILE_DELAY_MS=1000
|
||||
|
||||
rm -rf ~/.cache/vllm/xla_cache
|
||||
python3 profiling.py \
|
||||
--model $MODEL \
|
||||
--input-len 1 \
|
||||
--output-len 128 \
|
||||
--batch-size 32 \
|
||||
--enforce-eager \
|
||||
--profile-result-dir profiles \
|
||||
--max-model-len 2048 --tensor-parallel-size 8
|
||||
```
|
||||
|
||||
|
||||
## Visualizing the profiles
|
||||
|
||||
Once you have collected your profiles with this script, you can visualize them using [TensorBoard](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).
|
||||
|
||||
Here are most likely the dependencies you need to install:
|
||||
```bash
|
||||
pip install tensorflow-cpu tensorboard-plugin-profile etils importlib_resources
|
||||
```
|
||||
|
||||
Then you just need to point TensorBoard to the directory where you saved the profiles and visit `http://localhost:6006/` in your browser:
|
||||
```bash
|
||||
tensorboard --logdir profiles/ --port 6006
|
||||
```
|
||||
101
examples/offline_inference/profiling_tpu/profiling.py
Normal file
101
examples/offline_inference/profiling_tpu/profiling.py
Normal file
@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch_xla.debug.profiler as xp
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000))
|
||||
DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0))
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
_ = xp.start_server(9012)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompts: List[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
def run_to_completion():
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
# Warmup
|
||||
print("Warming up...")
|
||||
warmup_latencies = []
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
warmup_latencies.append(run_to_completion())
|
||||
print(f"Average warmup latency: {np.mean(warmup_latencies):.4f}s")
|
||||
|
||||
# Profile
|
||||
profile_dir = args.profile_result_dir
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
# Enable tracing on server
|
||||
xp.trace_detached("localhost:9012",
|
||||
profile_dir,
|
||||
delay_ms=DELAY_MS,
|
||||
duration_ms=DURATION_MS)
|
||||
if DELAY_MS == 0:
|
||||
time.sleep(1.0)
|
||||
profile_latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profile iterations"):
|
||||
profile_latencies.append(run_to_completion())
|
||||
print(f"Average profile latency: {np.mean(profile_latencies):.4f}s")
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--num-iters-warmup',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Number of iterations to run for warmup.')
|
||||
parser.add_argument('--num-iters',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of iterations to run for profiling.')
|
||||
parser.add_argument(
|
||||
'--profile-result-dir',
|
||||
type=str,
|
||||
default="profiles",
|
||||
help=
|
||||
('path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard '
|
||||
'(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).'
|
||||
))
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Loading…
x
Reference in New Issue
Block a user