add il tool

more changes

Apply suggestions from code review

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

fix tp

Signed-off-by: Lu Fang <fanglu@fb.com>

add comparison tool

tmp

add unit test and fix format

Signed-off-by: Lu Fang <fanglu@fb.com>

add comparison script and documentation

Signed-off-by: Lu Fang <fanglu@fb.com>

provide default intermediate logging

Signed-off-by: Lu Fang <fanglu@fb.com>

optional register il

Signed-off-by: Lu Fang <fanglu@fb.com>

add input reload and improve intermediate compare
This commit is contained in:
Lu Fang 2025-07-17 18:20:31 -07:00 committed by Lucia Fang
parent c6c9122d50
commit d8bff253d7
11 changed files with 1982 additions and 6 deletions

View File

@ -0,0 +1,136 @@
# Intermediate Tensor Logging
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
## Overview
The intermediate tensor logging feature enables you to:
- Log input and output tensors from a configured set of filters
- Filter modules by name using regex patterns
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
- Filter tensors by device
- Filter whole model fwd step id
This is manily useful for debugging model accucacy gaps with 2 runs
## Usage
### Enabling via parameters or config file
**Offline Inference example**
Dump all modules, all devices for step 0 (default behavior)
```bash
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}'
```
Dump first layers module, all devices for step 0
```bash
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}'
```
Dump customized layers, devices, steps through a config file
The configuration file should be a JSON file with the following structure:
```json
{
"output_dir": "/tmp/vllm_intermediates",
"module_call_match": ["layers\\.0\\.(?!.*rotary_emb).*", "rotary_emb:0", "embed_tokens", "model\\.norm"],
"log_step_ids": [0, 1],
"device_names": ["cuda:0"]
}
```
```bash
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json
```
#### Configuration Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
| `log_step_ids` | array | List of step IDs to log | `[0]` |
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
| `device_names` | array | List of device names to log | `[]` (log all devices) |
### Output Directory Structure
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
```
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
└── step_0
├── model.embed_tokens
│ ├── inputs_0_cuda_0.pt
│ ├── inputs.json
│ ├── outputs_cuda_0.pt
│ └── outputs.json
├── model.layers.0.input_layernorm
│ ├── inputs_0_cuda_0.pt
│ ├── inputs.json
│ ├── outputs_cuda_0.pt
│ └── outputs.json
└── step_1/
└── ...
```
Each tensor is saved in two formats:
1. `.json` files containing metadata and small tensor values
2. `.pt` files containing the full PyTorch tensors (can be loaded with `torch.load()`)
## Comparing Intermediate Logging Results
vLLM provides a tool called `compare_intermediate.py` to compare intermediate tensors between two different runs. This is particularly useful for debugging accuracy differences or verifying that code changes don't affect model outputs.
### Usage
```bash
python tools/compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options]
```
### Options
| Option | Description | Default |
|--------|-------------|---------|
| `--dir1` | First intermediate logging directory | (required) |
| `--dir2` | Second intermediate logging directory | (required) |
| `--output` | Output file for the report | stdout |
| `--rtol` | Relative tolerance for tensor comparison | 1e-5 |
| `--atol` | Absolute tolerance for tensor comparison | 1e-8 |
| `--steps` | Comma-separated list of steps to compare | all |
| `--modules` | Comma-separated list of module name patterns to compare | all |
| `--verbose` | Include detailed information about each tensor | false |
### Example
```bash
# Compare all tensors from two different runs
python tools/compare_intermediate.py --dir1 /tmp/vllm_intermediates/run1 --dir2 /tmp/vllm_intermediates/run2
# Compare only specific modules and steps with custom tolerance
python tools/compare_intermediate.py \
--dir1 /tmp/vllm_intermediates/run1 \
--dir2 /tmp/vllm_intermediates/run2 \
--steps 0,1 \
--modules ".*attention.*,.*mlp.*" \
--rtol 1e-4 \
--atol 1e-7 \
--output comparison_report.md
```
### Output
The tool generates a detailed markdown report that includes:
- Overall summary of matching and mismatched tensors
- Per-module comparison results
- Detailed tensor differences (when using `--verbose`)
This makes it easy to identify which specific tensors differ between runs and by how much.

View File

@ -0,0 +1,325 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the intermediate tensor logging functionality.
"""
import json
from os.path import isdir
import shutil
import os
import tempfile
from pathlib import Path
from unittest import mock
import pytest
import torch
import torch.nn as nn
from vllm.config import IntermediateLoggingConfig
from vllm.v1.intermediates.intermediates_logging import (get_current_il_config,
get_step, increment_step,
intermediate_logging,
register_intermediate_hooks,
reset_step,
should_log_device,
should_log_module,
should_log_step)
class SimpleModel(nn.Module):
"""A simple model for testing."""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
@pytest.fixture
def temp_output_dir():
"""Create a temporary directory for test outputs."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# Clean up after the test
shutil.rmtree(temp_dir)
@pytest.fixture
def simple_model():
"""Create a simple model for testing."""
return SimpleModel()
@pytest.fixture
def il_config(temp_output_dir):
"""Create a basic IntermediateLoggingConfig for testing."""
return IntermediateLoggingConfig(output_dir=temp_output_dir,
enabled=True,
log_step_ids=[0, 1],
module_call_match=[".*linear.*"])
def test_step_counter():
"""Test the step counter functionality."""
# Reset the step counter
reset_step()
assert get_step() == 0
# Increment the step counter
increment_step()
assert get_step() == 1
# Increment again
increment_step()
assert get_step() == 2
# Reset again
reset_step()
assert get_step() == 0
def test_intermediate_logging_context_manager():
"""Test the intermediate_logging context manager."""
# Create a config
config = IntermediateLoggingConfig(enabled=True)
# Initially, there should be no global config
assert get_current_il_config() is None
# Use the context manager
with intermediate_logging(config):
# Inside the context, the global config should be set
assert get_current_il_config() is not None
assert get_current_il_config().enabled is True
# After the context, the global config should be None again
assert get_current_il_config() is None
# Test with a different config
config2 = IntermediateLoggingConfig(enabled=False)
with intermediate_logging(config2):
assert get_current_il_config() is not None
assert get_current_il_config().enabled is False
def test_should_log_step():
"""Test the should_log_step function."""
# Reset step counter
reset_step()
# Create configs with different step IDs
config_all_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[] # Empty list means log all steps
)
config_specific_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
)
config_disabled = IntermediateLoggingConfig(enabled=False,
log_step_ids=[0, 1, 2])
# Test with all steps config
with intermediate_logging(config_all_steps):
assert should_log_step(config_all_steps) is True # Step 0
increment_step()
assert should_log_step(config_all_steps) is True # Step 1
# Reset step counter
reset_step()
# Test with specific steps config
with intermediate_logging(config_specific_steps):
assert should_log_step(config_specific_steps) is True # Step 0
increment_step()
assert should_log_step(config_specific_steps) is False # Step 1
increment_step()
assert should_log_step(config_specific_steps) is True # Step 2
increment_step()
assert should_log_step(config_specific_steps) is False # Step 3
increment_step()
assert should_log_step(config_specific_steps) is True # Step 4
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_step(config_disabled) is False # Disabled
def test_should_log_device():
"""Test the should_log_device function."""
# Create configs with different device filters
config_all_devices = IntermediateLoggingConfig(
enabled=True,
device_names=[] # Empty list means log all devices
)
config_specific_devices = IntermediateLoggingConfig(
enabled=True,
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
)
config_disabled = IntermediateLoggingConfig(enabled=False,
device_names=["cuda:0", "cpu"])
# Test with all devices config
with intermediate_logging(config_all_devices):
assert should_log_device(config_all_devices, "cuda:0") is True
assert should_log_device(config_all_devices, "cuda:1") is True
assert should_log_device(config_all_devices, "cpu") is True
# Test with specific devices config
with intermediate_logging(config_specific_devices):
assert should_log_device(config_specific_devices, "cuda:0") is True
assert should_log_device(config_specific_devices, "cuda:1") is False
assert should_log_device(config_specific_devices, "cpu") is True
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_device(config_disabled, "cuda:0") is False
assert should_log_device(config_disabled, "cpu") is False
def test_should_log_module(simple_model):
"""Test the should_log_module function."""
# Create configs with different module name filters
config_all_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=None # None means log all modules
)
config_specific_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=[".*linear.*"
] # Only log modules with "linear" in the name
)
config_disabled = IntermediateLoggingConfig(enabled=False,
module_call_match=[".*"])
# Test with all modules config
with intermediate_logging(config_all_modules):
assert should_log_module(config_all_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_all_modules, "relu",
simple_model.relu) is True
# Test with specific modules config
with intermediate_logging(config_specific_modules):
assert should_log_module(config_specific_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_specific_modules, "relu",
simple_model.relu) is False
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_module(config_disabled, "linear1",
simple_model.linear1) is False
assert should_log_module(config_disabled, "relu",
simple_model.relu) is False
def test_register_hooks(simple_model, il_config):
"""Test registering hooks on a model."""
# Register hooks
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Check that hooks were registered
assert len(logger_instance.hooks) > 0
# Remove hooks
logger_instance.remove_hooks()
# Check that hooks were removed
assert len(logger_instance.hooks) == 0
@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
il_config, temp_output_dir):
"""Test that forward hooks are called during model execution."""
mock_save_tensors.return_value = None
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that the step counter was incremented
assert get_step() == 1
# Check that dump_intermediates_to_json and save_tensors were called
assert mock_dump_json.called
assert mock_save_tensors.called
# Remove hooks
logger_instance.remove_hooks()
def test_end_to_end(simple_model, il_config, temp_output_dir):
"""Test the entire intermediate logging workflow end-to-end."""
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that output directories were created
root_dir = Path(il_config._output_run_dir)
assert root_dir.exists()
step_dir = root_dir / "step_0"
assert step_dir.exists()
module_dirs = list(step_dir.glob("*"))
print(f"{module_dirs=}")
assert len(module_dirs) > 0
# Check that input and output files were created
for module_dir in module_dirs:
print(f"{module_dir=}")
if os.path.isdir(module_dir):
inputs_json = module_dir / "inputs.json"
outputs_json = module_dir / "outputs.json"
# Check that JSON files exist
assert inputs_json.exists()
assert outputs_json.exists()
# Check that JSON files contain valid data
with open(inputs_json) as f:
inputs_data = json.load(f)
assert "type" in inputs_data
with open(outputs_json) as f:
outputs_data = json.load(f)
assert "type" in outputs_data
# Check that tensor files exist
tensor_files = list(module_dir.glob("*.pt"))
assert len(tensor_files) > 0
# Remove hooks
logger_instance.remove_hooks()
if __name__ == "__main__":
pytest.main(["-xvs", __file__])

706
tools/compare_intermediate.py Executable file
View File

@ -0,0 +1,706 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Script to compare intermediate logging outputs from two different runs.
This script compares the tensor outputs from two different intermediate logging
directories and generates a report of the differences.
Usage:
python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options]
Options:
--dir1 DIR First intermediate logging directory
--dir2 DIR Second intermediate logging directory
--output FILE Output file for the report (default: stdout)
--format {md,json} Output format (default: md)
--rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5)
--atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8)
--steps STEPS Comma-separated list of steps to compare (default: all)
--modules MODULES Comma-separated list of module name patterns to compare (default: all)
--verbose Include detailed information about each tensor
"""
import argparse
import json
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
def load_tensor(path: Path) -> torch.Tensor:
"""Load a tensor from a .pt file."""
try:
return torch.load(path, map_location="cpu")
except Exception as e:
print(f"Error loading tensor from {path}: {e}")
return None
def load_json(path: Path) -> Dict:
"""Load a JSON file."""
try:
with open(path, "r") as f:
return json.load(f)
except Exception as e:
print(f"Error loading JSON from {path}: {e}")
return {}
def extract_diff_metatada(exception_str: str) -> Dict:
try:
num_diff_elements = int(
re.search(r"Mismatched elements: (\d+) /", exception_str).group(1)
)
total_elements = int(
re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1)
)
max_abs_diff = float(
re.search(
r"Greatest absolute difference: ([\d\.e-]+)", exception_str
).group(1)
)
max_rel_diff = float(
re.search(
r"Greatest relative difference: ([\d\.e-]+)", exception_str
).group(1)
)
return {
"num_diff_elements": num_diff_elements,
"total_elements": total_elements,
"max_abs_diff": max_abs_diff,
"max_rel_diff": max_rel_diff,
}
except Exception:
return {"error": exception_str}
def compare_tensors(
tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float
) -> Dict:
"""Compare two tensors and return a dictionary with comparison results."""
if tensor1 is None or tensor2 is None:
return {"match": False, "error": "One or both tensors are None"}
if tensor1.shape != tensor2.shape:
return {
"match": False,
"error": f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}",
}
if tensor1.dtype != tensor2.dtype:
return {
"match": False,
"error": f"Dtype mismatch: {tensor1.dtype} vs {tensor2.dtype}",
}
# Check if tensors are close using PyTorch's assert_close
try:
torch.testing.assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
except Exception as e:
return {"match": False, **extract_diff_metatada(str(e))}
return {"match": True}
def compare_json_values(value1: Any, value2: Any) -> Dict:
"""Compare two JSON values and return a dictionary with comparison results."""
if type(value1) is not type(value2):
return {
"match": False,
"error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}",
}
if isinstance(value1, dict):
# Compare dictionaries
all_keys = set(value1.keys()) | set(value2.keys())
mismatches = {}
for key in all_keys:
if key not in value1:
mismatches[key] = {"error": "Missing in first dict"}
elif key not in value2:
mismatches[key] = {"error": "Missing in second dict"}
else:
comparison = compare_json_values(value1[key], value2[key])
if not comparison["match"]:
mismatches[key] = comparison
if mismatches:
return {"match": False, "mismatches": mismatches}
return {"match": True}
elif isinstance(value1, list):
# Compare lists
if len(value1) != len(value2):
return {
"match": False,
"error": f"Length mismatch: {len(value1)} vs {len(value2)}",
}
mismatches = {}
for i, (item1, item2) in enumerate(zip(value1, value2)):
comparison = compare_json_values(item1, item2)
if not comparison["match"]:
mismatches[i] = comparison
if mismatches:
return {"match": False, "mismatches": mismatches}
return {"match": True}
else:
# Compare primitive values
if value1 == value2:
return {"match": True}
else:
return {"match": False, "value1": value1, "value2": value2}
def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
"""
Find all tensor files in the given directory.
Returns a dictionary with the structure:
{
"step_0": {
"module_name_123456": {
"inputs": [Path("inputs_0_cuda_0.pt"), ...],
"outputs": [Path("output_cuda_0.pt"), ...]
},
...
},
...
}
"""
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
# Find all step directories
step_dirs = [d for d in directory.glob("step_*") if d.is_dir()]
for step_dir in step_dirs:
step_name = step_dir.name
# Find all module directories
module_dirs = [d for d in step_dir.glob("*") if d.is_dir()]
for module_dir in module_dirs:
module_name = module_dir.name
# Find input tensor files
input_tensors = list(module_dir.glob("inputs_*.pt"))
if input_tensors:
result[step_name][module_name]["inputs"] = input_tensors
# Find output tensor files
output_tensors = list(module_dir.glob("output*.pt"))
if output_tensors:
result[step_name][module_name]["outputs"] = output_tensors
# Find JSON metadata files
inputs_json = module_dir / "inputs.json"
if inputs_json.exists():
result[step_name][module_name]["inputs_json"] = [inputs_json]
outputs_json = module_dir / "outputs.json"
if outputs_json.exists():
result[step_name][module_name]["outputs_json"] = [outputs_json]
return result
def filter_steps_and_modules(
tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]],
steps: Optional[List[str]] = None,
module_patterns: Optional[List[str]] = None,
) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
"""Filter tensor files by steps and module patterns."""
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
# Filter steps
if steps:
step_names = [f"step_{step}" for step in steps]
steps_to_include = {step: True for step in step_names}
else:
steps_to_include = {step: True for step in tensor_files.keys()}
# Compile module patterns
if module_patterns:
compiled_patterns = [re.compile(pattern) for pattern in module_patterns]
else:
compiled_patterns = None
for step_name, modules in tensor_files.items():
if step_name not in steps_to_include:
continue
for module_name, file_types in modules.items():
# Check if module matches any pattern
if compiled_patterns:
if not any(
pattern.search(module_name) for pattern in compiled_patterns
):
continue
result[step_name][module_name] = file_types
return result
def compare_directories(
dir1: Path,
dir2: Path,
rtol: Optional[float] = None,
atol: Optional[float] = None,
steps: Optional[List[str]] = None,
module_patterns: Optional[List[str]] = None,
) -> Dict:
"""Compare two intermediate logging directories and return a report."""
# Find tensor files in both directories
tensor_files1 = find_tensor_files(dir1)
tensor_files2 = find_tensor_files(dir2)
# Filter by steps and modules
if steps or module_patterns:
tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns)
tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns)
# Get all steps and modules from both directories
all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys())
report = {
"dir1": str(dir1),
"dir2": str(dir2),
"rtol": rtol,
"atol": atol,
"steps": {},
}
# Compare each step
for step in sorted(all_steps):
step_report = {
"modules": {},
"summary": {
"total_modules": 0,
"matching_modules": 0,
"mismatched_modules": 0,
"missing_modules": 0,
},
}
# Get all modules from both directories for this step
modules1 = tensor_files1.get(step, {})
modules2 = tensor_files2.get(step, {})
# TODO: read from module calls.txt to get the full module list
# TODO: check if module calls txt exsits
dir1_module_call_file = dir1 / step / "module_calls.txt"
if dir1_module_call_file.exists():
with open(dir1 / step / "module_calls.txt", "r") as f:
all_modules = f.read().splitlines()
else:
print(
"Warnings: the module call orders are missed, ordering using module alphbetics"
)
all_modules = sorted(set(modules1.keys()) | set(modules2.keys()))
step_report["module_call_list"] = []
for module in all_modules:
module_report = {
"inputs": {},
"outputs": {},
"summary": {
"total_tensors": 0,
"matching_tensors": 0,
"mismatched_tensors": 0,
"missing_tensors": 0,
},
}
# Check if module exists in both directories
if module not in modules1:
module_report["error"] = f"Module missing in {dir1}"
step_report["summary"]["missing_modules"] += 1
step_report["modules"][module] = module_report
continue
if module not in modules2:
module_report["error"] = f"Module missing in {dir2}"
step_report["summary"]["missing_modules"] += 1
step_report["modules"][module] = module_report
continue
# Compare JSON metadata
for json_type in ["inputs_json", "outputs_json"]:
json_files1 = modules1[module].get(json_type, [])
json_files2 = modules2[module].get(json_type, [])
if json_files1 and json_files2:
json1 = load_json(json_files1[0])
json2 = load_json(json_files2[0])
json_comparison = compare_json_values(json1, json2)
json_name = json_type.replace("_json", "")
module_report[f"{json_name}_metadata"] = json_comparison
# Add file paths for manual checking when there's a mismatch
if not json_comparison.get("match", True):
module_report[f"{json_name}_metadata"]["file1"] = str(
json_files1[0]
)
module_report[f"{json_name}_metadata"]["file2"] = str(
json_files2[0]
)
# Compare input tensors
input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])}
input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])}
all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys())
for tensor_name in sorted(all_input_names):
if tensor_name not in input_tensors1:
module_report["inputs"][tensor_name] = {
"match": False,
"error": f"Tensor missing in {dir1}",
}
module_report["summary"]["missing_tensors"] += 1
elif tensor_name not in input_tensors2:
module_report["inputs"][tensor_name] = {
"match": False,
"error": f"Tensor missing in {dir2}",
}
module_report["summary"]["missing_tensors"] += 1
else:
tensor1 = load_tensor(input_tensors1[tensor_name])
tensor2 = load_tensor(input_tensors2[tensor_name])
comparison = compare_tensors(tensor1, tensor2, rtol, atol)
# Add file paths for manual checking when there's a mismatch
if not comparison.get("match", False):
comparison["file1"] = str(input_tensors1[tensor_name])
comparison["file2"] = str(input_tensors2[tensor_name])
module_report["inputs"][tensor_name] = comparison
if comparison.get("match", False):
module_report["summary"]["matching_tensors"] += 1
else:
module_report["summary"]["mismatched_tensors"] += 1
module_report["summary"]["total_tensors"] += 1
# Compare output tensors
output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])}
output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])}
all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys())
for tensor_name in sorted(all_output_names):
if tensor_name not in output_tensors1:
module_report["outputs"][tensor_name] = {
"match": False,
"error": f"Tensor missing in {dir1}",
}
module_report["summary"]["missing_tensors"] += 1
elif tensor_name not in output_tensors2:
module_report["outputs"][tensor_name] = {
"match": False,
"error": f"Tensor missing in {dir2}",
}
module_report["summary"]["missing_tensors"] += 1
else:
tensor1 = load_tensor(output_tensors1[tensor_name])
tensor2 = load_tensor(output_tensors2[tensor_name])
comparison = compare_tensors(tensor1, tensor2, rtol, atol)
# Add file paths for manual checking when there's a mismatch
if not comparison.get("match", False):
comparison["file1"] = str(output_tensors1[tensor_name])
comparison["file2"] = str(output_tensors2[tensor_name])
module_report["outputs"][tensor_name] = comparison
if comparison.get("match", False):
module_report["summary"]["matching_tensors"] += 1
else:
module_report["summary"]["mismatched_tensors"] += 1
module_report["summary"]["total_tensors"] += 1
# Update module status
if module_report["summary"]["mismatched_tensors"] > 0:
step_report["summary"]["mismatched_modules"] += 1
else:
step_report["summary"]["matching_modules"] += 1
step_report["summary"]["total_modules"] += 1
step_report["modules"][module] = module_report
step_report["module_call_list"].append(module)
report["steps"][step] = step_report
# Add overall summary
report["summary"] = {
"total_steps": len(all_steps),
"total_modules": sum(
step_report["summary"]["total_modules"]
for step_report in report["steps"].values()
),
"matching_modules": sum(
step_report["summary"]["matching_modules"]
for step_report in report["steps"].values()
),
"mismatched_modules": sum(
step_report["summary"]["mismatched_modules"]
for step_report in report["steps"].values()
),
"missing_modules": sum(
step_report["summary"]["missing_modules"]
for step_report in report["steps"].values()
),
"total_tensors": sum(
module_report["summary"]["total_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
"matching_tensors": sum(
module_report["summary"]["matching_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
"mismatched_tensors": sum(
module_report["summary"]["mismatched_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
"missing_tensors": sum(
module_report["summary"]["missing_tensors"]
for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items()
if "summary" in module_report
),
}
return report
def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
"""Generate a markdown report from the comparison results."""
lines = []
# Add header
lines.append("# Intermediate Logging Comparison Report")
lines.append("")
lines.append("Comparing intermediate logging outputs between:")
lines.append(f"- **Directory 1**: `{report['dir1']}`")
lines.append(f"- **Directory 2**: `{report['dir2']}`")
lines.append("")
lines.append(f"Comparison parameters:")
lines.append(f"- Relative tolerance (rtol): {report['rtol']}")
lines.append(f"- Absolute tolerance (atol): {report['atol']}")
lines.append("")
# Add overall summary
summary = report["summary"]
lines.append("## Overall Summary")
lines.append("")
lines.append("| Category | Total | Matching | Mismatched | Missing |")
lines.append("|----------|-------|----------|------------|---------|")
lines.append(f"| Steps | {summary['total_steps']} | - | - | - |")
lines.append(
f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |"
)
lines.append(
f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |"
)
lines.append("")
# Add step details
for step_name, step_report in sorted(report["steps"].items()):
step_summary = step_report["summary"]
lines.append(f"## {step_name}")
lines.append("")
lines.append(
f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules"
)
lines.append("")
# Add module details
for module_name in step_report["module_call_list"]:
module_report = step_report["modules"][module_name]
if "error" in module_report:
lines.append(f"### ❌ {module_name}")
lines.append("")
lines.append(f"**Error**: {module_report['error']}")
lines.append("")
continue
module_summary = module_report["summary"]
# Determine module status
if module_summary["mismatched_tensors"] > 0:
status = ""
else:
status = ""
lines.append(f"### {status} {module_name}")
lines.append("")
lines.append(
f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors"
)
lines.append("")
# Add metadata comparison results if available
for metadata_type in ["inputs_metadata", "outputs_metadata"]:
if metadata_type in module_report:
metadata_comparison = module_report[metadata_type]
if not metadata_comparison.get("match", True):
file_paths = ""
if (
"file1" in metadata_comparison
and "file2" in metadata_comparison
):
file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`"
lines.append(
f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}"
)
if verbose and "mismatches" in metadata_comparison:
lines.append("```json")
lines.append(
json.dumps(metadata_comparison["mismatches"], indent=2)
)
lines.append("```")
lines.append("")
# Add tensor comparison details
if module_summary["mismatched_tensors"] > 0 or verbose:
# Add input tensor details
if module_report["inputs"]:
lines.append("#### Input Tensors")
lines.append("")
lines.append("| Tensor | Status | Details |")
lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted(
module_report["inputs"].items()
):
if comparison.get("match", False):
status = ""
details = "Tensors match"
elif "error" in comparison:
status = ""
details = comparison["error"]
else:
status = ""
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, "
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, "
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
if "file1" in comparison and "file2" in comparison:
details += f"<br>Files: `{comparison['file1']}` vs `{comparison['file2']}`"
lines.append(f"| {tensor_name} | {status} | {details} |")
lines.append("")
# Add output tensor details
if module_report["outputs"]:
lines.append("#### Output Tensors")
lines.append("")
lines.append("| Tensor | Status | Details |")
lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted(
module_report["outputs"].items()
):
if comparison.get("match", False):
status = ""
details = "Tensors match"
elif "error" in comparison:
status = ""
details = comparison["error"]
else:
status = ""
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, "
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, "
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
lines.append(f"| {tensor_name} | {status} | {details} |")
lines.append("")
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(
description="Compare intermediate logging outputs from two different runs."
)
parser.add_argument(
"--dir1", required=True, help="First intermediate logging directory"
)
parser.add_argument(
"--dir2", required=True, help="Second intermediate logging directory"
)
parser.add_argument("--output", help="Output file for the report (default: stdout)")
parser.add_argument(
"--rtol",
type=float,
default=None,
help="Relative tolerance for tensor comparison (default: 1e-5)",
)
parser.add_argument(
"--atol",
type=float,
default=None,
help="Absolute tolerance for tensor comparison (default: 1e-8)",
)
parser.add_argument(
"--steps", help="Comma-separated list of steps to compare (default: all)"
)
parser.add_argument(
"--modules",
help="Comma-separated list of module name patterns to compare (default: all)",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Include detailed information about each tensor",
)
args = parser.parse_args()
# Parse steps and modules
steps = args.steps.split(",") if args.steps else None
module_patterns = args.modules.split(",") if args.modules else None
# Compare directories
report = compare_directories(
Path(args.dir1),
Path(args.dir2),
rtol=args.rtol,
atol=args.atol,
steps=steps,
module_patterns=module_patterns,
)
# Generate report
output = generate_markdown_report(report, verbose=args.verbose)
# Write report
if args.output:
with open(args.output, "w") as f:
f.write(output)
print(f"Report written to {args.output}")
else:
print(output)
if __name__ == "__main__":
main()
def invoke_main() -> None:
main()

View File

@ -17,7 +17,8 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
from functools import cached_property
from importlib.util import find_spec
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args)
Protocol, TypeVar, Union, cast, get_args, List, Set)
from re import Pattern
import regex as re
import torch
@ -4024,6 +4025,122 @@ class KVEventsConfig:
"""
@config
@dataclass
class IntermediateLoggingConfig:
"""Configuration for intermediate tensor logging."""
output_dir: str = "/tmp/vllm_intermediates"
"""Directory where to save the intermediate tensors."""
reload_input_dir: Optional[str] = None
"""Directory where to load the inputs for the steps/modules.
This is used when we want to check per module numerical gaps instead
of accumulated gap to further dive into the actual numerical issues."""
module_call_match: Optional[List[str]] = None
"""Match modules by name regex and call index (
a module can be called multiple times in a step)
List of regex:call_idx, call_idx is -1 for default for all calls """
log_step_ids: List[int] = field(default_factory=lambda: [0])
"""List of step IDs to log (empty list means log all steps)."""
log_post_fwd_inputs: bool = False
"""Whether logging inputs after forwards for each module"""
max_tensor_size: Optional[int] = None
"""Maximum number of elements in tensors to log (None = no limit)."""
enabled: bool = True
"""Whether logging is enabled."""
device_names: List[str] = field(default_factory=list)
"""List of device names to log (empty list means log all devices)."""
_compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False)
"""Compiled regex patterns for module filtering."""
_module_call: dict[str, int] = field(default_factory=dict, init=False)
_step_id_set: Set[int] = field(default_factory=set, init=False)
"""Set of step IDs for faster lookup."""
_output_run_dir: str = "/tmp/vllm_intermediates"
"""Unique directory to save single run/serve logging result."""
def __post_init__(self):
"""Initialize derived fields after instance creation."""
self._compile_regex_patterns()
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
self._step_id_set = set(self.log_step_ids)
def _compile_regex_patterns(self):
"""Compile regex patterns for module name filtering."""
from vllm.logger import init_logger
logger = init_logger(__name__)
self._compiled_module_matches = []
if self.module_call_match is None:
logger.info("No module name regex patterns provided, will log all modules")
return
# Compile all patterns
for regex_pattern_call_idx in self.module_call_match:
try:
splits = regex_pattern_call_idx.split(":", 2)
regex_pattern = splits[0]
call_idx = -1
if len(splits) > 1:
call_idx = int(splits[1])
compiled_pattern: Pattern[str] = re.compile(regex_pattern)
self._compiled_module_calls[compiled_pattern] = call_idx
logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'")
except Exception as e:
logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}")
raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e
logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns")
def to_dict(self) -> dict:
"""Convert the config to a dictionary for serialization."""
return {
"output_run_dir": self.output_run_dir,
"module_call_match": self.module_call_match,
"log_step_ids": self.log_step_ids,
"max_tensor_size": self.max_tensor_size,
"enabled": self.enabled,
"device_names": self.device_names
}
@classmethod
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)
@property
def output_run_dir(self) -> str:
return self._output_run_dir
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# Intermediate logging doesn't affect the computation graph
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
@ -4480,6 +4597,8 @@ class VllmConfig:
"""The configurations for distributed KV cache transfer."""
kv_events_config: Optional[KVEventsConfig] = None
"""The configurations for event publishing."""
intermediate_log_config: Optional[IntermediateLoggingConfig] = None
"""Configuration for intermediate tensor logging."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
@ -4564,6 +4683,10 @@ class VllmConfig:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.intermediate_log_config:
vllm_factors.append(self.intermediate_log_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(

View File

@ -26,7 +26,8 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DecodingConfig, DetailedTraceModules, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, KVEventsConfig, KVTransferConfig,
HfOverrides, IntermediateLoggingConfig,
KVEventsConfig, KVTransferConfig,
LoadConfig, LogprobsMode, LoRAConfig, ModelConfig,
ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
@ -399,6 +400,7 @@ class EngineArgs:
str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None
show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
@ -444,6 +446,9 @@ class EngineArgs:
async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
intermediate_log_config_path: Optional[str] = None
intermediate_log_config: Optional[dict[str, Any]] = None
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@ -758,6 +763,20 @@ class EngineArgs:
help="The configurations for speculative decoding. Should be a "
"JSON string.")
intermediate_log_group = parser.add_argument_group(
title="IntermediateLoggingConfig",
description=IntermediateLoggingConfig.__doc__,
)
intermediate_log_group.add_argument(
"--intermediate-log-config",
type=json.loads,
default=None,
help="The configurations for intermediate loggings. Should be a "
"JSON string.")
intermediate_log_group.add_argument("--intermediate-log-config-path", type=str,
help="The path to the configurations for intermediate loggings. Should be a string.")
# Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig)
observability_group = parser.add_argument_group(
@ -846,6 +865,9 @@ class EngineArgs:
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
# Other arguments
parser.add_argument('--disable-log-stats',
action='store_true',
@ -957,6 +979,21 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load,
pt_load_map_location=self.pt_load_map_location,
)
def create_intermediate_log_config(
self,
) -> Optional[IntermediateLoggingConfig]:
"""Initializes and returns an IntermediateLoggingConfig object based on
`intermediate_log_config` or `intermediate_log_config_path`.
"""
if self.intermediate_log_config is not None:
return IntermediateLoggingConfig.from_dict(
self.intermediate_log_config)
if self.intermediate_log_config_path is not None:
with open(self.intermediate_log_config_path, "r") as f:
return IntermediateLoggingConfig.from_dict(json.load(f))
return None
def create_speculative_config(
self,
@ -1198,6 +1235,9 @@ class EngineArgs:
disable_log_stats=self.disable_log_stats,
)
intermediate_log_config = self.create_intermediate_log_config(
)
# Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid
if self.num_scheduler_steps > 1:
@ -1284,7 +1324,6 @@ class EngineArgs:
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_detailed_traces=self.collect_detailed_traces,
)
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
@ -1299,6 +1338,7 @@ class EngineArgs:
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
intermediate_log_config=intermediate_log_config,
additional_config=self.additional_config,
)

View File

@ -77,6 +77,10 @@ class EngineCore:
# Setup Model.
self.model_executor = executor_class(vllm_config)
if vllm_config.intermediate_log_config is not None:
self.collective_rpc("register_intermediate_hooks",
args=(vllm_config.intermediate_log_config, ))
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)

View File

View File

@ -0,0 +1,599 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Module for logging intermediate tensors during model execution.
This module provides functionality to capture and save intermediate tensors
(inputs and outputs) from PyTorch modules during forward passes.
"""
import json
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Optional
import torch
from torch.utils.hooks import RemovableHandle
from vllm.config import IntermediateLoggingConfig
# Import logger from vllm
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global step counter
_CURRENT_STEP = 0
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
IL_MODULE_NAME = "_il_module_name"
IL_MODULE_CALL_IDX = "_il_module_call_idx"
# Utility functions for intermediate logging
def should_log_step(config):
"""Check if the current step should be logged based on the step IDs.
Args:
config: The IntermediateLoggingConfig instance.
Returns:
True if the current step should be logged, False otherwise.
"""
if not is_log_enabled(config):
return False
# If log_step_ids is empty, log all steps
if not config.log_step_ids:
return True
# Otherwise, check if current step is in the set of step IDs to log
return get_step() in config._step_id_set
def should_log_device(config, device_name):
"""Check if a device should be logged based on the device names.
Args:
config: The IntermediateLoggingConfig instance.
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
Returns:
True if the device should be logged, False otherwise.
If device_names is empty, all devices are logged.
"""
if not is_log_enabled(config):
return False
# If device_names is empty, log all devices
if not config.device_names:
return True
# Otherwise, check if device_name is in the list of device names to log
return device_name in config.device_names
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
"""Check if a module should be logged based on the name regex patterns.
Args:
config: The IntermediateLoggingConfig instance.
module_name: The name of the module to check.
Returns:
True if the module should be logged, False otherwise.
If no patterns are defined, all modules are logged.
If patterns are defined, the module is logged if it matches ANY pattern.
"""
if not is_log_enabled(config):
return False
# If no patterns are defined, log all modules
if not config._compiled_module_calls:
logger.debug("No patterns defined, will log module: %s", module_name)
set_il_module_name(module, module_name)
set_il_module_call_idx(module, -1)
return True
# Check if the module name matches any of the patterns
for pattern, call_idx in config._compiled_module_calls.items():
match = pattern.search(module_name)
if match:
logger.debug(
"Module %s, %s matches pattern: '%s', call_idx=%s",
module_name,
module.__class__.__name__,
pattern.pattern,
call_idx,
)
set_il_module_name(module, module_name)
set_il_module_call_idx(module, call_idx)
return True
return False
def is_log_enabled(config):
if not config or not config.enabled:
logger.debug("Not logging because config not enabled")
return False
if torch.compiler.is_compiling():
logger.debug("Not logging because torch.compile is in progress")
return False
return True
def get_il_module_name(module: torch.nn.Module) -> str:
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
def get_il_module_call_idx(module: torch.nn.Module) -> int:
return getattr(module, IL_MODULE_CALL_IDX, -1)
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
setattr(module, IL_MODULE_NAME, name)
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
setattr(module, IL_MODULE_CALL_IDX, idx)
_global_config: Optional[IntermediateLoggingConfig] = None
@contextmanager
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
"""
Temporarily sets the global config for the duration of the context.
:param config: Keyword arguments to set as global config
"""
global _global_config
old_config = _global_config
try:
_global_config = config
yield
finally:
_global_config = old_config
def get_current_il_config():
return _global_config
def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any:
try:
# Convert inputs to JSON-serializable format
intermediates_json = convert_intermediates_to_json(intermediates)
with open(path, "w") as f:
json.dump(intermediates_json, f, indent=2)
logger.debug("Saved all intermediates as JSON to %s", path)
except Exception as e:
logger.warning("Failed to save intermediates as JSON: %s", e)
import traceback
logger.warning(traceback.format_exc())
def convert_intermediates_to_json(tensor: Any) -> Any:
"""Convert a intermediates(including tensor) to a JSON-serializable
representation.
Args:
intermediates: The intermediates to convert.
Returns:
A JSON-serializable representation of the tensor.
"""
if isinstance(tensor, torch.Tensor):
try:
result = {
"type": "tensor",
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"numel": tensor.numel(),
}
return result
except Exception as e:
# Handle any errors in tensor conversion
return {
"type": "tensor_error",
"error": str(e),
"tensor_type": str(type(tensor)),
}
elif isinstance(tensor, (list, tuple)):
# For lists/tuples, recursively convert each element
container_type = "list" if isinstance(tensor, list) else "tuple"
# If it's a large list, only include a sample
if len(tensor) > 20:
return {
"type": container_type,
"length": len(tensor),
"sample": [
convert_intermediates_to_json(item) for item in tensor[:100]
],
"note": f"Showing only first 20 of {len(tensor)} items",
}
else:
return {
"type": container_type,
"items": [convert_intermediates_to_json(item) for item in tensor],
}
elif isinstance(tensor, dict):
# For dictionaries, recursively convert each value
if len(tensor) > 20:
# For large dicts, only include keys and a sample of values
keys = list(tensor.keys())
sample_keys = keys[:20]
return {
"type": "dict",
"length": len(tensor),
"keys": keys,
"sample": {
k: convert_intermediates_to_json(tensor[k]) for k in sample_keys
},
"note": f"Showing only first 20 of {len(tensor)} items",
}
else:
return {
"type": "dict",
"items": {
k: convert_intermediates_to_json(v) for k, v in tensor.items()
},
}
elif tensor is None:
return None
elif isinstance(tensor, (int, float, bool, str)):
# Primitive types can be directly serialized
return tensor
else:
# For other types, use string representation
return {"type": str(type(tensor).__name__), "string_repr": str(tensor)}
def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool:
"""Utility function to dump tensor metadata to a file.
Args:
tensor: The tensor to dump.
file_path: Base path where to save the tensor (without extension).
"""
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None:
return False
if (
intermediate_log_config.max_tensor_size is not None
and tensor.numel() > intermediate_log_config.max_tensor_size
):
# Save tensor metadata instead of full tensor
tensor_info = {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"device": str(tensor.device),
"numel": tensor.numel(),
"skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size "
f"{intermediate_log_config.max_tensor_size}",
}
os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True)
with open(f"{file_path}.json", "w") as f:
json.dump(tensor_info, f, indent=2)
return True
return False
def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any:
if reload_dir is None:
return None
try:
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None
replace_dir = str(intermediate_log_config.output_run_dir)
reload_path = save_path.replace(replace_dir, reload_dir)
logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path)
return torch.load(reload_path, map_location=tensor.device)
except Exception as e:
logger.warning("Failed to load tensor from %s: %s", reload_dir, e)
return tensor
def save_tensors(
tensor: Any, file_path: str, reload_input_dir: Optional[str] = None
) -> Any:
"""Utility function to dump tensor to a file.
Args:
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
tensors, or a dictionary containing tensors.
file_path: Base path where to save the tensor (without extension).
"""
# Also save the actual tensor data for tensors
if isinstance(tensor, torch.Tensor):
# Check if tensor is too large
if save_tensors_metadata_if_too_large(tensor, file_path):
return
# Get device name
device_name = str(tensor.device)
# Skip if device filtering is enabled and this device should not be
# logged
intermediate_log_config = get_current_il_config()
if not should_log_device(intermediate_log_config, device_name):
logger.debug(
"Skipping tensor on device %s due to device filter", device_name
)
return tensor
# Append device name to file path
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
try:
# Save tensor directly without detaching or moving to CPU
torch.save(tensor, pt_path)
reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir)
if reloaded_tensor is not None:
return reloaded_tensor
logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path)
except Exception as e:
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
return tensor
if isinstance(tensor, (list, tuple)):
# For collections, also save each item individually
reloaded_inputs = []
for i, item in enumerate(tensor):
reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir)
reloaded_inputs.append(reloaded)
return tuple(reloaded_inputs) if reloaded_inputs else tensor
if isinstance(tensor, dict):
reloaded_inputs = {}
# For dictionaries, also save each value individually
for k, v in tensor.items():
reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir)
reloaded_inputs[k] = reloaded
return reloaded_inputs if reloaded_inputs else tensor
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None:
"""Hook to increment the global step counter after a forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if get_current_il_config() is None:
return
# Increment the global step counter
increment_step()
global _CURRENT_STEP_MODULE_CALL_STEP
_CURRENT_STEP_MODULE_CALL_STEP = {}
def _prepare_module_log_dir(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
is_pre_fwd: bool = False,
) -> Path:
# Create a unique directory for this step if not
dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir.mkdir(exist_ok=True, parents=True)
# Create module directory
suffix = ""
module_call_idx = get_current_step_module_call(module_name)
if module_call_idx > 0:
suffix = f"_{module_call_idx}"
module_dir = dump_dir / (module_name + suffix)
if is_pre_fwd:
_log_module_call(intermediate_log_config, module_name + suffix)
module_dir.mkdir(exist_ok=True, parents=True)
logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir)
return module_dir
def _log_module_call(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
) -> None:
logger.debug("Logging module call for %s", module_name)
# write module name and call to step:
file = (
Path(intermediate_log_config.output_run_dir)
/ f"step_{get_step()}"
/ "module_calls.txt"
)
with open(file, "a") as f:
f.write(f"{module_name}\n")
def update_current_step_module_call(module_name: str) -> None:
logger.debug("Updating current step module call for %s", module_name)
global _CURRENT_STEP_MODULE_CALL_STEP
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
else:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
def get_current_step_module_call(module_name: str) -> int:
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None or not intermediate_log_config.enabled:
return None
if not should_log_step(intermediate_log_config):
return None
module_name = get_il_module_name(module)
log_call_idx = get_il_module_call_idx(module)
current_call_idx = get_current_step_module_call(module_name)
should_log = True
if log_call_idx >= 0 and current_call_idx != log_call_idx:
should_log = False
log_dir = None
if is_pre_fwd:
update_current_step_module_call(module_name)
if should_log:
log_dir = _prepare_module_log_dir(
intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd
)
return log_dir
def log_pre_fwd_hook(
module: torch.nn.Module, inputs: tuple[Any, ...]
) -> tuple[Any, ...]:
"""Hook to capture module inputs before forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
Returns:
The unchanged inputs.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
dump_intermediates_to_json(inputs, log_dir / "inputs.json")
intermediate_log_config = get_current_il_config()
if intermediate_log_config is not None:
reload_input_dir = getattr(
intermediate_log_config,
"reload_input_dir",
"/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5",
)
else:
reload_input_dir = None
reloaded_inputs = save_tensors(
inputs, str(log_dir / "inputs"), reload_input_dir
)
if reloaded_inputs is not None:
return reloaded_inputs
return inputs
def log_post_fwd_hook(
module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any
) -> None:
"""Hook to capture module outputs after forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
dump_intermediates_to_json(outputs, log_dir / "outputs.json")
save_tensors(outputs, str(log_dir / "outputs"))
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None, "IL config should not be None"
if intermediate_log_config.log_post_fwd_inputs:
dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json")
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
def get_step() -> int:
"""Get the current global step counter.
Returns:
The current global step counter.
"""
return _CURRENT_STEP
def increment_step() -> int:
"""Increment the global step counter.
Returns:
The new step counter value.
"""
global _CURRENT_STEP
_CURRENT_STEP += 1
return _CURRENT_STEP
def reset_step() -> None:
"""Reset the global step counter to zero."""
global _CURRENT_STEP
_CURRENT_STEP = 0
class IntermediatesLogger:
"""Class to manage logging of intermediate tensors during model
execution."""
def __init__(self, config: IntermediateLoggingConfig):
self.config = config
self.hooks: list[
tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]]
] = []
logger.debug("Created IntermediatesLogger with config: %s", config)
path = Path(config.output_run_dir)
path.mkdir(exist_ok=True, parents=True)
# Log configuration
logger.info("Intermediates will be logged in %s", config.output_run_dir)
def register_hooks(self, model: torch.nn.Module) -> None:
"""Register hooks for the model.
Args:
model: The PyTorch model to register hooks for.
"""
for name, module in model.named_modules():
if name and should_log_module(self.config, name, module):
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
logger.debug(
"Registered pre_fwd hook for %s", module.__class__.__name__
)
post_hook = module.register_forward_hook(log_post_fwd_hook)
logger.debug(
"Registered post_fwd hook for %s", module.__class__.__name__
)
self.hooks.append((name, module, pre_hook, post_hook))
# Register a step counter hook for the root model
step_hook = model.register_forward_hook(step_fwd)
self.hooks.append(("", model, None, step_hook))
logger.info("Registered hooks for %s modules", len(self.hooks))
def remove_hooks(self) -> None:
"""Remove all registered hooks."""
for _, _, pre_hook, post_hook in self.hooks:
if pre_hook is not None:
pre_hook.remove()
if post_hook is not None:
post_hook.remove()
logger.info("Removed %s hooks", len(self.hooks))
self.hooks = []
def register_intermediate_hooks(
model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs
) -> IntermediatesLogger:
"""Register hooks to log intermediate tensors for a model.
Args:
model: The PyTorch model to log intermediates for.
config: Configuration for intermediate logging. If provided, this takes
precedence over kwargs.
Returns:
An IntermediatesLogger instance that can be used to manage the hooks.
"""
if config is None:
# Create config from kwargs
config = IntermediateLoggingConfig.from_dict(kwargs)
logger_instance = IntermediatesLogger(config)
logger_instance.register_hooks(model)
return logger_instance

View File

@ -32,6 +32,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
logger = init_logger(__name__)
@ -344,8 +345,9 @@ class Worker(WorkerBase):
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
with intermediate_logging(self.vllm_config.intermediate_log_config):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \

View File

@ -6,9 +6,10 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config import VllmConfig, IntermediateLoggingConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
logger = init_logger(__name__)
@ -63,3 +64,27 @@ class WorkerBase(WorkerBaseV0):
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
def register_intermediate_hooks(self,
config: Optional[IntermediateLoggingConfig] = None,
**kwargs) -> None:
"""Register hooks for intermediate tensor logging.
This method is called via collective_rpc from the engine core.
It registers hooks on the model to dump intermediate tensors during execution.
Args:
config: Configuration for intermediate logging. If provided, this takes precedence over kwargs.
"""
if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None:
logger.error("Could not register intermediate hooks: model_runner.model is not accessible")
return
model = self.model_runner.model
try:
# Register hooks
register_intermediate_hooks(model, config, **kwargs)
# Store the logger instance for potential later hook removal
except Exception as e:
logger.info("Successfully registered intermediate hooks")
logger.error("Error registering intermediate hooks", exc_info=True)

View File

@ -128,6 +128,22 @@ class WorkerBase:
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def register_intermediate_hooks(self, config=None, **kwargs) -> None:
"""Register hooks for intermediate tensor logging.
This method is a stub for v0 workers. The actual implementation is in v1 workers.
It's included here for compatibility with the collective_rpc mechanism.
Args:
config: Configuration for intermediate logging.
**kwargs: Configuration parameters for intermediate logging.
These are ignored in v0 workers.
"""
logger.warning(
"register_intermediate_hooks is not implemented in v0 workers. "
"This is only available in v1 workers. No hooks will be registered.")
return None
class DelegateWorkerBase(WorkerBase):