[Feature] Add Layer-wise NVTX Support (#29990)

Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: Max Hu <maxhu@nvidia.com>
Co-authored-by: Max Hu <maxhu@nvidia.com>
This commit is contained in:
Max Hu 2025-12-05 06:20:07 -05:00 committed by GitHub
parent 3628bcaaf2
commit c2894d3883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 375 additions and 3 deletions

View File

@ -14,6 +14,7 @@ import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__)
@ -92,12 +93,29 @@ class TorchCompileWithNoGuardsWrapper:
return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
if self.layerwise_nvtx_tracing_enabled:
args_list = list(args)
kwargs_dict = dict(kwargs)
with layerwise_nvtx_marker_context(
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
self,
in_tensor=args_list,
kwargs=kwargs_dict,
) as ctx:
ctx.result = callable_fn(*args, **kwargs)
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self):
self.compiled = False
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
mode = vllm_config.compilation_config.mode
self.layerwise_nvtx_tracing_enabled = (
vllm_config.observability_config.enable_layerwise_nvtx_tracing
)
if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
@ -168,13 +186,19 @@ class TorchCompileWithNoGuardsWrapper:
# Make sure a compilation is triggered by clearing dynamo
# cache.
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
return self._compiled_callable(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
else:
with self._dispatch_to_compiled_code():
return self.forward(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self.forward, *args, **kwargs
)
else:
with _compilation_context():
return self._compiled_callable(*args, **kwargs)
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)
@abstractmethod
def forward(self, *args, **kwargs): ...

View File

@ -59,6 +59,11 @@ class ObservabilityConfig:
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
dispatch modes, and their observed frequencies at every logging interval)."""
enable_layerwise_nvtx_tracing: bool = False
"""Enable layerwise NVTX tracing. This traces the execution of each layer or
module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""

View File

@ -519,6 +519,9 @@ class EngineArgs:
ObservabilityConfig, "kv_cache_metrics_sample"
)
cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
enable_layerwise_nvtx_tracing: bool = (
ObservabilityConfig.enable_layerwise_nvtx_tracing
)
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
@ -1026,6 +1029,10 @@ class EngineArgs:
"--cudagraph-metrics",
**observability_kwargs["cudagraph_metrics"],
)
observability_group.add_argument(
"--enable-layerwise-nvtx-tracing",
**observability_kwargs["enable_layerwise_nvtx_tracing"],
)
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
@ -1704,6 +1711,7 @@ class EngineArgs:
kv_cache_metrics=self.kv_cache_metrics,
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
cudagraph_metrics=self.cudagraph_metrics,
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
)
# Compilation config overrides

View File

@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import torch
import torch.cuda.nvtx as nvtx
def print_tensor(tensor_obj, prefix, tensor_list=None):
"""Descends iterators that contains Tensors and prints the Tensor.
Recursive function that descends iterator type arguments until
it finds a Tensor object.
"""
if tensor_list is None:
tensor_list = []
if isinstance(tensor_obj, (list, tuple)):
for ten in tensor_obj:
tensor_list = print_tensor(ten, prefix, tensor_list)
elif isinstance(tensor_obj, torch.Tensor):
tensor_dims = list(tensor_obj.size())
tensor_list.append(tensor_dims)
return tensor_list
def process_layer_params(module_obj):
"""Extract the static parameters from LLM and VLM relevant layer types"""
param_info = {}
# Extract parameters for layers commonly used in LLMs and VLMs
if isinstance(module_obj, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
conv_params = {}
conv_params["in_chan"] = module_obj.in_channels
conv_params["out_chan"] = module_obj.out_channels
conv_params["filter_dim"] = module_obj.kernel_size
conv_params["stride"] = module_obj.stride
conv_params["padding"] = module_obj.padding
conv_params["dilation"] = module_obj.dilation
conv_params["transposed"] = module_obj.transposed
conv_params["output_padding"] = module_obj.output_padding
conv_params["groups"] = module_obj.groups
conv_params["padding_mode"] = module_obj.padding_mode
param_info = conv_params
elif isinstance(
module_obj,
(
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
),
):
convtranspose_params = {}
convtranspose_params["in_chan"] = module_obj.in_channels
convtranspose_params["out_chan"] = module_obj.out_channels
convtranspose_params["filter_dim"] = module_obj.kernel_size
convtranspose_params["stride"] = module_obj.stride
convtranspose_params["padding"] = module_obj.padding
convtranspose_params["dilation"] = module_obj.dilation
convtranspose_params["transposed"] = module_obj.transposed
convtranspose_params["output_padding"] = module_obj.output_padding
convtranspose_params["groups"] = module_obj.groups
convtranspose_params["padding_mode"] = module_obj.padding_mode
param_info = convtranspose_params
elif isinstance(
module_obj, (torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)
):
def _handle_int_or_tuple(parameter):
if isinstance(parameter, tuple):
return list(parameter)
elif isinstance(parameter, int):
return [parameter, parameter]
pooling_params = {}
pooling_params["filter_dim"] = _handle_int_or_tuple(module_obj.kernel_size)
pooling_params["stride"] = _handle_int_or_tuple(module_obj.stride)
pooling_params["padding"] = _handle_int_or_tuple(module_obj.padding)
pooling_params["dilation"] = _handle_int_or_tuple(module_obj.dilation)
param_info = pooling_params
elif isinstance(
module_obj, (torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d)
):
pooling_params = {}
pooling_params["filter_dim"] = [
module_obj.kernel_size,
module_obj.kernel_size,
]
pooling_params["stride"] = [module_obj.stride, module_obj.stride]
pooling_params["padding"] = [module_obj.padding, module_obj.padding]
pooling_params["ceil_mode"] = module_obj.ceil_mode
pooling_params["count_include_pad"] = module_obj.count_include_pad
param_info = pooling_params
elif isinstance(
module_obj,
(
torch.nn.AdaptiveAvgPool1d,
torch.nn.AdaptiveAvgPool2d,
torch.nn.AdaptiveAvgPool3d,
),
):
pooling_params = {}
pooling_params["output_size"] = [
module_obj.output_size,
module_obj.output_size,
]
param_info = pooling_params
elif isinstance(module_obj, torch.nn.Linear):
param_info["in_features"] = module_obj.in_features
param_info["out_features"] = module_obj.out_features
elif isinstance(
module_obj,
(torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d),
):
param_info["num_features"] = module_obj.num_features
param_info["epsilon"] = module_obj.eps
param_info["momentum"] = module_obj.momentum
elif isinstance(module_obj, torch.nn.ReLU):
param_info["in_place"] = module_obj.inplace
elif isinstance(module_obj, torch.nn.Dropout):
param_info["p"] = module_obj.p
param_info["in_place"] = module_obj.inplace
elif isinstance(module_obj, torch.nn.Embedding):
param_info["num_embeddings"] = module_obj.num_embeddings
param_info["embedding_dim"] = module_obj.embedding_dim
elif isinstance(
module_obj,
(
torch.nn.Upsample,
torch.nn.UpsamplingNearest2d,
torch.nn.UpsamplingBilinear2d,
),
):
param_info["scale_factor"] = module_obj.scale_factor
return param_info
def construct_marker_dict_and_push(
module_name, module_obj, in_tensor, kwargs=None, out_tensor=None
):
marker_dict = {}
marker_dict["Module"] = module_name
## Get trainable parameters like weights and bias
module_params = module_obj.named_parameters(recurse=False)
for idx, (param_name, param_obj) in enumerate(module_params):
if idx == 0:
marker_dict["TrainableParams"] = {}
marker_dict["TrainableParams"][param_name] = list(param_obj.size())
in_tensor_list = print_tensor(in_tensor, "Input")
if in_tensor_list:
marker_dict["Inputs"] = in_tensor_list
out_tensor_list = print_tensor(out_tensor, "Output")
if out_tensor_list:
marker_dict["Outputs"] = out_tensor_list
## Get Kwargs like input_ids and positions for the top module
if kwargs:
for key, value in kwargs.items():
if isinstance(value, (torch.Tensor, list, tuple)):
tensor_list = print_tensor(value, key)
if tensor_list:
marker_dict[key] = tensor_list
param_info = process_layer_params(module_obj)
if param_info:
marker_dict["StaticParams"] = param_info
nvtx.range_push("{}".format(marker_dict))
class ResultHolder:
"""Holder for storing results from within a context manager."""
result = None
@contextmanager
def layerwise_nvtx_marker_context(module_name, module_obj, in_tensor=None, kwargs=None):
"""Context manager for NVTX markers that automatically pushes on enter
and pops on exit.
Example:
with nvtx_marker_context("Module:MyModule", module, in_tensor=args,
kwargs=kwargs) as ctx:
ctx.result = module(*args, **kwargs)
return ctx.result
"""
holder = ResultHolder()
# Push input marker
construct_marker_dict_and_push(
module_name,
module_obj,
in_tensor=in_tensor,
kwargs=kwargs,
)
try:
yield holder
finally:
# Pop input marker
nvtx.range_pop()
# Push and pop output marker
output_name = module_name.replace("(input)", "(output)")
construct_marker_dict_and_push(
output_name,
module_obj,
in_tensor=None,
kwargs=None,
out_tensor=holder.result,
)
nvtx.range_pop()
class PytHooks:
"""This module contains all the code needed to enable forward hooks
in a pytorch network.
To register the hooks for a given network, the user needs to instantiate
a PytHook object. Then call the register_hooks method.
Example:
my_hook = PytHook()
my_hook.register_hooks(my_network_model)
"""
def __init__(self):
"""Initialize module variables."""
super().__init__()
self.module_to_name_map = {}
def _process_layer_params(self, module_obj):
return process_layer_params(module_obj)
def module_fwd_hook(self, module_obj, in_tensor, out_tensor):
"""Callback function that ends the NVTX marker.
Records the module name and tensor information.
Called after the module executes the forward method.
"""
nvtx.range_pop()
module_name = self.module_to_name_map.get(module_obj, "unknown")
construct_marker_dict_and_push(
module_name, module_obj, in_tensor=None, kwargs=None, out_tensor=out_tensor
)
nvtx.range_pop()
return
def module_fwd_pre_hook(self, module_obj, in_tensor, kwargs):
"""Creates an NVTX marker with the module name in it.
This function is called before the module executes.
"""
module_name = self.module_to_name_map.get(module_obj, "unknown")
construct_marker_dict_and_push(
module_name, module_obj, in_tensor=in_tensor, kwargs=kwargs, out_tensor=None
)
return
def register_hooks(self, network_model, module_prefix="top"):
"""User level function that activates all the hooks.
The user needs to call this method from the network source code.
The code descends all the modules in the network and registers their
respective hooks.
"""
# Module types to skip (simple operations that don't need detailed profiling)
skip_types = (
torch.nn.Identity,
torch.nn.Dropout,
torch.nn.Dropout1d,
torch.nn.Dropout2d,
torch.nn.Dropout3d,
)
for name, module in network_model.named_modules(prefix=module_prefix):
# Skip certain module types to reduce profiling overhead
if isinstance(module, skip_types):
continue
module.register_forward_pre_hook(self.module_fwd_pre_hook, with_kwargs=True)
module.register_forward_hook(self.module_fwd_hook)
if module not in self.module_to_name_map:
self.module_to_name_map[module] = name
else:
raise ValueError("Module instance {} is not unique ".format(module))
return

View File

@ -88,6 +88,7 @@ from vllm.utils.jsontree import json_map_leaves
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import (
get_dtype_size,
@ -599,6 +600,7 @@ class GPUModelRunner(
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None
self.layerwise_nvtx_hooks_registered = False
def reset_mm_cache(self) -> None:
if self.mm_budget:
@ -2828,6 +2830,42 @@ class GPUModelRunner(
cudagraph_stats,
)
def _register_layerwise_nvtx_hooks(self) -> None:
"""
Register layerwise NVTX hooks if --enable-layerwise-nvtx-tracing is enabled
to trace detailed information of each layer or module in the model.
"""
if (
self.vllm_config.observability_config.enable_layerwise_nvtx_tracing
and not self.layerwise_nvtx_hooks_registered
):
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
logger.debug_once(
"layerwise NVTX tracing is not supported when CUDA graph is "
"turned off; you may observe part or all of the model "
"missing NVTX markers"
)
# In STOCK_TORCH_COMPILE mode, after registering hooks here,
# the __call__ function of nn.module will be recompiled with
# fullgraph=True. Since nvtx.range_push/pop are not traceable
# by torch dynamo, we can't register hook functions here
# because hook functions will also be traced by torch dynamo.
if (
self.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
logger.debug_once(
"layerwise NVTX tracing is not supported when "
"CompilationMode is STOCK_TORCH_COMPILE, skipping "
"function hooks registration"
)
else:
pyt_hooks = PytHooks()
pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
self.layerwise_nvtx_hooks_registered = True
@torch.inference_mode()
def execute_model(
self,
@ -4122,6 +4160,17 @@ class GPUModelRunner(
is_graph_capturing=is_graph_capturing,
)
# We register layerwise NVTX hooks here after the first dynamo tracing is
# done to avoid nvtx operations in hook functions being traced by
# torch dynamo and causing graph breaks.
# Note that for DYNAMO_ONCE and VLLM_COMPILE mode,
# compiled model's dynamo tracing is only done once and the compiled model's
# __call__ function is replaced by calling the compiled function.
# So it's safe to register hooks here. Hooks will be registered to
# both compiled and uncompiled models but they will never
# be called on the compiled model execution path.
self._register_layerwise_nvtx_hooks()
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
# requests to process.