diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b120c85bf232e..69e1ed37a5beb 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -14,6 +14,7 @@ import torch._C._dynamo.guards import vllm.envs as envs from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config from vllm.logger import init_logger +from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context logger = init_logger(__name__) @@ -92,12 +93,29 @@ class TorchCompileWithNoGuardsWrapper: return self.forward(*args, **kwargs) + def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs): + if self.layerwise_nvtx_tracing_enabled: + args_list = list(args) + kwargs_dict = dict(kwargs) + with layerwise_nvtx_marker_context( + "Torch Compiled Module (input):{}".format(self.__class__.__name__), + self, + in_tensor=args_list, + kwargs=kwargs_dict, + ) as ctx: + ctx.result = callable_fn(*args, **kwargs) + return ctx.result + return callable_fn(*args, **kwargs) + def __init__(self): self.compiled = False vllm_config = get_current_vllm_config() self.vllm_config = vllm_config mode = vllm_config.compilation_config.mode + self.layerwise_nvtx_tracing_enabled = ( + vllm_config.observability_config.enable_layerwise_nvtx_tracing + ) if mode is None: raise RuntimeError("Compilation mode cannot be NO_COMPILATION") @@ -168,13 +186,19 @@ class TorchCompileWithNoGuardsWrapper: # Make sure a compilation is triggered by clearing dynamo # cache. torch._dynamo.eval_frame.remove_from_cache(self.original_code_object()) - return self._compiled_callable(*args, **kwargs) + return self._call_with_optional_nvtx_range( + self._compiled_callable, *args, **kwargs + ) else: with self._dispatch_to_compiled_code(): - return self.forward(*args, **kwargs) + return self._call_with_optional_nvtx_range( + self.forward, *args, **kwargs + ) else: with _compilation_context(): - return self._compiled_callable(*args, **kwargs) + return self._call_with_optional_nvtx_range( + self._compiled_callable, *args, **kwargs + ) @abstractmethod def forward(self, *args, **kwargs): ... diff --git a/vllm/config/observability.py b/vllm/config/observability.py index fdc27aee380ef..e40bf18a00ce2 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -59,6 +59,11 @@ class ObservabilityConfig: """Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph dispatch modes, and their observed frequencies at every logging interval).""" + enable_layerwise_nvtx_tracing: bool = False + """Enable layerwise NVTX tracing. This traces the execution of each layer or + module in the model and attach informations such as input/output shapes to + nvtx range markers. Noted that this doesn't work with CUDA graphs enabled.""" + @cached_property def collect_model_forward_time(self) -> bool: """Whether to collect model forward time for the request.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fd07cded7bc51..883ae370f9e74 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -519,6 +519,9 @@ class EngineArgs: ObservabilityConfig, "kv_cache_metrics_sample" ) cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics + enable_layerwise_nvtx_tracing: bool = ( + ObservabilityConfig.enable_layerwise_nvtx_tracing + ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls @@ -1026,6 +1029,10 @@ class EngineArgs: "--cudagraph-metrics", **observability_kwargs["cudagraph_metrics"], ) + observability_group.add_argument( + "--enable-layerwise-nvtx-tracing", + **observability_kwargs["enable_layerwise_nvtx_tracing"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -1704,6 +1711,7 @@ class EngineArgs: kv_cache_metrics=self.kv_cache_metrics, kv_cache_metrics_sample=self.kv_cache_metrics_sample, cudagraph_metrics=self.cudagraph_metrics, + enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing, ) # Compilation config overrides diff --git a/vllm/utils/nvtx_pytorch_hooks.py b/vllm/utils/nvtx_pytorch_hooks.py new file mode 100644 index 0000000000000..39e2a9a136e63 --- /dev/null +++ b/vllm/utils/nvtx_pytorch_hooks.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager + +import torch +import torch.cuda.nvtx as nvtx + + +def print_tensor(tensor_obj, prefix, tensor_list=None): + """Descends iterators that contains Tensors and prints the Tensor. + Recursive function that descends iterator type arguments until + it finds a Tensor object. + """ + if tensor_list is None: + tensor_list = [] + + if isinstance(tensor_obj, (list, tuple)): + for ten in tensor_obj: + tensor_list = print_tensor(ten, prefix, tensor_list) + elif isinstance(tensor_obj, torch.Tensor): + tensor_dims = list(tensor_obj.size()) + tensor_list.append(tensor_dims) + return tensor_list + + +def process_layer_params(module_obj): + """Extract the static parameters from LLM and VLM relevant layer types""" + param_info = {} + # Extract parameters for layers commonly used in LLMs and VLMs + if isinstance(module_obj, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)): + conv_params = {} + conv_params["in_chan"] = module_obj.in_channels + conv_params["out_chan"] = module_obj.out_channels + conv_params["filter_dim"] = module_obj.kernel_size + conv_params["stride"] = module_obj.stride + conv_params["padding"] = module_obj.padding + conv_params["dilation"] = module_obj.dilation + conv_params["transposed"] = module_obj.transposed + conv_params["output_padding"] = module_obj.output_padding + conv_params["groups"] = module_obj.groups + conv_params["padding_mode"] = module_obj.padding_mode + param_info = conv_params + elif isinstance( + module_obj, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + convtranspose_params = {} + convtranspose_params["in_chan"] = module_obj.in_channels + convtranspose_params["out_chan"] = module_obj.out_channels + convtranspose_params["filter_dim"] = module_obj.kernel_size + convtranspose_params["stride"] = module_obj.stride + convtranspose_params["padding"] = module_obj.padding + convtranspose_params["dilation"] = module_obj.dilation + convtranspose_params["transposed"] = module_obj.transposed + convtranspose_params["output_padding"] = module_obj.output_padding + convtranspose_params["groups"] = module_obj.groups + convtranspose_params["padding_mode"] = module_obj.padding_mode + param_info = convtranspose_params + elif isinstance( + module_obj, (torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d) + ): + + def _handle_int_or_tuple(parameter): + if isinstance(parameter, tuple): + return list(parameter) + elif isinstance(parameter, int): + return [parameter, parameter] + + pooling_params = {} + pooling_params["filter_dim"] = _handle_int_or_tuple(module_obj.kernel_size) + pooling_params["stride"] = _handle_int_or_tuple(module_obj.stride) + pooling_params["padding"] = _handle_int_or_tuple(module_obj.padding) + pooling_params["dilation"] = _handle_int_or_tuple(module_obj.dilation) + param_info = pooling_params + elif isinstance( + module_obj, (torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d) + ): + pooling_params = {} + pooling_params["filter_dim"] = [ + module_obj.kernel_size, + module_obj.kernel_size, + ] + pooling_params["stride"] = [module_obj.stride, module_obj.stride] + pooling_params["padding"] = [module_obj.padding, module_obj.padding] + pooling_params["ceil_mode"] = module_obj.ceil_mode + pooling_params["count_include_pad"] = module_obj.count_include_pad + param_info = pooling_params + elif isinstance( + module_obj, + ( + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + ), + ): + pooling_params = {} + pooling_params["output_size"] = [ + module_obj.output_size, + module_obj.output_size, + ] + param_info = pooling_params + elif isinstance(module_obj, torch.nn.Linear): + param_info["in_features"] = module_obj.in_features + param_info["out_features"] = module_obj.out_features + elif isinstance( + module_obj, + (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d), + ): + param_info["num_features"] = module_obj.num_features + param_info["epsilon"] = module_obj.eps + param_info["momentum"] = module_obj.momentum + elif isinstance(module_obj, torch.nn.ReLU): + param_info["in_place"] = module_obj.inplace + elif isinstance(module_obj, torch.nn.Dropout): + param_info["p"] = module_obj.p + param_info["in_place"] = module_obj.inplace + elif isinstance(module_obj, torch.nn.Embedding): + param_info["num_embeddings"] = module_obj.num_embeddings + param_info["embedding_dim"] = module_obj.embedding_dim + elif isinstance( + module_obj, + ( + torch.nn.Upsample, + torch.nn.UpsamplingNearest2d, + torch.nn.UpsamplingBilinear2d, + ), + ): + param_info["scale_factor"] = module_obj.scale_factor + + return param_info + + +def construct_marker_dict_and_push( + module_name, module_obj, in_tensor, kwargs=None, out_tensor=None +): + marker_dict = {} + marker_dict["Module"] = module_name + + ## Get trainable parameters like weights and bias + module_params = module_obj.named_parameters(recurse=False) + for idx, (param_name, param_obj) in enumerate(module_params): + if idx == 0: + marker_dict["TrainableParams"] = {} + marker_dict["TrainableParams"][param_name] = list(param_obj.size()) + + in_tensor_list = print_tensor(in_tensor, "Input") + if in_tensor_list: + marker_dict["Inputs"] = in_tensor_list + + out_tensor_list = print_tensor(out_tensor, "Output") + if out_tensor_list: + marker_dict["Outputs"] = out_tensor_list + + ## Get Kwargs like input_ids and positions for the top module + if kwargs: + for key, value in kwargs.items(): + if isinstance(value, (torch.Tensor, list, tuple)): + tensor_list = print_tensor(value, key) + if tensor_list: + marker_dict[key] = tensor_list + + param_info = process_layer_params(module_obj) + if param_info: + marker_dict["StaticParams"] = param_info + nvtx.range_push("{}".format(marker_dict)) + + +class ResultHolder: + """Holder for storing results from within a context manager.""" + + result = None + + +@contextmanager +def layerwise_nvtx_marker_context(module_name, module_obj, in_tensor=None, kwargs=None): + """Context manager for NVTX markers that automatically pushes on enter + and pops on exit. + + Example: + with nvtx_marker_context("Module:MyModule", module, in_tensor=args, + kwargs=kwargs) as ctx: + ctx.result = module(*args, **kwargs) + return ctx.result + """ + holder = ResultHolder() + + # Push input marker + construct_marker_dict_and_push( + module_name, + module_obj, + in_tensor=in_tensor, + kwargs=kwargs, + ) + try: + yield holder + finally: + # Pop input marker + nvtx.range_pop() + # Push and pop output marker + output_name = module_name.replace("(input)", "(output)") + construct_marker_dict_and_push( + output_name, + module_obj, + in_tensor=None, + kwargs=None, + out_tensor=holder.result, + ) + nvtx.range_pop() + + +class PytHooks: + """This module contains all the code needed to enable forward hooks + in a pytorch network. + + To register the hooks for a given network, the user needs to instantiate + a PytHook object. Then call the register_hooks method. + + Example: + + my_hook = PytHook() + my_hook.register_hooks(my_network_model) + """ + + def __init__(self): + """Initialize module variables.""" + super().__init__() + self.module_to_name_map = {} + + def _process_layer_params(self, module_obj): + return process_layer_params(module_obj) + + def module_fwd_hook(self, module_obj, in_tensor, out_tensor): + """Callback function that ends the NVTX marker. + Records the module name and tensor information. + Called after the module executes the forward method. + """ + nvtx.range_pop() + module_name = self.module_to_name_map.get(module_obj, "unknown") + construct_marker_dict_and_push( + module_name, module_obj, in_tensor=None, kwargs=None, out_tensor=out_tensor + ) + nvtx.range_pop() + return + + def module_fwd_pre_hook(self, module_obj, in_tensor, kwargs): + """Creates an NVTX marker with the module name in it. + This function is called before the module executes. + """ + module_name = self.module_to_name_map.get(module_obj, "unknown") + construct_marker_dict_and_push( + module_name, module_obj, in_tensor=in_tensor, kwargs=kwargs, out_tensor=None + ) + return + + def register_hooks(self, network_model, module_prefix="top"): + """User level function that activates all the hooks. + The user needs to call this method from the network source code. + The code descends all the modules in the network and registers their + respective hooks. + """ + # Module types to skip (simple operations that don't need detailed profiling) + skip_types = ( + torch.nn.Identity, + torch.nn.Dropout, + torch.nn.Dropout1d, + torch.nn.Dropout2d, + torch.nn.Dropout3d, + ) + + for name, module in network_model.named_modules(prefix=module_prefix): + # Skip certain module types to reduce profiling overhead + if isinstance(module, skip_types): + continue + + module.register_forward_pre_hook(self.module_fwd_pre_hook, with_kwargs=True) + module.register_forward_hook(self.module_fwd_hook) + if module not in self.module_to_name_map: + self.module_to_name_map[module] = name + else: + raise ValueError("Module instance {} is not unique ".format(module)) + return diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 152bea2c0975c..b6a8145226b3f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -88,6 +88,7 @@ from vllm.utils.jsontree import json_map_leaves from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.nvtx_pytorch_hooks import PytHooks from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import ( get_dtype_size, @@ -599,6 +600,7 @@ class GPUModelRunner( # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.layerwise_nvtx_hooks_registered = False def reset_mm_cache(self) -> None: if self.mm_budget: @@ -2828,6 +2830,42 @@ class GPUModelRunner( cudagraph_stats, ) + def _register_layerwise_nvtx_hooks(self) -> None: + """ + Register layerwise NVTX hooks if --enable-layerwise-nvtx-tracing is enabled + to trace detailed information of each layer or module in the model. + """ + + if ( + self.vllm_config.observability_config.enable_layerwise_nvtx_tracing + and not self.layerwise_nvtx_hooks_registered + ): + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + logger.debug_once( + "layerwise NVTX tracing is not supported when CUDA graph is " + "turned off; you may observe part or all of the model " + "missing NVTX markers" + ) + + # In STOCK_TORCH_COMPILE mode, after registering hooks here, + # the __call__ function of nn.module will be recompiled with + # fullgraph=True. Since nvtx.range_push/pop are not traceable + # by torch dynamo, we can't register hook functions here + # because hook functions will also be traced by torch dynamo. + if ( + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + ): + logger.debug_once( + "layerwise NVTX tracing is not supported when " + "CompilationMode is STOCK_TORCH_COMPILE, skipping " + "function hooks registration" + ) + else: + pyt_hooks = PytHooks() + pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) + self.layerwise_nvtx_hooks_registered = True + @torch.inference_mode() def execute_model( self, @@ -4122,6 +4160,17 @@ class GPUModelRunner( is_graph_capturing=is_graph_capturing, ) + # We register layerwise NVTX hooks here after the first dynamo tracing is + # done to avoid nvtx operations in hook functions being traced by + # torch dynamo and causing graph breaks. + # Note that for DYNAMO_ONCE and VLLM_COMPILE mode, + # compiled model's dynamo tracing is only done once and the compiled model's + # __call__ function is replaced by calling the compiled function. + # So it's safe to register hooks here. Hooks will be registered to + # both compiled and uncompiled models but they will never + # be called on the compiled model execution path. + self._register_layerwise_nvtx_hooks() + # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real # requests to process.