From d8bff253d76c765c4a53515c262d8ea64a2b8b70 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Thu, 17 Jul 2025 18:20:31 -0700 Subject: [PATCH] add il tool more changes Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> fix tp Signed-off-by: Lu Fang add comparison tool tmp add unit test and fix format Signed-off-by: Lu Fang add comparison script and documentation Signed-off-by: Lu Fang provide default intermediate logging Signed-off-by: Lu Fang optional register il Signed-off-by: Lu Fang add input reload and improve intermediate compare --- docs/contributing/intermediate_logging.md | 136 ++++ tests/v1/test_intermediates_logging.py | 325 ++++++++ tools/compare_intermediate.py | 706 ++++++++++++++++++ vllm/config.py | 125 +++- vllm/engine/arg_utils.py | 44 +- vllm/v1/engine/core.py | 4 + vllm/v1/intermediates/__init__.py | 0 .../v1/intermediates/intermediates_logging.py | 599 +++++++++++++++ vllm/v1/worker/gpu_worker.py | 6 +- vllm/v1/worker/worker_base.py | 27 +- vllm/worker/worker_base.py | 16 + 11 files changed, 1982 insertions(+), 6 deletions(-) create mode 100644 docs/contributing/intermediate_logging.md create mode 100644 tests/v1/test_intermediates_logging.py create mode 100755 tools/compare_intermediate.py create mode 100644 vllm/v1/intermediates/__init__.py create mode 100644 vllm/v1/intermediates/intermediates_logging.py diff --git a/docs/contributing/intermediate_logging.md b/docs/contributing/intermediate_logging.md new file mode 100644 index 0000000000000..4b1dc2aca8797 --- /dev/null +++ b/docs/contributing/intermediate_logging.md @@ -0,0 +1,136 @@ +# Intermediate Tensor Logging + +This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution. + +## Overview + +The intermediate tensor logging feature enables you to: + +- Log input and output tensors from a configured set of filters +- Filter modules by name using regex patterns +- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module) +- Filter tensors by device +- Filter whole model fwd step id + +This is manily useful for debugging model accucacy gaps with 2 runs + +## Usage + +### Enabling via parameters or config file + +**Offline Inference example** + +Dump all modules, all devices for step 0 (default behavior) + +```bash +python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}' +``` + +Dump first layers module, all devices for step 0 + +```bash +python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}' +``` + +Dump customized layers, devices, steps through a config file + +The configuration file should be a JSON file with the following structure: + +```json +{ + "output_dir": "/tmp/vllm_intermediates", + "module_call_match": ["layers\\.0\\.(?!.*rotary_emb).*", "rotary_emb:0", "embed_tokens", "model\\.norm"], + "log_step_ids": [0, 1], + "device_names": ["cuda:0"] +} +``` + +```bash +python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json +``` + + +#### Configuration Parameters + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` | +| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) | +| `log_step_ids` | array | List of step IDs to log | `[0]` | +| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) | +| `device_names` | array | List of device names to log | `[]` (log all devices) | + +### Output Directory Structure + +When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions: + +``` +/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/ +└── step_0 + ├── model.embed_tokens + │ ├── inputs_0_cuda_0.pt + │ ├── inputs.json + │ ├── outputs_cuda_0.pt + │ └── outputs.json + ├── model.layers.0.input_layernorm + │ ├── inputs_0_cuda_0.pt + │ ├── inputs.json + │ ├── outputs_cuda_0.pt + │ └── outputs.json + └── step_1/ + └── ... +``` + +Each tensor is saved in two formats: +1. `.json` files containing metadata and small tensor values +2. `.pt` files containing the full PyTorch tensors (can be loaded with `torch.load()`) + +## Comparing Intermediate Logging Results + +vLLM provides a tool called `compare_intermediate.py` to compare intermediate tensors between two different runs. This is particularly useful for debugging accuracy differences or verifying that code changes don't affect model outputs. + +### Usage + +```bash +python tools/compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options] +``` + +### Options + +| Option | Description | Default | +|--------|-------------|---------| +| `--dir1` | First intermediate logging directory | (required) | +| `--dir2` | Second intermediate logging directory | (required) | +| `--output` | Output file for the report | stdout | +| `--rtol` | Relative tolerance for tensor comparison | 1e-5 | +| `--atol` | Absolute tolerance for tensor comparison | 1e-8 | +| `--steps` | Comma-separated list of steps to compare | all | +| `--modules` | Comma-separated list of module name patterns to compare | all | +| `--verbose` | Include detailed information about each tensor | false | + +### Example + +```bash +# Compare all tensors from two different runs +python tools/compare_intermediate.py --dir1 /tmp/vllm_intermediates/run1 --dir2 /tmp/vllm_intermediates/run2 + +# Compare only specific modules and steps with custom tolerance +python tools/compare_intermediate.py \ + --dir1 /tmp/vllm_intermediates/run1 \ + --dir2 /tmp/vllm_intermediates/run2 \ + --steps 0,1 \ + --modules ".*attention.*,.*mlp.*" \ + --rtol 1e-4 \ + --atol 1e-7 \ + --output comparison_report.md +``` + +### Output + +The tool generates a detailed markdown report that includes: + +- Overall summary of matching and mismatched tensors +- Per-module comparison results +- Detailed tensor differences (when using `--verbose`) + +This makes it easy to identify which specific tensors differ between runs and by how much. diff --git a/tests/v1/test_intermediates_logging.py b/tests/v1/test_intermediates_logging.py new file mode 100644 index 0000000000000..a25a70e9106a1 --- /dev/null +++ b/tests/v1/test_intermediates_logging.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for the intermediate tensor logging functionality. +""" + +import json +from os.path import isdir +import shutil +import os +import tempfile +from pathlib import Path +from unittest import mock + +import pytest +import torch +import torch.nn as nn + +from vllm.config import IntermediateLoggingConfig +from vllm.v1.intermediates.intermediates_logging import (get_current_il_config, + get_step, increment_step, + intermediate_logging, + register_intermediate_hooks, + reset_step, + should_log_device, + should_log_module, + should_log_step) + + +class SimpleModel(nn.Module): + """A simple model for testing.""" + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 20) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(20, 5) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +@pytest.fixture +def temp_output_dir(): + """Create a temporary directory for test outputs.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + # Clean up after the test + shutil.rmtree(temp_dir) + + +@pytest.fixture +def simple_model(): + """Create a simple model for testing.""" + return SimpleModel() + + +@pytest.fixture +def il_config(temp_output_dir): + """Create a basic IntermediateLoggingConfig for testing.""" + return IntermediateLoggingConfig(output_dir=temp_output_dir, + enabled=True, + log_step_ids=[0, 1], + module_call_match=[".*linear.*"]) + + +def test_step_counter(): + """Test the step counter functionality.""" + # Reset the step counter + reset_step() + assert get_step() == 0 + + # Increment the step counter + increment_step() + assert get_step() == 1 + + # Increment again + increment_step() + assert get_step() == 2 + + # Reset again + reset_step() + assert get_step() == 0 + + +def test_intermediate_logging_context_manager(): + """Test the intermediate_logging context manager.""" + # Create a config + config = IntermediateLoggingConfig(enabled=True) + + # Initially, there should be no global config + assert get_current_il_config() is None + + # Use the context manager + with intermediate_logging(config): + # Inside the context, the global config should be set + assert get_current_il_config() is not None + assert get_current_il_config().enabled is True + + # After the context, the global config should be None again + assert get_current_il_config() is None + + # Test with a different config + config2 = IntermediateLoggingConfig(enabled=False) + with intermediate_logging(config2): + assert get_current_il_config() is not None + assert get_current_il_config().enabled is False + + +def test_should_log_step(): + """Test the should_log_step function.""" + # Reset step counter + reset_step() + + # Create configs with different step IDs + config_all_steps = IntermediateLoggingConfig( + enabled=True, + log_step_ids=[] # Empty list means log all steps + ) + config_specific_steps = IntermediateLoggingConfig( + enabled=True, + log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4 + ) + config_disabled = IntermediateLoggingConfig(enabled=False, + log_step_ids=[0, 1, 2]) + + # Test with all steps config + with intermediate_logging(config_all_steps): + assert should_log_step(config_all_steps) is True # Step 0 + increment_step() + assert should_log_step(config_all_steps) is True # Step 1 + + # Reset step counter + reset_step() + + # Test with specific steps config + with intermediate_logging(config_specific_steps): + assert should_log_step(config_specific_steps) is True # Step 0 + increment_step() + assert should_log_step(config_specific_steps) is False # Step 1 + increment_step() + assert should_log_step(config_specific_steps) is True # Step 2 + increment_step() + assert should_log_step(config_specific_steps) is False # Step 3 + increment_step() + assert should_log_step(config_specific_steps) is True # Step 4 + + # Test with disabled config + with intermediate_logging(config_disabled): + assert should_log_step(config_disabled) is False # Disabled + + +def test_should_log_device(): + """Test the should_log_device function.""" + # Create configs with different device filters + config_all_devices = IntermediateLoggingConfig( + enabled=True, + device_names=[] # Empty list means log all devices + ) + config_specific_devices = IntermediateLoggingConfig( + enabled=True, + device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu + ) + config_disabled = IntermediateLoggingConfig(enabled=False, + device_names=["cuda:0", "cpu"]) + + # Test with all devices config + with intermediate_logging(config_all_devices): + assert should_log_device(config_all_devices, "cuda:0") is True + assert should_log_device(config_all_devices, "cuda:1") is True + assert should_log_device(config_all_devices, "cpu") is True + + # Test with specific devices config + with intermediate_logging(config_specific_devices): + assert should_log_device(config_specific_devices, "cuda:0") is True + assert should_log_device(config_specific_devices, "cuda:1") is False + assert should_log_device(config_specific_devices, "cpu") is True + + # Test with disabled config + with intermediate_logging(config_disabled): + assert should_log_device(config_disabled, "cuda:0") is False + assert should_log_device(config_disabled, "cpu") is False + + +def test_should_log_module(simple_model): + """Test the should_log_module function.""" + # Create configs with different module name filters + config_all_modules = IntermediateLoggingConfig( + enabled=True, + module_call_match=None # None means log all modules + ) + config_specific_modules = IntermediateLoggingConfig( + enabled=True, + module_call_match=[".*linear.*" + ] # Only log modules with "linear" in the name + ) + config_disabled = IntermediateLoggingConfig(enabled=False, + module_call_match=[".*"]) + + # Test with all modules config + with intermediate_logging(config_all_modules): + assert should_log_module(config_all_modules, "linear1", + simple_model.linear1) is True + assert should_log_module(config_all_modules, "relu", + simple_model.relu) is True + + # Test with specific modules config + with intermediate_logging(config_specific_modules): + assert should_log_module(config_specific_modules, "linear1", + simple_model.linear1) is True + assert should_log_module(config_specific_modules, "relu", + simple_model.relu) is False + + # Test with disabled config + with intermediate_logging(config_disabled): + assert should_log_module(config_disabled, "linear1", + simple_model.linear1) is False + assert should_log_module(config_disabled, "relu", + simple_model.relu) is False + + +def test_register_hooks(simple_model, il_config): + """Test registering hooks on a model.""" + # Register hooks + logger_instance = register_intermediate_hooks(simple_model, il_config) + + # Check that hooks were registered + assert len(logger_instance.hooks) > 0 + + # Remove hooks + logger_instance.remove_hooks() + + # Check that hooks were removed + assert len(logger_instance.hooks) == 0 + + +@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json') +@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors') +def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model, + il_config, temp_output_dir): + """Test that forward hooks are called during model execution.""" + mock_save_tensors.return_value = None + # Register hooks + with intermediate_logging(il_config): + logger_instance = register_intermediate_hooks(simple_model, il_config) + + # Create input tensor + input_tensor = torch.randn(2, 10) + + # Reset step counter + reset_step() + + # Forward pass + simple_model(input_tensor) + + # Check that the step counter was incremented + assert get_step() == 1 + + # Check that dump_intermediates_to_json and save_tensors were called + assert mock_dump_json.called + assert mock_save_tensors.called + + + # Remove hooks + logger_instance.remove_hooks() + + +def test_end_to_end(simple_model, il_config, temp_output_dir): + """Test the entire intermediate logging workflow end-to-end.""" + # Register hooks + with intermediate_logging(il_config): + logger_instance = register_intermediate_hooks(simple_model, il_config) + + # Create input tensor + input_tensor = torch.randn(2, 10) + + # Reset step counter + reset_step() + + # Forward pass + simple_model(input_tensor) + + # Check that output directories were created + root_dir = Path(il_config._output_run_dir) + assert root_dir.exists() + step_dir = root_dir / "step_0" + assert step_dir.exists() + + module_dirs = list(step_dir.glob("*")) + print(f"{module_dirs=}") + assert len(module_dirs) > 0 + + # Check that input and output files were created + for module_dir in module_dirs: + print(f"{module_dir=}") + if os.path.isdir(module_dir): + inputs_json = module_dir / "inputs.json" + outputs_json = module_dir / "outputs.json" + + # Check that JSON files exist + assert inputs_json.exists() + assert outputs_json.exists() + + # Check that JSON files contain valid data + with open(inputs_json) as f: + inputs_data = json.load(f) + assert "type" in inputs_data + + with open(outputs_json) as f: + outputs_data = json.load(f) + assert "type" in outputs_data + + # Check that tensor files exist + tensor_files = list(module_dir.glob("*.pt")) + assert len(tensor_files) > 0 + + # Remove hooks + logger_instance.remove_hooks() + + +if __name__ == "__main__": + pytest.main(["-xvs", __file__]) diff --git a/tools/compare_intermediate.py b/tools/compare_intermediate.py new file mode 100755 index 0000000000000..984c60450354c --- /dev/null +++ b/tools/compare_intermediate.py @@ -0,0 +1,706 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Script to compare intermediate logging outputs from two different runs. + +This script compares the tensor outputs from two different intermediate logging +directories and generates a report of the differences. + +Usage: + python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options] + +Options: + --dir1 DIR First intermediate logging directory + --dir2 DIR Second intermediate logging directory + --output FILE Output file for the report (default: stdout) + --format {md,json} Output format (default: md) + --rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5) + --atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8) + --steps STEPS Comma-separated list of steps to compare (default: all) + --modules MODULES Comma-separated list of module name patterns to compare (default: all) + --verbose Include detailed information about each tensor +""" + +import argparse +import json +import re +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch + + +def load_tensor(path: Path) -> torch.Tensor: + """Load a tensor from a .pt file.""" + try: + return torch.load(path, map_location="cpu") + except Exception as e: + print(f"Error loading tensor from {path}: {e}") + return None + + +def load_json(path: Path) -> Dict: + """Load a JSON file.""" + try: + with open(path, "r") as f: + return json.load(f) + except Exception as e: + print(f"Error loading JSON from {path}: {e}") + return {} + + +def extract_diff_metatada(exception_str: str) -> Dict: + try: + num_diff_elements = int( + re.search(r"Mismatched elements: (\d+) /", exception_str).group(1) + ) + total_elements = int( + re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1) + ) + max_abs_diff = float( + re.search( + r"Greatest absolute difference: ([\d\.e-]+)", exception_str + ).group(1) + ) + max_rel_diff = float( + re.search( + r"Greatest relative difference: ([\d\.e-]+)", exception_str + ).group(1) + ) + return { + "num_diff_elements": num_diff_elements, + "total_elements": total_elements, + "max_abs_diff": max_abs_diff, + "max_rel_diff": max_rel_diff, + } + except Exception: + return {"error": exception_str} + + +def compare_tensors( + tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float +) -> Dict: + """Compare two tensors and return a dictionary with comparison results.""" + if tensor1 is None or tensor2 is None: + return {"match": False, "error": "One or both tensors are None"} + + if tensor1.shape != tensor2.shape: + return { + "match": False, + "error": f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}", + } + + if tensor1.dtype != tensor2.dtype: + return { + "match": False, + "error": f"Dtype mismatch: {tensor1.dtype} vs {tensor2.dtype}", + } + + # Check if tensors are close using PyTorch's assert_close + try: + torch.testing.assert_close(tensor1, tensor2, rtol=rtol, atol=atol) + except Exception as e: + return {"match": False, **extract_diff_metatada(str(e))} + return {"match": True} + + +def compare_json_values(value1: Any, value2: Any) -> Dict: + """Compare two JSON values and return a dictionary with comparison results.""" + if type(value1) is not type(value2): + return { + "match": False, + "error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}", + } + + if isinstance(value1, dict): + # Compare dictionaries + all_keys = set(value1.keys()) | set(value2.keys()) + mismatches = {} + + for key in all_keys: + if key not in value1: + mismatches[key] = {"error": "Missing in first dict"} + elif key not in value2: + mismatches[key] = {"error": "Missing in second dict"} + else: + comparison = compare_json_values(value1[key], value2[key]) + if not comparison["match"]: + mismatches[key] = comparison + + if mismatches: + return {"match": False, "mismatches": mismatches} + return {"match": True} + + elif isinstance(value1, list): + # Compare lists + if len(value1) != len(value2): + return { + "match": False, + "error": f"Length mismatch: {len(value1)} vs {len(value2)}", + } + + mismatches = {} + for i, (item1, item2) in enumerate(zip(value1, value2)): + comparison = compare_json_values(item1, item2) + if not comparison["match"]: + mismatches[i] = comparison + + if mismatches: + return {"match": False, "mismatches": mismatches} + return {"match": True} + + else: + # Compare primitive values + if value1 == value2: + return {"match": True} + else: + return {"match": False, "value1": value1, "value2": value2} + + +def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]: + """ + Find all tensor files in the given directory. + + Returns a dictionary with the structure: + { + "step_0": { + "module_name_123456": { + "inputs": [Path("inputs_0_cuda_0.pt"), ...], + "outputs": [Path("output_cuda_0.pt"), ...] + }, + ... + }, + ... + } + """ + result = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + + # Find all step directories + step_dirs = [d for d in directory.glob("step_*") if d.is_dir()] + + for step_dir in step_dirs: + step_name = step_dir.name + + # Find all module directories + module_dirs = [d for d in step_dir.glob("*") if d.is_dir()] + + for module_dir in module_dirs: + module_name = module_dir.name + + # Find input tensor files + input_tensors = list(module_dir.glob("inputs_*.pt")) + if input_tensors: + result[step_name][module_name]["inputs"] = input_tensors + + # Find output tensor files + output_tensors = list(module_dir.glob("output*.pt")) + if output_tensors: + result[step_name][module_name]["outputs"] = output_tensors + + # Find JSON metadata files + inputs_json = module_dir / "inputs.json" + if inputs_json.exists(): + result[step_name][module_name]["inputs_json"] = [inputs_json] + + outputs_json = module_dir / "outputs.json" + if outputs_json.exists(): + result[step_name][module_name]["outputs_json"] = [outputs_json] + + return result + + +def filter_steps_and_modules( + tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]], + steps: Optional[List[str]] = None, + module_patterns: Optional[List[str]] = None, +) -> Dict[str, Dict[str, Dict[str, List[Path]]]]: + """Filter tensor files by steps and module patterns.""" + result = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + + # Filter steps + if steps: + step_names = [f"step_{step}" for step in steps] + steps_to_include = {step: True for step in step_names} + else: + steps_to_include = {step: True for step in tensor_files.keys()} + + # Compile module patterns + if module_patterns: + compiled_patterns = [re.compile(pattern) for pattern in module_patterns] + else: + compiled_patterns = None + + for step_name, modules in tensor_files.items(): + if step_name not in steps_to_include: + continue + + for module_name, file_types in modules.items(): + # Check if module matches any pattern + if compiled_patterns: + if not any( + pattern.search(module_name) for pattern in compiled_patterns + ): + continue + + result[step_name][module_name] = file_types + + return result + + +def compare_directories( + dir1: Path, + dir2: Path, + rtol: Optional[float] = None, + atol: Optional[float] = None, + steps: Optional[List[str]] = None, + module_patterns: Optional[List[str]] = None, +) -> Dict: + """Compare two intermediate logging directories and return a report.""" + # Find tensor files in both directories + tensor_files1 = find_tensor_files(dir1) + tensor_files2 = find_tensor_files(dir2) + + # Filter by steps and modules + if steps or module_patterns: + tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns) + tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns) + + # Get all steps and modules from both directories + all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys()) + + report = { + "dir1": str(dir1), + "dir2": str(dir2), + "rtol": rtol, + "atol": atol, + "steps": {}, + } + + # Compare each step + for step in sorted(all_steps): + step_report = { + "modules": {}, + "summary": { + "total_modules": 0, + "matching_modules": 0, + "mismatched_modules": 0, + "missing_modules": 0, + }, + } + + # Get all modules from both directories for this step + modules1 = tensor_files1.get(step, {}) + modules2 = tensor_files2.get(step, {}) + # TODO: read from module calls.txt to get the full module list + # TODO: check if module calls txt exsits + dir1_module_call_file = dir1 / step / "module_calls.txt" + if dir1_module_call_file.exists(): + with open(dir1 / step / "module_calls.txt", "r") as f: + all_modules = f.read().splitlines() + else: + print( + "Warnings: the module call orders are missed, ordering using module alphbetics" + ) + all_modules = sorted(set(modules1.keys()) | set(modules2.keys())) + step_report["module_call_list"] = [] + for module in all_modules: + module_report = { + "inputs": {}, + "outputs": {}, + "summary": { + "total_tensors": 0, + "matching_tensors": 0, + "mismatched_tensors": 0, + "missing_tensors": 0, + }, + } + + # Check if module exists in both directories + if module not in modules1: + module_report["error"] = f"Module missing in {dir1}" + step_report["summary"]["missing_modules"] += 1 + step_report["modules"][module] = module_report + continue + + if module not in modules2: + module_report["error"] = f"Module missing in {dir2}" + step_report["summary"]["missing_modules"] += 1 + step_report["modules"][module] = module_report + continue + + # Compare JSON metadata + for json_type in ["inputs_json", "outputs_json"]: + json_files1 = modules1[module].get(json_type, []) + json_files2 = modules2[module].get(json_type, []) + + if json_files1 and json_files2: + json1 = load_json(json_files1[0]) + json2 = load_json(json_files2[0]) + + json_comparison = compare_json_values(json1, json2) + json_name = json_type.replace("_json", "") + module_report[f"{json_name}_metadata"] = json_comparison + + # Add file paths for manual checking when there's a mismatch + if not json_comparison.get("match", True): + module_report[f"{json_name}_metadata"]["file1"] = str( + json_files1[0] + ) + module_report[f"{json_name}_metadata"]["file2"] = str( + json_files2[0] + ) + + # Compare input tensors + input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])} + input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])} + all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys()) + + for tensor_name in sorted(all_input_names): + if tensor_name not in input_tensors1: + module_report["inputs"][tensor_name] = { + "match": False, + "error": f"Tensor missing in {dir1}", + } + module_report["summary"]["missing_tensors"] += 1 + elif tensor_name not in input_tensors2: + module_report["inputs"][tensor_name] = { + "match": False, + "error": f"Tensor missing in {dir2}", + } + module_report["summary"]["missing_tensors"] += 1 + else: + tensor1 = load_tensor(input_tensors1[tensor_name]) + tensor2 = load_tensor(input_tensors2[tensor_name]) + + comparison = compare_tensors(tensor1, tensor2, rtol, atol) + # Add file paths for manual checking when there's a mismatch + if not comparison.get("match", False): + comparison["file1"] = str(input_tensors1[tensor_name]) + comparison["file2"] = str(input_tensors2[tensor_name]) + + module_report["inputs"][tensor_name] = comparison + + if comparison.get("match", False): + module_report["summary"]["matching_tensors"] += 1 + else: + module_report["summary"]["mismatched_tensors"] += 1 + + module_report["summary"]["total_tensors"] += 1 + + # Compare output tensors + output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])} + output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])} + all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys()) + + for tensor_name in sorted(all_output_names): + if tensor_name not in output_tensors1: + module_report["outputs"][tensor_name] = { + "match": False, + "error": f"Tensor missing in {dir1}", + } + module_report["summary"]["missing_tensors"] += 1 + elif tensor_name not in output_tensors2: + module_report["outputs"][tensor_name] = { + "match": False, + "error": f"Tensor missing in {dir2}", + } + module_report["summary"]["missing_tensors"] += 1 + else: + tensor1 = load_tensor(output_tensors1[tensor_name]) + tensor2 = load_tensor(output_tensors2[tensor_name]) + + comparison = compare_tensors(tensor1, tensor2, rtol, atol) + # Add file paths for manual checking when there's a mismatch + if not comparison.get("match", False): + comparison["file1"] = str(output_tensors1[tensor_name]) + comparison["file2"] = str(output_tensors2[tensor_name]) + + module_report["outputs"][tensor_name] = comparison + + if comparison.get("match", False): + module_report["summary"]["matching_tensors"] += 1 + else: + module_report["summary"]["mismatched_tensors"] += 1 + + module_report["summary"]["total_tensors"] += 1 + + # Update module status + if module_report["summary"]["mismatched_tensors"] > 0: + step_report["summary"]["mismatched_modules"] += 1 + else: + step_report["summary"]["matching_modules"] += 1 + + step_report["summary"]["total_modules"] += 1 + step_report["modules"][module] = module_report + step_report["module_call_list"].append(module) + + report["steps"][step] = step_report + + # Add overall summary + report["summary"] = { + "total_steps": len(all_steps), + "total_modules": sum( + step_report["summary"]["total_modules"] + for step_report in report["steps"].values() + ), + "matching_modules": sum( + step_report["summary"]["matching_modules"] + for step_report in report["steps"].values() + ), + "mismatched_modules": sum( + step_report["summary"]["mismatched_modules"] + for step_report in report["steps"].values() + ), + "missing_modules": sum( + step_report["summary"]["missing_modules"] + for step_report in report["steps"].values() + ), + "total_tensors": sum( + module_report["summary"]["total_tensors"] + for step_report in report["steps"].values() + for module_name, module_report in step_report["modules"].items() + if "summary" in module_report + ), + "matching_tensors": sum( + module_report["summary"]["matching_tensors"] + for step_report in report["steps"].values() + for module_name, module_report in step_report["modules"].items() + if "summary" in module_report + ), + "mismatched_tensors": sum( + module_report["summary"]["mismatched_tensors"] + for step_report in report["steps"].values() + for module_name, module_report in step_report["modules"].items() + if "summary" in module_report + ), + "missing_tensors": sum( + module_report["summary"]["missing_tensors"] + for step_report in report["steps"].values() + for module_name, module_report in step_report["modules"].items() + if "summary" in module_report + ), + } + + return report + + +def generate_markdown_report(report: Dict, verbose: bool = False) -> str: + """Generate a markdown report from the comparison results.""" + lines = [] + + # Add header + lines.append("# Intermediate Logging Comparison Report") + lines.append("") + lines.append("Comparing intermediate logging outputs between:") + lines.append(f"- **Directory 1**: `{report['dir1']}`") + lines.append(f"- **Directory 2**: `{report['dir2']}`") + lines.append("") + lines.append(f"Comparison parameters:") + lines.append(f"- Relative tolerance (rtol): {report['rtol']}") + lines.append(f"- Absolute tolerance (atol): {report['atol']}") + lines.append("") + + # Add overall summary + summary = report["summary"] + lines.append("## Overall Summary") + lines.append("") + lines.append("| Category | Total | Matching | Mismatched | Missing |") + lines.append("|----------|-------|----------|------------|---------|") + lines.append(f"| Steps | {summary['total_steps']} | - | - | - |") + lines.append( + f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |" + ) + lines.append( + f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |" + ) + lines.append("") + + # Add step details + for step_name, step_report in sorted(report["steps"].items()): + step_summary = step_report["summary"] + + lines.append(f"## {step_name}") + lines.append("") + lines.append( + f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules" + ) + lines.append("") + + # Add module details + for module_name in step_report["module_call_list"]: + module_report = step_report["modules"][module_name] + if "error" in module_report: + lines.append(f"### ❌ {module_name}") + lines.append("") + lines.append(f"**Error**: {module_report['error']}") + lines.append("") + continue + + module_summary = module_report["summary"] + + # Determine module status + if module_summary["mismatched_tensors"] > 0: + status = "❌" + else: + status = "✅" + + lines.append(f"### {status} {module_name}") + lines.append("") + lines.append( + f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors" + ) + lines.append("") + + # Add metadata comparison results if available + for metadata_type in ["inputs_metadata", "outputs_metadata"]: + if metadata_type in module_report: + metadata_comparison = module_report[metadata_type] + if not metadata_comparison.get("match", True): + file_paths = "" + if ( + "file1" in metadata_comparison + and "file2" in metadata_comparison + ): + file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`" + + lines.append( + f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}" + ) + if verbose and "mismatches" in metadata_comparison: + lines.append("```json") + lines.append( + json.dumps(metadata_comparison["mismatches"], indent=2) + ) + lines.append("```") + lines.append("") + + # Add tensor comparison details + if module_summary["mismatched_tensors"] > 0 or verbose: + # Add input tensor details + if module_report["inputs"]: + lines.append("#### Input Tensors") + lines.append("") + lines.append("| Tensor | Status | Details |") + lines.append("|--------|--------|---------|") + + for tensor_name, comparison in sorted( + module_report["inputs"].items() + ): + if comparison.get("match", False): + status = "✅" + details = "Tensors match" + elif "error" in comparison: + status = "❌" + details = comparison["error"] + else: + status = "❌" + details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, " + details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, " + details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}" + if "file1" in comparison and "file2" in comparison: + details += f"
Files: `{comparison['file1']}` vs `{comparison['file2']}`" + + lines.append(f"| {tensor_name} | {status} | {details} |") + + lines.append("") + + # Add output tensor details + if module_report["outputs"]: + lines.append("#### Output Tensors") + lines.append("") + lines.append("| Tensor | Status | Details |") + lines.append("|--------|--------|---------|") + + for tensor_name, comparison in sorted( + module_report["outputs"].items() + ): + if comparison.get("match", False): + status = "✅" + details = "Tensors match" + elif "error" in comparison: + status = "❌" + details = comparison["error"] + else: + status = "❌" + details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, " + details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, " + details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}" + + lines.append(f"| {tensor_name} | {status} | {details} |") + + lines.append("") + + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser( + description="Compare intermediate logging outputs from two different runs." + ) + parser.add_argument( + "--dir1", required=True, help="First intermediate logging directory" + ) + parser.add_argument( + "--dir2", required=True, help="Second intermediate logging directory" + ) + parser.add_argument("--output", help="Output file for the report (default: stdout)") + parser.add_argument( + "--rtol", + type=float, + default=None, + help="Relative tolerance for tensor comparison (default: 1e-5)", + ) + parser.add_argument( + "--atol", + type=float, + default=None, + help="Absolute tolerance for tensor comparison (default: 1e-8)", + ) + parser.add_argument( + "--steps", help="Comma-separated list of steps to compare (default: all)" + ) + parser.add_argument( + "--modules", + help="Comma-separated list of module name patterns to compare (default: all)", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Include detailed information about each tensor", + ) + + args = parser.parse_args() + + # Parse steps and modules + steps = args.steps.split(",") if args.steps else None + module_patterns = args.modules.split(",") if args.modules else None + + # Compare directories + report = compare_directories( + Path(args.dir1), + Path(args.dir2), + rtol=args.rtol, + atol=args.atol, + steps=steps, + module_patterns=module_patterns, + ) + + # Generate report + output = generate_markdown_report(report, verbose=args.verbose) + + # Write report + if args.output: + with open(args.output, "w") as f: + f.write(output) + print(f"Report written to {args.output}") + else: + print(output) + + +if __name__ == "__main__": + main() + + +def invoke_main() -> None: + main() diff --git a/vllm/config.py b/vllm/config.py index 3bcbbe60652b7..ab5676686dd48 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,8 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, from functools import cached_property from importlib.util import find_spec from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Protocol, TypeVar, Union, cast, get_args) + Protocol, TypeVar, Union, cast, get_args, List, Set) +from re import Pattern import regex as re import torch @@ -4024,6 +4025,122 @@ class KVEventsConfig: """ +@config +@dataclass +class IntermediateLoggingConfig: + """Configuration for intermediate tensor logging.""" + + output_dir: str = "/tmp/vllm_intermediates" + """Directory where to save the intermediate tensors.""" + + reload_input_dir: Optional[str] = None + """Directory where to load the inputs for the steps/modules. + This is used when we want to check per module numerical gaps instead + of accumulated gap to further dive into the actual numerical issues.""" + + module_call_match: Optional[List[str]] = None + """Match modules by name regex and call index ( + a module can be called multiple times in a step) + List of regex:call_idx, call_idx is -1 for default for all calls """ + + log_step_ids: List[int] = field(default_factory=lambda: [0]) + """List of step IDs to log (empty list means log all steps).""" + + log_post_fwd_inputs: bool = False + """Whether logging inputs after forwards for each module""" + + max_tensor_size: Optional[int] = None + """Maximum number of elements in tensors to log (None = no limit).""" + + enabled: bool = True + """Whether logging is enabled.""" + device_names: List[str] = field(default_factory=list) + """List of device names to log (empty list means log all devices).""" + + _compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False) + """Compiled regex patterns for module filtering.""" + + _module_call: dict[str, int] = field(default_factory=dict, init=False) + _step_id_set: Set[int] = field(default_factory=set, init=False) + """Set of step IDs for faster lookup.""" + _output_run_dir: str = "/tmp/vllm_intermediates" + """Unique directory to save single run/serve logging result.""" + + def __post_init__(self): + """Initialize derived fields after instance creation.""" + self._compile_regex_patterns() + self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4()) + self._step_id_set = set(self.log_step_ids) + + def _compile_regex_patterns(self): + """Compile regex patterns for module name filtering.""" + from vllm.logger import init_logger + logger = init_logger(__name__) + + self._compiled_module_matches = [] + + if self.module_call_match is None: + logger.info("No module name regex patterns provided, will log all modules") + return + + # Compile all patterns + for regex_pattern_call_idx in self.module_call_match: + try: + splits = regex_pattern_call_idx.split(":", 2) + regex_pattern = splits[0] + call_idx = -1 + if len(splits) > 1: + call_idx = int(splits[1]) + compiled_pattern: Pattern[str] = re.compile(regex_pattern) + self._compiled_module_calls[compiled_pattern] = call_idx + logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'") + except Exception as e: + logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") + raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e + + + logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns") + + def to_dict(self) -> dict: + """Convert the config to a dictionary for serialization.""" + return { + "output_run_dir": self.output_run_dir, + "module_call_match": self.module_call_match, + "log_step_ids": self.log_step_ids, + "max_tensor_size": self.max_tensor_size, + "enabled": self.enabled, + "device_names": self.device_names + } + + @classmethod + def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig": + """Parse the CLI value for the speculative config.""" + return cls(**dict_value) + + @property + def output_run_dir(self) -> str: + return self._output_run_dir + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # Intermediate logging doesn't affect the computation graph + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + + class CompilationLevel: # constants for the levels of the compilation process NO_COMPILATION = 0 @@ -4480,6 +4597,8 @@ class VllmConfig: """The configurations for distributed KV cache transfer.""" kv_events_config: Optional[KVEventsConfig] = None """The configurations for event publishing.""" + intermediate_log_config: Optional[IntermediateLoggingConfig] = None + """Configuration for intermediate tensor logging.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -4564,6 +4683,10 @@ class VllmConfig: vllm_factors.append(self.kv_transfer_config.compute_hash()) else: vllm_factors.append("None") + if self.intermediate_log_config: + vllm_factors.append(self.intermediate_log_config.compute_hash()) + else: + vllm_factors.append("None") if self.additional_config: if isinstance(additional_config := self.additional_config, dict): additional_config_hash = hashlib.md5( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4d6001a428d2..49e331e183397 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -26,7 +26,8 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DecodingConfig, DetailedTraceModules, Device, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, GuidedDecodingBackendV1, - HfOverrides, KVEventsConfig, KVTransferConfig, + HfOverrides, IntermediateLoggingConfig, + KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, LoRAConfig, ModelConfig, ModelDType, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, @@ -399,6 +400,7 @@ class EngineArgs: str] = ModelConfig.logits_processor_pattern speculative_config: Optional[Dict[str, Any]] = None + show_hidden_metrics_for_version: Optional[str] = \ ObservabilityConfig.show_hidden_metrics_for_version @@ -444,6 +446,9 @@ class EngineArgs: async_scheduling: bool = SchedulerConfig.async_scheduling # DEPRECATED enable_prompt_adapter: bool = False + intermediate_log_config_path: Optional[str] = None + + intermediate_log_config: Optional[dict[str, Any]] = None def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -758,6 +763,20 @@ class EngineArgs: help="The configurations for speculative decoding. Should be a " "JSON string.") + intermediate_log_group = parser.add_argument_group( + title="IntermediateLoggingConfig", + description=IntermediateLoggingConfig.__doc__, + ) + intermediate_log_group.add_argument( + "--intermediate-log-config", + type=json.loads, + default=None, + help="The configurations for intermediate loggings. Should be a " + "JSON string.") + + intermediate_log_group.add_argument("--intermediate-log-config-path", type=str, + help="The path to the configurations for intermediate loggings. Should be a string.") + # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) observability_group = parser.add_argument_group( @@ -846,6 +865,9 @@ class EngineArgs: vllm_group.add_argument("--additional-config", **vllm_kwargs["additional_config"]) + + + # Other arguments parser.add_argument('--disable-log-stats', action='store_true', @@ -957,6 +979,21 @@ class EngineArgs: use_tqdm_on_load=self.use_tqdm_on_load, pt_load_map_location=self.pt_load_map_location, ) + + + def create_intermediate_log_config( + self, + ) -> Optional[IntermediateLoggingConfig]: + """Initializes and returns an IntermediateLoggingConfig object based on + `intermediate_log_config` or `intermediate_log_config_path`. + """ + if self.intermediate_log_config is not None: + return IntermediateLoggingConfig.from_dict( + self.intermediate_log_config) + if self.intermediate_log_config_path is not None: + with open(self.intermediate_log_config_path, "r") as f: + return IntermediateLoggingConfig.from_dict(json.load(f)) + return None def create_speculative_config( self, @@ -1198,6 +1235,9 @@ class EngineArgs: disable_log_stats=self.disable_log_stats, ) + intermediate_log_config = self.create_intermediate_log_config( + ) + # Reminder: Please update docs/features/compatibility_matrix.md # If the feature combo become valid if self.num_scheduler_steps > 1: @@ -1284,7 +1324,6 @@ class EngineArgs: otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, ) - config = VllmConfig( model_config=model_config, cache_config=cache_config, @@ -1299,6 +1338,7 @@ class EngineArgs: compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, + intermediate_log_config=intermediate_log_config, additional_config=self.additional_config, ) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 57f60c4b289bb..8a1f28a7a7715 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -77,6 +77,10 @@ class EngineCore: # Setup Model. self.model_executor = executor_class(vllm_config) + if vllm_config.intermediate_log_config is not None: + self.collective_rpc("register_intermediate_hooks", + args=(vllm_config.intermediate_log_config, )) + if executor_fail_callback is not None: self.model_executor.register_failure_callback( executor_fail_callback) diff --git a/vllm/v1/intermediates/__init__.py b/vllm/v1/intermediates/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/intermediates/intermediates_logging.py b/vllm/v1/intermediates/intermediates_logging.py new file mode 100644 index 0000000000000..0024523b7befb --- /dev/null +++ b/vllm/v1/intermediates/intermediates_logging.py @@ -0,0 +1,599 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Module for logging intermediate tensors during model execution. + +This module provides functionality to capture and save intermediate tensors +(inputs and outputs) from PyTorch modules during forward passes. +""" + +import json +import os +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Optional + +import torch +from torch.utils.hooks import RemovableHandle + +from vllm.config import IntermediateLoggingConfig + +# Import logger from vllm +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Global step counter +_CURRENT_STEP = 0 + +_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {} + +IL_MODULE_NAME = "_il_module_name" +IL_MODULE_CALL_IDX = "_il_module_call_idx" + +# Utility functions for intermediate logging + + +def should_log_step(config): + """Check if the current step should be logged based on the step IDs. + + Args: + config: The IntermediateLoggingConfig instance. + + Returns: + True if the current step should be logged, False otherwise. + """ + if not is_log_enabled(config): + return False + + # If log_step_ids is empty, log all steps + if not config.log_step_ids: + return True + + # Otherwise, check if current step is in the set of step IDs to log + return get_step() in config._step_id_set + + +def should_log_device(config, device_name): + """Check if a device should be logged based on the device names. + + Args: + config: The IntermediateLoggingConfig instance. + device_name: The name of the device to check (e.g., 'cuda:0', 'cpu'). + + Returns: + True if the device should be logged, False otherwise. + If device_names is empty, all devices are logged. + """ + if not is_log_enabled(config): + return False + # If device_names is empty, log all devices + if not config.device_names: + return True + + # Otherwise, check if device_name is in the list of device names to log + return device_name in config.device_names + + +def should_log_module(config, module_name, module: torch.nn.Module) -> bool: + """Check if a module should be logged based on the name regex patterns. + + Args: + config: The IntermediateLoggingConfig instance. + module_name: The name of the module to check. + + Returns: + True if the module should be logged, False otherwise. + If no patterns are defined, all modules are logged. + If patterns are defined, the module is logged if it matches ANY pattern. + """ + if not is_log_enabled(config): + return False + # If no patterns are defined, log all modules + if not config._compiled_module_calls: + logger.debug("No patterns defined, will log module: %s", module_name) + set_il_module_name(module, module_name) + set_il_module_call_idx(module, -1) + return True + + # Check if the module name matches any of the patterns + for pattern, call_idx in config._compiled_module_calls.items(): + match = pattern.search(module_name) + if match: + logger.debug( + "Module %s, %s matches pattern: '%s', call_idx=%s", + module_name, + module.__class__.__name__, + pattern.pattern, + call_idx, + ) + set_il_module_name(module, module_name) + set_il_module_call_idx(module, call_idx) + return True + return False + + +def is_log_enabled(config): + if not config or not config.enabled: + logger.debug("Not logging because config not enabled") + return False + if torch.compiler.is_compiling(): + logger.debug("Not logging because torch.compile is in progress") + return False + return True + + +def get_il_module_name(module: torch.nn.Module) -> str: + return getattr(module, IL_MODULE_NAME, module.__class__.__name__) + + +def get_il_module_call_idx(module: torch.nn.Module) -> int: + return getattr(module, IL_MODULE_CALL_IDX, -1) + + +def set_il_module_name(module: torch.nn.Module, name: str) -> None: + setattr(module, IL_MODULE_NAME, name) + + +def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None: + setattr(module, IL_MODULE_CALL_IDX, idx) + + +_global_config: Optional[IntermediateLoggingConfig] = None + + +@contextmanager +def intermediate_logging(config: Optional[IntermediateLoggingConfig]): + """ + Temporarily sets the global config for the duration of the context. + :param config: Keyword arguments to set as global config + """ + global _global_config + old_config = _global_config + try: + _global_config = config + yield + finally: + _global_config = old_config + + +def get_current_il_config(): + return _global_config + + +def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any: + try: + # Convert inputs to JSON-serializable format + intermediates_json = convert_intermediates_to_json(intermediates) + with open(path, "w") as f: + json.dump(intermediates_json, f, indent=2) + logger.debug("Saved all intermediates as JSON to %s", path) + except Exception as e: + logger.warning("Failed to save intermediates as JSON: %s", e) + import traceback + + logger.warning(traceback.format_exc()) + + +def convert_intermediates_to_json(tensor: Any) -> Any: + """Convert a intermediates(including tensor) to a JSON-serializable + representation. + + Args: + intermediates: The intermediates to convert. + + Returns: + A JSON-serializable representation of the tensor. + """ + if isinstance(tensor, torch.Tensor): + try: + result = { + "type": "tensor", + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "numel": tensor.numel(), + } + + return result + except Exception as e: + # Handle any errors in tensor conversion + return { + "type": "tensor_error", + "error": str(e), + "tensor_type": str(type(tensor)), + } + + elif isinstance(tensor, (list, tuple)): + # For lists/tuples, recursively convert each element + container_type = "list" if isinstance(tensor, list) else "tuple" + + # If it's a large list, only include a sample + if len(tensor) > 20: + return { + "type": container_type, + "length": len(tensor), + "sample": [ + convert_intermediates_to_json(item) for item in tensor[:100] + ], + "note": f"Showing only first 20 of {len(tensor)} items", + } + else: + return { + "type": container_type, + "items": [convert_intermediates_to_json(item) for item in tensor], + } + + elif isinstance(tensor, dict): + # For dictionaries, recursively convert each value + if len(tensor) > 20: + # For large dicts, only include keys and a sample of values + keys = list(tensor.keys()) + sample_keys = keys[:20] + return { + "type": "dict", + "length": len(tensor), + "keys": keys, + "sample": { + k: convert_intermediates_to_json(tensor[k]) for k in sample_keys + }, + "note": f"Showing only first 20 of {len(tensor)} items", + } + else: + return { + "type": "dict", + "items": { + k: convert_intermediates_to_json(v) for k, v in tensor.items() + }, + } + + elif tensor is None: + return None + + elif isinstance(tensor, (int, float, bool, str)): + # Primitive types can be directly serialized + return tensor + + else: + # For other types, use string representation + return {"type": str(type(tensor).__name__), "string_repr": str(tensor)} + + +def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool: + """Utility function to dump tensor metadata to a file. + + Args: + tensor: The tensor to dump. + file_path: Base path where to save the tensor (without extension). + """ + intermediate_log_config = get_current_il_config() + if intermediate_log_config is None: + return False + if ( + intermediate_log_config.max_tensor_size is not None + and tensor.numel() > intermediate_log_config.max_tensor_size + ): + # Save tensor metadata instead of full tensor + tensor_info = { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "device": str(tensor.device), + "numel": tensor.numel(), + "skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size " + f"{intermediate_log_config.max_tensor_size}", + } + os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True) + with open(f"{file_path}.json", "w") as f: + json.dump(tensor_info, f, indent=2) + return True + return False + + +def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any: + if reload_dir is None: + return None + try: + intermediate_log_config = get_current_il_config() + assert intermediate_log_config is not None + replace_dir = str(intermediate_log_config.output_run_dir) + reload_path = save_path.replace(replace_dir, reload_dir) + logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path) + return torch.load(reload_path, map_location=tensor.device) + except Exception as e: + logger.warning("Failed to load tensor from %s: %s", reload_dir, e) + return tensor + + +def save_tensors( + tensor: Any, file_path: str, reload_input_dir: Optional[str] = None +) -> Any: + """Utility function to dump tensor to a file. + + Args: + tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of + tensors, or a dictionary containing tensors. + file_path: Base path where to save the tensor (without extension). + """ + + # Also save the actual tensor data for tensors + if isinstance(tensor, torch.Tensor): + # Check if tensor is too large + if save_tensors_metadata_if_too_large(tensor, file_path): + return + # Get device name + device_name = str(tensor.device) + # Skip if device filtering is enabled and this device should not be + # logged + intermediate_log_config = get_current_il_config() + if not should_log_device(intermediate_log_config, device_name): + logger.debug( + "Skipping tensor on device %s due to device filter", device_name + ) + return tensor + # Append device name to file path + pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt" + try: + # Save tensor directly without detaching or moving to CPU + torch.save(tensor, pt_path) + reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir) + if reloaded_tensor is not None: + return reloaded_tensor + logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path) + except Exception as e: + logger.warning("Failed to save tensor to %s: %s", pt_path, e) + return tensor + + if isinstance(tensor, (list, tuple)): + # For collections, also save each item individually + + reloaded_inputs = [] + for i, item in enumerate(tensor): + reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir) + reloaded_inputs.append(reloaded) + return tuple(reloaded_inputs) if reloaded_inputs else tensor + if isinstance(tensor, dict): + reloaded_inputs = {} + # For dictionaries, also save each value individually + for k, v in tensor.items(): + reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir) + reloaded_inputs[k] = reloaded + return reloaded_inputs if reloaded_inputs else tensor + + +def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None: + """Hook to increment the global step counter after a forward pass. + + Args: + module: The PyTorch module being executed. + inputs: The inputs to the module's forward function. + outputs: The outputs from the module's forward function. + """ + if get_current_il_config() is None: + return + # Increment the global step counter + increment_step() + global _CURRENT_STEP_MODULE_CALL_STEP + _CURRENT_STEP_MODULE_CALL_STEP = {} + + +def _prepare_module_log_dir( + intermediate_log_config: IntermediateLoggingConfig, + module_name: str, + is_pre_fwd: bool = False, +) -> Path: + # Create a unique directory for this step if not + dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}" + dump_dir.mkdir(exist_ok=True, parents=True) + + # Create module directory + suffix = "" + module_call_idx = get_current_step_module_call(module_name) + if module_call_idx > 0: + suffix = f"_{module_call_idx}" + module_dir = dump_dir / (module_name + suffix) + if is_pre_fwd: + _log_module_call(intermediate_log_config, module_name + suffix) + module_dir.mkdir(exist_ok=True, parents=True) + logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir) + return module_dir + + +def _log_module_call( + intermediate_log_config: IntermediateLoggingConfig, + module_name: str, +) -> None: + logger.debug("Logging module call for %s", module_name) + # write module name and call to step: + file = ( + Path(intermediate_log_config.output_run_dir) + / f"step_{get_step()}" + / "module_calls.txt" + ) + with open(file, "a") as f: + f.write(f"{module_name}\n") + + +def update_current_step_module_call(module_name: str) -> None: + logger.debug("Updating current step module call for %s", module_name) + global _CURRENT_STEP_MODULE_CALL_STEP + if module_name not in _CURRENT_STEP_MODULE_CALL_STEP: + _CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0 + else: + _CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1 + + +def get_current_step_module_call(module_name: str) -> int: + return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0) + + +def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]: + intermediate_log_config = get_current_il_config() + if intermediate_log_config is None or not intermediate_log_config.enabled: + return None + if not should_log_step(intermediate_log_config): + return None + + module_name = get_il_module_name(module) + log_call_idx = get_il_module_call_idx(module) + current_call_idx = get_current_step_module_call(module_name) + should_log = True + if log_call_idx >= 0 and current_call_idx != log_call_idx: + should_log = False + + log_dir = None + if is_pre_fwd: + update_current_step_module_call(module_name) + if should_log: + log_dir = _prepare_module_log_dir( + intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd + ) + return log_dir + + +def log_pre_fwd_hook( + module: torch.nn.Module, inputs: tuple[Any, ...] +) -> tuple[Any, ...]: + """Hook to capture module inputs before forward pass. + + Args: + module: The PyTorch module being executed. + inputs: The inputs to the module's forward function. + + Returns: + The unchanged inputs. + """ + if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True): + dump_intermediates_to_json(inputs, log_dir / "inputs.json") + intermediate_log_config = get_current_il_config() + if intermediate_log_config is not None: + reload_input_dir = getattr( + intermediate_log_config, + "reload_input_dir", + "/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5", + ) + else: + reload_input_dir = None + reloaded_inputs = save_tensors( + inputs, str(log_dir / "inputs"), reload_input_dir + ) + if reloaded_inputs is not None: + return reloaded_inputs + return inputs + + +def log_post_fwd_hook( + module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any +) -> None: + """Hook to capture module outputs after forward pass. + + Args: + module: The PyTorch module being executed. + inputs: The inputs to the module's forward function. + outputs: The outputs from the module's forward function. + """ + if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False): + dump_intermediates_to_json(outputs, log_dir / "outputs.json") + save_tensors(outputs, str(log_dir / "outputs")) + intermediate_log_config = get_current_il_config() + assert intermediate_log_config is not None, "IL config should not be None" + if intermediate_log_config.log_post_fwd_inputs: + dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json") + save_tensors(inputs, str(log_dir / "post_fwd_inputs")) + + +def get_step() -> int: + """Get the current global step counter. + + Returns: + The current global step counter. + """ + return _CURRENT_STEP + + +def increment_step() -> int: + """Increment the global step counter. + + Returns: + The new step counter value. + """ + global _CURRENT_STEP + _CURRENT_STEP += 1 + return _CURRENT_STEP + + +def reset_step() -> None: + """Reset the global step counter to zero.""" + global _CURRENT_STEP + _CURRENT_STEP = 0 + + +class IntermediatesLogger: + """Class to manage logging of intermediate tensors during model + execution.""" + + def __init__(self, config: IntermediateLoggingConfig): + self.config = config + self.hooks: list[ + tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]] + ] = [] + logger.debug("Created IntermediatesLogger with config: %s", config) + path = Path(config.output_run_dir) + path.mkdir(exist_ok=True, parents=True) + # Log configuration + logger.info("Intermediates will be logged in %s", config.output_run_dir) + + def register_hooks(self, model: torch.nn.Module) -> None: + """Register hooks for the model. + + Args: + model: The PyTorch model to register hooks for. + """ + + for name, module in model.named_modules(): + if name and should_log_module(self.config, name, module): + pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook) + logger.debug( + "Registered pre_fwd hook for %s", module.__class__.__name__ + ) + post_hook = module.register_forward_hook(log_post_fwd_hook) + logger.debug( + "Registered post_fwd hook for %s", module.__class__.__name__ + ) + self.hooks.append((name, module, pre_hook, post_hook)) + + # Register a step counter hook for the root model + step_hook = model.register_forward_hook(step_fwd) + self.hooks.append(("", model, None, step_hook)) + logger.info("Registered hooks for %s modules", len(self.hooks)) + + def remove_hooks(self) -> None: + """Remove all registered hooks.""" + for _, _, pre_hook, post_hook in self.hooks: + if pre_hook is not None: + pre_hook.remove() + if post_hook is not None: + post_hook.remove() + + logger.info("Removed %s hooks", len(self.hooks)) + self.hooks = [] + + +def register_intermediate_hooks( + model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs +) -> IntermediatesLogger: + """Register hooks to log intermediate tensors for a model. + + Args: + model: The PyTorch model to log intermediates for. + config: Configuration for intermediate logging. If provided, this takes + precedence over kwargs. + + Returns: + An IntermediatesLogger instance that can be used to manage the hooks. + """ + if config is None: + # Create config from kwargs + config = IntermediateLoggingConfig.from_dict(kwargs) + + logger_instance = IntermediatesLogger(config) + logger_instance.register_hooks(model) + return logger_instance diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d9d1f14f0554c..a0d5cd2e40e2b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -32,6 +32,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase +from vllm.v1.intermediates.intermediates_logging import intermediate_logging logger = init_logger(__name__) @@ -344,8 +345,9 @@ class Worker(WorkerBase): get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) - output = self.model_runner.execute_model(scheduler_output, - intermediate_tensors) + with intermediate_logging(self.vllm_config.intermediate_log_config): + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) parallel_config = self.vllm_config.parallel_config if parallel_config.distributed_executor_backend != "external_launcher" \ diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 9c93754f93f81..bbe601afa956e 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -6,9 +6,10 @@ from typing import Optional import torch import torch.nn as nn -from vllm.config import VllmConfig +from vllm.config import VllmConfig, IntermediateLoggingConfig from vllm.logger import init_logger from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 logger = init_logger(__name__) @@ -63,3 +64,27 @@ class WorkerBase(WorkerBaseV0): def check_health(self) -> None: """Basic health check (override for device-specific checks).""" return + + def register_intermediate_hooks(self, + config: Optional[IntermediateLoggingConfig] = None, + **kwargs) -> None: + """Register hooks for intermediate tensor logging. + + This method is called via collective_rpc from the engine core. + It registers hooks on the model to dump intermediate tensors during execution. + + Args: + config: Configuration for intermediate logging. If provided, this takes precedence over kwargs. + """ + if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None: + logger.error("Could not register intermediate hooks: model_runner.model is not accessible") + return + model = self.model_runner.model + try: + # Register hooks + register_intermediate_hooks(model, config, **kwargs) + # Store the logger instance for potential later hook removal + except Exception as e: + logger.info("Successfully registered intermediate hooks") + logger.error("Error registering intermediate hooks", exc_info=True) + diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index f1c9a0ab001e8..065bbd26c05c9 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -128,6 +128,22 @@ class WorkerBase: def vocab_size(self) -> int: """Get vocabulary size from model configuration.""" return self.model_config.get_vocab_size() + + def register_intermediate_hooks(self, config=None, **kwargs) -> None: + """Register hooks for intermediate tensor logging. + + This method is a stub for v0 workers. The actual implementation is in v1 workers. + It's included here for compatibility with the collective_rpc mechanism. + + Args: + config: Configuration for intermediate logging. + **kwargs: Configuration parameters for intermediate logging. + These are ignored in v0 workers. + """ + logger.warning( + "register_intermediate_hooks is not implemented in v0 workers. " + "This is only available in v1 workers. No hooks will be registered.") + return None class DelegateWorkerBase(WorkerBase):