mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 22:44:29 +08:00
[Feat] Iteration-level profiling for Torch and CUDA profiler (#28987)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
3168285fca
commit
fcbcba6c70
203
tests/v1/worker/test_gpu_profiler.py
Normal file
203
tests/v1/worker/test_gpu_profiler.py
Normal file
@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.profiler.gpu_profiler import WorkerProfiler
|
||||
|
||||
|
||||
class ConcreteWorkerProfiler(WorkerProfiler):
|
||||
"""
|
||||
A basic implementation of a worker profiler for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_call_count = 0
|
||||
self.stop_call_count = 0
|
||||
self.should_fail_start = False
|
||||
super().__init__()
|
||||
|
||||
def _start(self) -> None:
|
||||
if self.should_fail_start:
|
||||
raise RuntimeError("Simulated start failure")
|
||||
self.start_call_count += 1
|
||||
|
||||
def _stop(self) -> None:
|
||||
self.stop_call_count += 1
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_mocks():
|
||||
"""Fixture to reset mocks and env variables before each test."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 0
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 0
|
||||
|
||||
|
||||
def test_immediate_start_stop():
|
||||
"""Test standard start without delay."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
profiler.start()
|
||||
assert profiler._running is True
|
||||
assert profiler._active is True
|
||||
assert profiler.start_call_count == 1
|
||||
|
||||
profiler.stop()
|
||||
assert profiler._running is False
|
||||
assert profiler._active is False
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_delayed_start():
|
||||
"""Test that profiler waits for N steps before actually starting."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
# User requests start
|
||||
profiler.start()
|
||||
|
||||
# Should be active (request accepted) but not running (waiting for delay)
|
||||
assert profiler._active is True
|
||||
assert profiler._running is False
|
||||
assert profiler.start_call_count == 0
|
||||
|
||||
# Step 1
|
||||
profiler.step()
|
||||
assert profiler._running is False
|
||||
|
||||
# Step 2 (Threshold reached)
|
||||
profiler.step()
|
||||
assert profiler._running is True
|
||||
assert profiler.start_call_count == 1
|
||||
|
||||
|
||||
def test_max_iterations():
|
||||
"""Test that profiler stops automatically after max iterations."""
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
profiler.start()
|
||||
assert profiler._running is True
|
||||
|
||||
# Iteration 1
|
||||
profiler.step() # profiling_count becomes 1
|
||||
assert profiler._running is True
|
||||
|
||||
# Iteration 2
|
||||
profiler.step() # profiling_count becomes 2
|
||||
assert profiler._running is True
|
||||
|
||||
# Iteration 3 (Exceeds max)
|
||||
profiler.step() # profiling_count becomes 3
|
||||
|
||||
# Should have stopped now
|
||||
assert profiler._running is False
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_delayed_start_and_max_iters():
|
||||
"""Test combined delayed start and max iterations."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
envs.VLLM_PROFILER_MAX_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
profiler.start()
|
||||
|
||||
# Step 1
|
||||
profiler.step()
|
||||
assert profiler._running is False
|
||||
assert profiler._active is True
|
||||
|
||||
# Step 2 (Starts now)
|
||||
profiler.step()
|
||||
assert profiler._profiling_for_iters == 1
|
||||
assert profiler._running is True
|
||||
assert profiler._active is True
|
||||
|
||||
# Next iteration
|
||||
profiler.step()
|
||||
assert profiler._profiling_for_iters == 2
|
||||
assert profiler._running is True
|
||||
|
||||
# Iteration 2 (exceeds max)
|
||||
profiler.step()
|
||||
|
||||
# Should have stopped now
|
||||
assert profiler._running is False
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_idempotency():
|
||||
"""Test that calling start/stop multiple times doesn't break logic."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
# Double Start
|
||||
profiler.start()
|
||||
profiler.start()
|
||||
assert profiler.start_call_count == 1 # Should only start once
|
||||
|
||||
# Double Stop
|
||||
profiler.stop()
|
||||
profiler.stop()
|
||||
assert profiler.stop_call_count == 1 # Should only stop once
|
||||
|
||||
|
||||
def test_step_inactive():
|
||||
"""Test that stepping while inactive does nothing."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 2
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
# Not started yet
|
||||
profiler.step()
|
||||
profiler.step()
|
||||
|
||||
# Even though we stepped 2 times, start shouldn't happen because active=False
|
||||
assert profiler.start_call_count == 0
|
||||
|
||||
|
||||
def test_start_failure():
|
||||
"""Test behavior when the underlying _start method raises exception."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
profiler.should_fail_start = True
|
||||
|
||||
profiler.start()
|
||||
|
||||
# Exception caught in _call_start
|
||||
assert profiler._running is False # Should not mark as running
|
||||
assert profiler._active is True # Request is still considered active
|
||||
assert profiler.start_call_count == 0 # Logic failed inside start
|
||||
|
||||
|
||||
def test_shutdown():
|
||||
"""Test that shutdown calls stop only if running."""
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
# Case 1: Not running
|
||||
profiler.shutdown()
|
||||
assert profiler.stop_call_count == 0
|
||||
|
||||
# Case 2: Running
|
||||
profiler.start()
|
||||
profiler.shutdown()
|
||||
assert profiler.stop_call_count == 1
|
||||
|
||||
|
||||
def test_mixed_delay_and_stop():
|
||||
"""Test manual stop during the delay period."""
|
||||
envs.VLLM_PROFILER_DELAY_ITERS = 5
|
||||
profiler = ConcreteWorkerProfiler()
|
||||
|
||||
profiler.start()
|
||||
profiler.step()
|
||||
profiler.step()
|
||||
|
||||
# User cancels before delay finishes
|
||||
profiler.stop()
|
||||
assert profiler._active is False
|
||||
|
||||
# Further steps should not trigger start
|
||||
profiler.step()
|
||||
profiler.step()
|
||||
profiler.step()
|
||||
|
||||
assert profiler.start_call_count == 0
|
||||
16
vllm/envs.py
16
vllm/envs.py
@ -92,11 +92,14 @@ if TYPE_CHECKING:
|
||||
VLLM_TORCH_PROFILER_DIR: str | None = None
|
||||
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
|
||||
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False
|
||||
VLLM_USE_AOT_COMPILE: bool = False
|
||||
VLLM_USE_BYTECODE_HOOK: bool = False
|
||||
VLLM_FORCE_AOT_LOAD: bool = False
|
||||
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
|
||||
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
|
||||
VLLM_PROFILER_DELAY_ITERS: int = 0
|
||||
VLLM_PROFILER_MAX_ITERS: int = 0
|
||||
VLLM_USE_TRITON_AWQ: bool = False
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
@ -872,6 +875,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
|
||||
),
|
||||
# Disable torch profiling of the AsyncLLMEngine process.
|
||||
# If set to 1, will not profile the engine process.
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
|
||||
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
|
||||
),
|
||||
# Delay number of iterations before starting profiling when using
|
||||
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
|
||||
"VLLM_PROFILER_DELAY_ITERS": lambda: int(
|
||||
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0")
|
||||
),
|
||||
# Maximum number of iterations to profile when using the torch/torch CUDA profiler.
|
||||
# If set to 0, will not limit the number of iterations.
|
||||
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")),
|
||||
# If set, vLLM will use Triton implementations of AWQ.
|
||||
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
|
||||
# If set, allow loading or unloading lora adapters in runtime,
|
||||
|
||||
@ -1,37 +1,212 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudaProfilerWrapper:
|
||||
class WorkerProfiler(ABC):
|
||||
def __init__(self) -> None:
|
||||
self._profiler_running = False
|
||||
self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS
|
||||
if self._delay_iters > 0:
|
||||
logger.info_once(
|
||||
"GPU profiling will start "
|
||||
f"{self._delay_iters} steps after start_profile."
|
||||
)
|
||||
|
||||
self._max_iters = envs.VLLM_PROFILER_MAX_ITERS
|
||||
if self._max_iters > 0:
|
||||
logger.info_once(
|
||||
"GPU profiling will stop "
|
||||
f"after {self._max_iters} worker steps, "
|
||||
"or when stop_profile is received."
|
||||
)
|
||||
|
||||
# Track when the profiler gets triggered by start_profile
|
||||
self._active_iteration_count = 0
|
||||
self._active = False
|
||||
|
||||
# Track when the profiler is actually running
|
||||
self._profiling_for_iters = 0
|
||||
self._running = False
|
||||
|
||||
@abstractmethod
|
||||
def _start(self) -> None:
|
||||
"""Start the profiler."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _stop(self) -> None:
|
||||
"""Stop the profiler."""
|
||||
pass
|
||||
|
||||
def _call_start(self) -> None:
|
||||
"""Call _start with error handling but no safeguards."""
|
||||
try:
|
||||
self._start()
|
||||
self._running = True # Only mark as running if start succeeds
|
||||
except Exception as e:
|
||||
logger.warning("Failed to start profiler: %s", e)
|
||||
|
||||
def _call_stop(self) -> None:
|
||||
"""Call _stop with error handling but no safeguards."""
|
||||
try:
|
||||
self._stop()
|
||||
logger.info("Profiler stopped successfully.")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to stop profiler: %s", e)
|
||||
self._running = False # Always mark as not running, assume stop worked
|
||||
|
||||
def start(self) -> None:
|
||||
"""Attempt to start the profiler, accounting for delayed starts."""
|
||||
if self._active:
|
||||
logger.debug(
|
||||
"start_profile received when profiler is already active. "
|
||||
"Ignoring request."
|
||||
)
|
||||
return
|
||||
self._active = True
|
||||
if self._delay_iters == 0:
|
||||
self._call_start()
|
||||
|
||||
def step(self) -> None:
|
||||
"""Update the profiler state at each worker step,
|
||||
to handle delayed starts and max iteration limits."""
|
||||
if not self._active:
|
||||
return
|
||||
|
||||
self._active_iteration_count += 1
|
||||
|
||||
if (
|
||||
not self._running
|
||||
and self._delay_iters > 0
|
||||
and self._active_iteration_count == self._delay_iters
|
||||
):
|
||||
logger.info("Starting profiler after delay...")
|
||||
self._call_start()
|
||||
|
||||
if self._running:
|
||||
self._profiling_for_iters += 1
|
||||
|
||||
if (
|
||||
self._max_iters > 0
|
||||
and self._running
|
||||
and self._profiling_for_iters > self._max_iters
|
||||
):
|
||||
# Automatically stop the profiler after max iters
|
||||
# will be marked as not running, but leave as active so that stop
|
||||
# can clean up properly
|
||||
logger.info("Max profiling iterations reached. Stopping profiler...")
|
||||
self._call_stop()
|
||||
return
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Attempt to stop the profiler, accounting for overlapped calls."""
|
||||
if not self._active:
|
||||
logger.debug(
|
||||
"stop_profile received when profiler is not active. Ignoring request."
|
||||
)
|
||||
return
|
||||
self._active = False
|
||||
self._active_iteration_count = 0
|
||||
self._profiling_for_iters = 0
|
||||
|
||||
if self._running:
|
||||
self._call_stop()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Ensure profiler is stopped when shutting down."""
|
||||
logger.info_once("Shutting down profiler")
|
||||
if self._running:
|
||||
self.stop()
|
||||
|
||||
def annotate_context_manager(self, name: str):
|
||||
"""Return a context manager to annotate profiler traces."""
|
||||
return nullcontext()
|
||||
|
||||
|
||||
class TorchProfilerWrapper(WorkerProfiler):
|
||||
def __init__(self, worker_name: str, local_rank: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.local_rank = local_rank
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info(
|
||||
"Torch profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir,
|
||||
)
|
||||
logger.debug(
|
||||
"Profiler config: record_shapes=%s,"
|
||||
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
||||
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
def _start(self) -> None:
|
||||
self.profiler.start()
|
||||
|
||||
@override
|
||||
def _stop(self) -> None:
|
||||
self.profiler.stop()
|
||||
|
||||
rank = self.local_rank
|
||||
profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
|
||||
sort_key = "self_cuda_time_total"
|
||||
table = self.profiler.key_averages().table(sort_by=sort_key)
|
||||
|
||||
with open(profiler_out_file, "w") as f:
|
||||
print(table, file=f)
|
||||
|
||||
# only print profiler results on rank 0
|
||||
if rank == 0:
|
||||
print(table)
|
||||
|
||||
@override
|
||||
def annotate_context_manager(self, name: str):
|
||||
return torch.profiler.record_function(name)
|
||||
|
||||
|
||||
class CudaProfilerWrapper(WorkerProfiler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Note: lazy import to avoid dependency issues if CUDA is not available.
|
||||
import torch.cuda.profiler as cuda_profiler
|
||||
|
||||
self._cuda_profiler = cuda_profiler
|
||||
|
||||
def start(self) -> None:
|
||||
try:
|
||||
self._cuda_profiler.start()
|
||||
self._profiler_running = True
|
||||
logger.info_once("Started CUDA profiler")
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to start CUDA profiler: %s", e)
|
||||
@override
|
||||
def _start(self) -> None:
|
||||
self._cuda_profiler.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._profiler_running:
|
||||
try:
|
||||
self._cuda_profiler.stop()
|
||||
logger.info_once("Stopped CUDA profiler")
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to stop CUDA profiler: %s", e)
|
||||
finally:
|
||||
self._profiler_running = False
|
||||
@override
|
||||
def _stop(self) -> None:
|
||||
self._cuda_profiler.stop()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Ensure profiler is stopped when shutting down."""
|
||||
self.stop()
|
||||
@override
|
||||
def annotate_context_manager(self, name: str):
|
||||
return torch.cuda.nvtx.range(name)
|
||||
|
||||
@ -160,11 +160,23 @@ class AsyncLLM(EngineClient):
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
if (
|
||||
envs.VLLM_TORCH_PROFILER_DIR
|
||||
and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM
|
||||
):
|
||||
logger.info(
|
||||
"Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501
|
||||
envs.VLLM_TORCH_PROFILER_DIR,
|
||||
)
|
||||
if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0:
|
||||
logger.warning_once(
|
||||
"Torch profiler received max_iters or delay_iters setting. These "
|
||||
"are not compatible with the AsyncLLM profiler and will be ignored "
|
||||
"for the AsyncLLM process. Engine process profiling will still "
|
||||
"respect these settings. Consider setting "
|
||||
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable "
|
||||
"AsyncLLM profiling."
|
||||
)
|
||||
worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm"
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.gpu_profiler import CudaProfilerWrapper
|
||||
from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
@ -90,32 +90,9 @@ class Worker(WorkerBase):
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
logger.info(
|
||||
"Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir,
|
||||
)
|
||||
logger.debug(
|
||||
"Profiler config: record_shapes=%s,"
|
||||
"profile_memory=%s,with_stack=%s,with_flops=%s",
|
||||
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
|
||||
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
|
||||
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
|
||||
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
|
||||
),
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
worker_name=worker_name, local_rank=self.local_rank
|
||||
)
|
||||
elif envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
self.profiler = CudaProfilerWrapper()
|
||||
@ -526,10 +503,12 @@ class Worker(WorkerBase):
|
||||
if not self.profiler:
|
||||
return nullcontext()
|
||||
|
||||
self.profiler.step()
|
||||
|
||||
num_new = len(scheduler_output.scheduled_new_reqs)
|
||||
num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)
|
||||
|
||||
return torch.profiler.record_function(
|
||||
return self.profiler.annotate_context_manager(
|
||||
f"execute_new_{num_new}_cached_{num_cached}"
|
||||
)
|
||||
|
||||
@ -587,24 +566,11 @@ class Worker(WorkerBase):
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
raise RuntimeError("Profiling is not enabled.")
|
||||
if is_start:
|
||||
self.profiler.start()
|
||||
else:
|
||||
self.profiler.stop()
|
||||
if isinstance(self.profiler, torch.profiler.profile):
|
||||
rank = self.local_rank
|
||||
profiler_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
|
||||
sort_key = "self_cuda_time_total"
|
||||
table = self.profiler.key_averages().table(sort_by=sort_key)
|
||||
|
||||
with open(profiler_out_file, "w") as f:
|
||||
print(table, file=f)
|
||||
|
||||
# only print profiler results on rank 0
|
||||
if rank == 0:
|
||||
print(table)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1, uniform_decode=True)
|
||||
@ -865,6 +831,8 @@ class Worker(WorkerBase):
|
||||
def shutdown(self) -> None:
|
||||
if runner := getattr(self, "model_runner", None):
|
||||
runner.ensure_kv_transfer_shutdown()
|
||||
if self.profiler is not None:
|
||||
self.profiler.shutdown()
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user