From 4ebc513fc15cf9f0b251751e2f87bb611b9db451 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 23 Sep 2025 18:50:09 -0400 Subject: [PATCH] Add `VLLM_NVTX_SCOPES_FOR_PROFILING=1` to enable `nvtx.annotate` scopes (#25501) Signed-off-by: Corey Lowman Signed-off-by: yewentao256 --- vllm/envs.py | 5 +++++ vllm/v1/utils.py | 20 +++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 1c6c1e78ac9b1..33dae0be05f8d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -187,6 +187,7 @@ if TYPE_CHECKING: VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False + VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024 @@ -1387,6 +1388,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), + # Add optional nvtx scopes for profiling, disable to avoid overheads + "VLLM_NVTX_SCOPES_FOR_PROFILING": + lambda: bool(int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0"))), + # Represent block hashes in KV cache events as 64-bit integers instead of # raw bytes. Defaults to True for backward compatibility. "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index fd84b4a111f58..ec4417290f611 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -375,8 +375,22 @@ def report_usage_stats( }) +_PROFILER_FUNC = None + + def record_function_or_nullcontext(name: str) -> AbstractContextManager: + global _PROFILER_FUNC + + # fast path assume it is set + if _PROFILER_FUNC is not None: + return _PROFILER_FUNC(name) + + func = contextlib.nullcontext if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: - return record_function(name) - else: - return contextlib.nullcontext() + func = record_function + elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: + import nvtx + func = nvtx.annotate + + _PROFILER_FUNC = func + return func(name)