remove feature for metadata dump and input reload

Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang 2025-07-28 19:03:26 -07:00
parent d8bff253d7
commit 2af83ebdde
9 changed files with 278 additions and 544 deletions

View File

@ -49,7 +49,6 @@ The configuration file should be a JSON file with the following structure:
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json
``` ```
#### Configuration Parameters #### Configuration Parameters
| Parameter | Type | Description | Default | | Parameter | Type | Description | Default |

View File

@ -5,9 +5,8 @@ Tests for the intermediate tensor logging functionality.
""" """
import json import json
from os.path import isdir
import shutil
import os import os
import shutil
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
@ -17,14 +16,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import IntermediateLoggingConfig from vllm.config import IntermediateLoggingConfig
from vllm.v1.intermediates.intermediates_logging import (get_current_il_config, from vllm.v1.intermediates.intermediates_logging import (
get_step, increment_step, get_current_il_config, get_step, increment_step, intermediate_logging,
intermediate_logging, register_intermediate_hooks, reset_step, should_log_device,
register_intermediate_hooks, should_log_module, should_log_step)
reset_step,
should_log_device,
should_log_module,
should_log_step)
class SimpleModel(nn.Module): class SimpleModel(nn.Module):
@ -237,7 +232,8 @@ def test_register_hooks(simple_model, il_config):
assert len(logger_instance.hooks) == 0 assert len(logger_instance.hooks) == 0
@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json') @mock.patch(
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors') @mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model, def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
il_config, temp_output_dir): il_config, temp_output_dir):
@ -262,7 +258,6 @@ def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
# Check that dump_intermediates_to_json and save_tensors were called # Check that dump_intermediates_to_json and save_tensors were called
assert mock_dump_json.called assert mock_dump_json.called
assert mock_save_tensors.called assert mock_save_tensors.called
# Remove hooks # Remove hooks
logger_instance.remove_hooks() logger_instance.remove_hooks()

View File

@ -7,27 +7,30 @@ This script compares the tensor outputs from two different intermediate logging
directories and generates a report of the differences. directories and generates a report of the differences.
Usage: Usage:
python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options] python compare_intermediate.py --dir1 /path/to/first/log/dir \
--dir2 /path/to/second/log/dir [options]
Options: Options:
--dir1 DIR First intermediate logging directory --dir1 DIR First intermediate logging directory
--dir2 DIR Second intermediate logging directory --dir2 DIR Second intermediate logging directory
--output FILE Output file for the report (default: stdout) --output FILE Output file for the report (default: stdout)
--format {md,json} Output format (default: md) --rtol FLOAT Relative tolerance for tensor comparison
--rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5) (default: 1e-5)
--atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8) --atol FLOAT Absolute tolerance for tensor comparison
(default: 1e-8)
--steps STEPS Comma-separated list of steps to compare (default: all) --steps STEPS Comma-separated list of steps to compare (default: all)
--modules MODULES Comma-separated list of module name patterns to compare (default: all) --modules MODULES Comma-separated list of module name patterns to compare
(default: all)
--verbose Include detailed information about each tensor --verbose Include detailed information about each tensor
""" """
import argparse import argparse
import json import json
import re
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Optional
import regex as re
import torch import torch
@ -40,34 +43,19 @@ def load_tensor(path: Path) -> torch.Tensor:
return None return None
def load_json(path: Path) -> Dict: def extract_diff_metatada(exception_str: str) -> dict:
"""Load a JSON file."""
try:
with open(path, "r") as f:
return json.load(f)
except Exception as e:
print(f"Error loading JSON from {path}: {e}")
return {}
def extract_diff_metatada(exception_str: str) -> Dict:
try: try:
num_diff_elements = int( num_diff_elements = int(
re.search(r"Mismatched elements: (\d+) /", exception_str).group(1) re.search(r"Mismatched elements: (\d+) /", exception_str).group(1))
)
total_elements = int( total_elements = int(
re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1) re.search(r"Mismatched elements: \d+ / (\d+)",
) exception_str).group(1))
max_abs_diff = float( max_abs_diff = float(
re.search( re.search(r"Greatest absolute difference: ([\d\.e-]+)",
r"Greatest absolute difference: ([\d\.e-]+)", exception_str exception_str).group(1))
).group(1)
)
max_rel_diff = float( max_rel_diff = float(
re.search( re.search(r"Greatest relative difference: ([\d\.e-]+)",
r"Greatest relative difference: ([\d\.e-]+)", exception_str exception_str).group(1))
).group(1)
)
return { return {
"num_diff_elements": num_diff_elements, "num_diff_elements": num_diff_elements,
"total_elements": total_elements, "total_elements": total_elements,
@ -78,9 +66,8 @@ def extract_diff_metatada(exception_str: str) -> Dict:
return {"error": exception_str} return {"error": exception_str}
def compare_tensors( def compare_tensors(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float,
tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float atol: float) -> dict:
) -> Dict:
"""Compare two tensors and return a dictionary with comparison results.""" """Compare two tensors and return a dictionary with comparison results."""
if tensor1 is None or tensor2 is None: if tensor1 is None or tensor2 is None:
return {"match": False, "error": "One or both tensors are None"} return {"match": False, "error": "One or both tensors are None"}
@ -105,60 +92,8 @@ def compare_tensors(
return {"match": True} return {"match": True}
def compare_json_values(value1: Any, value2: Any) -> Dict: def find_tensor_files(
"""Compare two JSON values and return a dictionary with comparison results.""" directory: Path) -> dict[str, dict[str, dict[str, list[Path]]]]:
if type(value1) is not type(value2):
return {
"match": False,
"error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}",
}
if isinstance(value1, dict):
# Compare dictionaries
all_keys = set(value1.keys()) | set(value2.keys())
mismatches = {}
for key in all_keys:
if key not in value1:
mismatches[key] = {"error": "Missing in first dict"}
elif key not in value2:
mismatches[key] = {"error": "Missing in second dict"}
else:
comparison = compare_json_values(value1[key], value2[key])
if not comparison["match"]:
mismatches[key] = comparison
if mismatches:
return {"match": False, "mismatches": mismatches}
return {"match": True}
elif isinstance(value1, list):
# Compare lists
if len(value1) != len(value2):
return {
"match": False,
"error": f"Length mismatch: {len(value1)} vs {len(value2)}",
}
mismatches = {}
for i, (item1, item2) in enumerate(zip(value1, value2)):
comparison = compare_json_values(item1, item2)
if not comparison["match"]:
mismatches[i] = comparison
if mismatches:
return {"match": False, "mismatches": mismatches}
return {"match": True}
else:
# Compare primitive values
if value1 == value2:
return {"match": True}
else:
return {"match": False, "value1": value1, "value2": value2}
def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
""" """
Find all tensor files in the given directory. Find all tensor files in the given directory.
@ -198,23 +133,14 @@ def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Pat
if output_tensors: if output_tensors:
result[step_name][module_name]["outputs"] = output_tensors result[step_name][module_name]["outputs"] = output_tensors
# Find JSON metadata files
inputs_json = module_dir / "inputs.json"
if inputs_json.exists():
result[step_name][module_name]["inputs_json"] = [inputs_json]
outputs_json = module_dir / "outputs.json"
if outputs_json.exists():
result[step_name][module_name]["outputs_json"] = [outputs_json]
return result return result
def filter_steps_and_modules( def filter_steps_and_modules(
tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]], tensor_files: dict[str, dict[str, dict[str, list[Path]]]],
steps: Optional[List[str]] = None, steps: Optional[list[str]] = None,
module_patterns: Optional[List[str]] = None, module_patterns: Optional[list[str]] = None,
) -> Dict[str, Dict[str, Dict[str, List[Path]]]]: ) -> dict[str, dict[str, dict[str, list[Path]]]]:
"""Filter tensor files by steps and module patterns.""" """Filter tensor files by steps and module patterns."""
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
@ -223,11 +149,13 @@ def filter_steps_and_modules(
step_names = [f"step_{step}" for step in steps] step_names = [f"step_{step}" for step in steps]
steps_to_include = {step: True for step in step_names} steps_to_include = {step: True for step in step_names}
else: else:
steps_to_include = {step: True for step in tensor_files.keys()} steps_to_include = {step: True for step in tensor_files}
# Compile module patterns # Compile module patterns
if module_patterns: if module_patterns:
compiled_patterns = [re.compile(pattern) for pattern in module_patterns] compiled_patterns = [
re.compile(pattern) for pattern in module_patterns
]
else: else:
compiled_patterns = None compiled_patterns = None
@ -237,11 +165,10 @@ def filter_steps_and_modules(
for module_name, file_types in modules.items(): for module_name, file_types in modules.items():
# Check if module matches any pattern # Check if module matches any pattern
if compiled_patterns: if compiled_patterns and not any(
if not any( pattern.search(module_name)
pattern.search(module_name) for pattern in compiled_patterns for pattern in compiled_patterns):
): continue
continue
result[step_name][module_name] = file_types result[step_name][module_name] = file_types
@ -253,9 +180,9 @@ def compare_directories(
dir2: Path, dir2: Path,
rtol: Optional[float] = None, rtol: Optional[float] = None,
atol: Optional[float] = None, atol: Optional[float] = None,
steps: Optional[List[str]] = None, steps: Optional[list[str]] = None,
module_patterns: Optional[List[str]] = None, module_patterns: Optional[list[str]] = None,
) -> Dict: ) -> dict:
"""Compare two intermediate logging directories and return a report.""" """Compare two intermediate logging directories and return a report."""
# Find tensor files in both directories # Find tensor files in both directories
tensor_files1 = find_tensor_files(dir1) tensor_files1 = find_tensor_files(dir1)
@ -263,8 +190,10 @@ def compare_directories(
# Filter by steps and modules # Filter by steps and modules
if steps or module_patterns: if steps or module_patterns:
tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns) tensor_files1 = filter_steps_and_modules(tensor_files1, steps,
tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns) module_patterns)
tensor_files2 = filter_steps_and_modules(tensor_files2, steps,
module_patterns)
# Get all steps and modules from both directories # Get all steps and modules from both directories
all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys()) all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys())
@ -296,12 +225,12 @@ def compare_directories(
# TODO: check if module calls txt exsits # TODO: check if module calls txt exsits
dir1_module_call_file = dir1 / step / "module_calls.txt" dir1_module_call_file = dir1 / step / "module_calls.txt"
if dir1_module_call_file.exists(): if dir1_module_call_file.exists():
with open(dir1 / step / "module_calls.txt", "r") as f: with open(dir1 / step / "module_calls.txt") as f:
all_modules = f.read().splitlines() all_modules = f.read().splitlines()
else: else:
print( print(
"Warnings: the module call orders are missed, ordering using module alphbetics" "Warnings: the module call orders are missed, ordering using "
) "module alphbetics")
all_modules = sorted(set(modules1.keys()) | set(modules2.keys())) all_modules = sorted(set(modules1.keys()) | set(modules2.keys()))
step_report["module_call_list"] = [] step_report["module_call_list"] = []
for module in all_modules: for module in all_modules:
@ -329,32 +258,17 @@ def compare_directories(
step_report["modules"][module] = module_report step_report["modules"][module] = module_report
continue continue
# Compare JSON metadata
for json_type in ["inputs_json", "outputs_json"]:
json_files1 = modules1[module].get(json_type, [])
json_files2 = modules2[module].get(json_type, [])
if json_files1 and json_files2:
json1 = load_json(json_files1[0])
json2 = load_json(json_files2[0])
json_comparison = compare_json_values(json1, json2)
json_name = json_type.replace("_json", "")
module_report[f"{json_name}_metadata"] = json_comparison
# Add file paths for manual checking when there's a mismatch
if not json_comparison.get("match", True):
module_report[f"{json_name}_metadata"]["file1"] = str(
json_files1[0]
)
module_report[f"{json_name}_metadata"]["file2"] = str(
json_files2[0]
)
# Compare input tensors # Compare input tensors
input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])} input_tensors1 = {
input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])} p.name: p
all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys()) for p in modules1[module].get("inputs", [])
}
input_tensors2 = {
p.name: p
for p in modules2[module].get("inputs", [])
}
all_input_names = set(input_tensors1.keys()) | set(
input_tensors2.keys())
for tensor_name in sorted(all_input_names): for tensor_name in sorted(all_input_names):
if tensor_name not in input_tensors1: if tensor_name not in input_tensors1:
@ -389,9 +303,16 @@ def compare_directories(
module_report["summary"]["total_tensors"] += 1 module_report["summary"]["total_tensors"] += 1
# Compare output tensors # Compare output tensors
output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])} output_tensors1 = {
output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])} p.name: p
all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys()) for p in modules1[module].get("outputs", [])
}
output_tensors2 = {
p.name: p
for p in modules2[module].get("outputs", [])
}
all_output_names = set(output_tensors1.keys()) | set(
output_tensors2.keys())
for tensor_name in sorted(all_output_names): for tensor_name in sorted(all_output_names):
if tensor_name not in output_tensors1: if tensor_name not in output_tensors1:
@ -439,64 +360,58 @@ def compare_directories(
# Add overall summary # Add overall summary
report["summary"] = { report["summary"] = {
"total_steps": len(all_steps), "total_steps":
"total_modules": sum( len(all_steps),
step_report["summary"]["total_modules"] "total_modules":
for step_report in report["steps"].values() sum(step_report["summary"]["total_modules"]
), for step_report in report["steps"].values()),
"matching_modules": sum( "matching_modules":
step_report["summary"]["matching_modules"] sum(step_report["summary"]["matching_modules"]
for step_report in report["steps"].values() for step_report in report["steps"].values()),
), "mismatched_modules":
"mismatched_modules": sum( sum(step_report["summary"]["mismatched_modules"]
step_report["summary"]["mismatched_modules"] for step_report in report["steps"].values()),
for step_report in report["steps"].values() "missing_modules":
), sum(step_report["summary"]["missing_modules"]
"missing_modules": sum( for step_report in report["steps"].values()),
step_report["summary"]["missing_modules"] "total_tensors":
for step_report in report["steps"].values() sum(module_report["summary"]["total_tensors"]
),
"total_tensors": sum(
module_report["summary"]["total_tensors"]
for step_report in report["steps"].values() for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items() for module_name, module_report in step_report["modules"].items()
if "summary" in module_report if "summary" in module_report),
), "matching_tensors":
"matching_tensors": sum( sum(module_report["summary"]["matching_tensors"]
module_report["summary"]["matching_tensors"]
for step_report in report["steps"].values() for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items() for module_name, module_report in step_report["modules"].items()
if "summary" in module_report if "summary" in module_report),
), "mismatched_tensors":
"mismatched_tensors": sum( sum(module_report["summary"]["mismatched_tensors"]
module_report["summary"]["mismatched_tensors"]
for step_report in report["steps"].values() for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items() for module_name, module_report in step_report["modules"].items()
if "summary" in module_report if "summary" in module_report),
), "missing_tensors":
"missing_tensors": sum( sum(module_report["summary"]["missing_tensors"]
module_report["summary"]["missing_tensors"]
for step_report in report["steps"].values() for step_report in report["steps"].values()
for module_name, module_report in step_report["modules"].items() for module_name, module_report in step_report["modules"].items()
if "summary" in module_report if "summary" in module_report),
),
} }
return report return report
def generate_markdown_report(report: Dict, verbose: bool = False) -> str: def generate_markdown_report(report: dict, verbose: bool = False) -> str:
"""Generate a markdown report from the comparison results.""" """Generate a markdown report from the comparison results."""
lines = [] lines = []
# Add header # Add header
lines.append("# Intermediate Logging Comparison Report") lines.append("# Intermediate Logging Comparison Report")
lines.append("") lines.append("")
lines.append("Comparing intermediate logging outputs between:") lines.append("Comparing intermediate logging outputs "
"between:")
lines.append(f"- **Directory 1**: `{report['dir1']}`") lines.append(f"- **Directory 1**: `{report['dir1']}`")
lines.append(f"- **Directory 2**: `{report['dir2']}`") lines.append(f"- **Directory 2**: `{report['dir2']}`")
lines.append("") lines.append("")
lines.append(f"Comparison parameters:") lines.append("Comparison parameters:")
lines.append(f"- Relative tolerance (rtol): {report['rtol']}") lines.append(f"- Relative tolerance (rtol): {report['rtol']}")
lines.append(f"- Absolute tolerance (atol): {report['atol']}") lines.append(f"- Absolute tolerance (atol): {report['atol']}")
lines.append("") lines.append("")
@ -509,11 +424,13 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|----------|-------|----------|------------|---------|") lines.append("|----------|-------|----------|------------|---------|")
lines.append(f"| Steps | {summary['total_steps']} | - | - | - |") lines.append(f"| Steps | {summary['total_steps']} | - | - | - |")
lines.append( lines.append(
f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |" f"| Modules | {summary['total_modules']} | "
) f"{summary['matching_modules']} | {summary['mismatched_modules']} | "
f"{summary['missing_modules']} |")
lines.append( lines.append(
f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |" f"| Tensors | {summary['total_tensors']} | "
) f"{summary['matching_tensors']} | {summary['mismatched_tensors']} | "
f"{summary['missing_tensors']} |")
lines.append("") lines.append("")
# Add step details # Add step details
@ -523,8 +440,9 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append(f"## {step_name}") lines.append(f"## {step_name}")
lines.append("") lines.append("")
lines.append( lines.append(
f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules" f"**Summary**: {step_summary['matching_modules']} matching "
) f"modules, {step_summary['mismatched_modules']} mismatched "
f"modules, {step_summary['missing_modules']} missing modules")
lines.append("") lines.append("")
# Add module details # Add module details
@ -540,15 +458,14 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
module_summary = module_report["summary"] module_summary = module_report["summary"]
# Determine module status # Determine module status
if module_summary["mismatched_tensors"] > 0: status = "" if module_summary["mismatched_tensors"] > 0 else ""
status = ""
else:
status = ""
lines.append(f"### {status} {module_name}") lines.append(f"### {status} {module_name}")
lines.append("") lines.append("")
lines.append( lines.append(
f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors" f"**Summary**: {module_summary['matching_tensors']} matching "
f"tensors, {module_summary['mismatched_tensors']} mismatched "
f"tensors, {module_summary['missing_tensors']} missing tensors"
) )
lines.append("") lines.append("")
@ -558,20 +475,21 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
metadata_comparison = module_report[metadata_type] metadata_comparison = module_report[metadata_type]
if not metadata_comparison.get("match", True): if not metadata_comparison.get("match", True):
file_paths = "" file_paths = ""
if ( if ("file1" in metadata_comparison
"file1" in metadata_comparison and "file2" in metadata_comparison):
and "file2" in metadata_comparison file_paths = (
): f" - Files: "
file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`" f"`{metadata_comparison['file1']}` "
f"vs `{metadata_comparison['file2']}`")
lines.append( lines.append(
f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}" f"**{metadata_type.capitalize()}**: Mismatch "
) f"detected{file_paths}")
if verbose and "mismatches" in metadata_comparison: if verbose and "mismatches" in metadata_comparison:
lines.append("```json") lines.append("```json")
lines.append( lines.append(
json.dumps(metadata_comparison["mismatches"], indent=2) json.dumps(metadata_comparison["mismatches"],
) indent=2))
lines.append("```") lines.append("```")
lines.append("") lines.append("")
@ -585,8 +503,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|--------|--------|---------|") lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted( for tensor_name, comparison in sorted(
module_report["inputs"].items() module_report["inputs"].items()):
):
if comparison.get("match", False): if comparison.get("match", False):
status = "" status = ""
details = "Tensors match" details = "Tensors match"
@ -595,13 +512,23 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
details = comparison["error"] details = comparison["error"]
else: else:
status = "" status = ""
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, " details = (
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, " f"Max abs diff: "
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}" f"{comparison.get('max_abs_diff', 'N/A')}, ")
details += (
f"Max relative diff: "
f"{comparison.get('max_rel_diff', 'N/A')}, ")
details += (
f"Diff elements: "
f"{comparison.get('num_diff_elements', 'N/A')}/"
f"{comparison.get('total_elements', 'N/A')}")
if "file1" in comparison and "file2" in comparison: if "file1" in comparison and "file2" in comparison:
details += f"<br>Files: `{comparison['file1']}` vs `{comparison['file2']}`" details += (
f"<br>Files: `{comparison['file1']}` vs "
f"`{comparison['file2']}`")
lines.append(f"| {tensor_name} | {status} | {details} |") lines.append(
f"| {tensor_name} | {status} | {details} |")
lines.append("") lines.append("")
@ -613,8 +540,7 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
lines.append("|--------|--------|---------|") lines.append("|--------|--------|---------|")
for tensor_name, comparison in sorted( for tensor_name, comparison in sorted(
module_report["outputs"].items() module_report["outputs"].items()):
):
if comparison.get("match", False): if comparison.get("match", False):
status = "" status = ""
details = "Tensors match" details = "Tensors match"
@ -623,11 +549,19 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
details = comparison["error"] details = comparison["error"]
else: else:
status = "" status = ""
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, " details = (
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, " f"Max abs diff: "
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}" f"{comparison.get('max_abs_diff', 'N/A')}, ")
details += (
f"Max relative diff: "
f"{comparison.get('max_rel_diff', 'N/A')}, ")
details += (
f"Diff elements: "
f"{comparison.get('num_diff_elements', 'N/A')}/"
f"{comparison.get('total_elements', 'N/A')}")
lines.append(f"| {tensor_name} | {status} | {details} |") lines.append(
f"| {tensor_name} | {status} | {details} |")
lines.append("") lines.append("")
@ -636,15 +570,16 @@ def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Compare intermediate logging outputs from two different runs." description=
) "Compare intermediate logging outputs from two different runs.")
parser.add_argument( parser.add_argument("--dir1",
"--dir1", required=True, help="First intermediate logging directory" required=True,
) help="First intermediate logging directory")
parser.add_argument( parser.add_argument("--dir2",
"--dir2", required=True, help="Second intermediate logging directory" required=True,
) help="Second intermediate logging directory")
parser.add_argument("--output", help="Output file for the report (default: stdout)") parser.add_argument("--output",
help="Output file for the report (default: stdout)")
parser.add_argument( parser.add_argument(
"--rtol", "--rtol",
type=float, type=float,
@ -658,11 +593,12 @@ def main():
help="Absolute tolerance for tensor comparison (default: 1e-8)", help="Absolute tolerance for tensor comparison (default: 1e-8)",
) )
parser.add_argument( parser.add_argument(
"--steps", help="Comma-separated list of steps to compare (default: all)" "--steps",
) help="Comma-separated list of steps to compare (default: all)")
parser.add_argument( parser.add_argument(
"--modules", "--modules",
help="Comma-separated list of module name patterns to compare (default: all)", help="Comma-separated list of module name patterns to compare "
"(default: all)",
) )
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",

View File

@ -17,8 +17,7 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
from functools import cached_property from functools import cached_property
from importlib.util import find_spec from importlib.util import find_spec
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, List, Set) Protocol, TypeVar, Union, cast, get_args)
from re import Pattern
import regex as re import regex as re
import torch import torch
@ -4026,63 +4025,65 @@ class KVEventsConfig:
@config @config
@dataclass @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class IntermediateLoggingConfig: class IntermediateLoggingConfig:
"""Configuration for intermediate tensor logging.""" """Configuration for intermediate tensor logging."""
output_dir: str = "/tmp/vllm_intermediates" output_dir: str = "/tmp/vllm_intermediates"
"""Directory where to save the intermediate tensors.""" """Directory where to save the intermediate tensors."""
reload_input_dir: Optional[str] = None reload_input_dir: Optional[str] = None
"""Directory where to load the inputs for the steps/modules. """Directory where to load the inputs for the steps/modules.
This is used when we want to check per module numerical gaps instead This is used when we want to check per module numerical gaps instead
of accumulated gap to further dive into the actual numerical issues.""" of accumulated gap to further dive into the actual numerical issues."""
module_call_match: Optional[List[str]] = None module_call_match: Optional[list[str]] = None
"""Match modules by name regex and call index ( """Match modules by name regex and call index (
a module can be called multiple times in a step) a module can be called multiple times in a step)
List of regex:call_idx, call_idx is -1 for default for all calls """ List of regex:call_idx, call_idx is -1 for default for all calls """
log_step_ids: List[int] = field(default_factory=lambda: [0]) log_step_ids: list[int] = field(default_factory=lambda: [0])
"""List of step IDs to log (empty list means log all steps).""" """List of step IDs to log (empty list means log all steps)."""
log_post_fwd_inputs: bool = False log_post_fwd_inputs: bool = False
"""Whether logging inputs after forwards for each module""" """Whether logging inputs after forwards for each module"""
max_tensor_size: Optional[int] = None max_tensor_size: Optional[int] = None
"""Maximum number of elements in tensors to log (None = no limit).""" """Maximum number of elements in tensors to log (None = no limit)."""
enabled: bool = True enabled: bool = True
"""Whether logging is enabled.""" """Whether logging is enabled."""
device_names: List[str] = field(default_factory=list) device_names: list[str] = field(default_factory=list)
"""List of device names to log (empty list means log all devices).""" """List of device names to log (empty list means log all devices)."""
_compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False) _compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
init=False)
"""Compiled regex patterns for module filtering.""" """Compiled regex patterns for module filtering."""
_module_call: dict[str, int] = field(default_factory=dict, init=False) _module_call: dict[str, int] = field(default_factory=dict, init=False)
_step_id_set: Set[int] = field(default_factory=set, init=False) _step_id_set: set[int] = field(default_factory=set, init=False)
"""Set of step IDs for faster lookup.""" """Set of step IDs for faster lookup."""
_output_run_dir: str = "/tmp/vllm_intermediates" _output_run_dir: str = "/tmp/vllm_intermediates"
"""Unique directory to save single run/serve logging result.""" """Unique directory to save single run/serve logging result."""
def __post_init__(self): def __post_init__(self):
"""Initialize derived fields after instance creation.""" """Initialize derived fields after instance creation."""
self._compile_regex_patterns() self._compile_regex_patterns()
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4()) self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
self._step_id_set = set(self.log_step_ids) self._step_id_set = set(self.log_step_ids)
def _compile_regex_patterns(self): def _compile_regex_patterns(self):
"""Compile regex patterns for module name filtering.""" """Compile regex patterns for module name filtering."""
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
self._compiled_module_matches = [] self._compiled_module_matches = []
if self.module_call_match is None: if self.module_call_match is None:
logger.info("No module name regex patterns provided, will log all modules") logger.info(
"No module name regex patterns provided, will log all modules")
return return
# Compile all patterns # Compile all patterns
for regex_pattern_call_idx in self.module_call_match: for regex_pattern_call_idx in self.module_call_match:
try: try:
@ -4091,15 +4092,16 @@ class IntermediateLoggingConfig:
call_idx = -1 call_idx = -1
if len(splits) > 1: if len(splits) > 1:
call_idx = int(splits[1]) call_idx = int(splits[1])
compiled_pattern: Pattern[str] = re.compile(regex_pattern) compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
self._compiled_module_calls[compiled_pattern] = call_idx self._compiled_module_calls[compiled_pattern] = call_idx
logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'") logger.info("Successfully compiled regex pattern: '%s'",
regex_pattern)
except Exception as e: except Exception as e:
logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") logger.error("Failed to parse module_call_match '%s': %s",
raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e regex_pattern_call_idx, e)
logger.info("Compiled %d regex patterns",
logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns") len(self._compiled_module_calls))
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""Convert the config to a dictionary for serialization.""" """Convert the config to a dictionary for serialization."""
@ -4111,12 +4113,12 @@ class IntermediateLoggingConfig:
"enabled": self.enabled, "enabled": self.enabled,
"device_names": self.device_names "device_names": self.device_names
} }
@classmethod @classmethod
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig": def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
"""Parse the CLI value for the speculative config.""" """Parse the CLI value for the speculative config."""
return cls(**dict_value) return cls(**dict_value)
@property @property
def output_run_dir(self) -> str: def output_run_dir(self) -> str:
return self._output_run_dir return self._output_run_dir
@ -4138,7 +4140,6 @@ class IntermediateLoggingConfig:
hash_str = hashlib.md5(str(factors).encode(), hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest() usedforsecurity=False).hexdigest()
return hash_str return hash_str
class CompilationLevel: class CompilationLevel:

View File

@ -27,13 +27,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend, DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1, GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, IntermediateLoggingConfig, HfOverrides, IntermediateLoggingConfig,
KVEventsConfig, KVTransferConfig, KVEventsConfig, KVTransferConfig, LoadConfig,
LoadConfig, LogprobsMode, LoRAConfig, ModelConfig, LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelDType, ModelImpl, MultiModalConfig, ModelImpl, MultiModalConfig, ObservabilityConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, RunnerOption, SchedulerConfig, SchedulerPolicy,
SchedulerPolicy, SpeculativeConfig, TaskOption, SpeculativeConfig, TaskOption, TokenizerMode,
TokenizerMode, VllmConfig, get_attr_docs, get_field) VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
@ -400,7 +400,7 @@ class EngineArgs:
str] = ModelConfig.logits_processor_pattern str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None speculative_config: Optional[Dict[str, Any]] = None
show_hidden_metrics_for_version: Optional[str] = \ show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version ObservabilityConfig.show_hidden_metrics_for_version
@ -773,10 +773,13 @@ class EngineArgs:
default=None, default=None,
help="The configurations for intermediate loggings. Should be a " help="The configurations for intermediate loggings. Should be a "
"JSON string.") "JSON string.")
intermediate_log_group.add_argument("--intermediate-log-config-path", type=str, intermediate_log_group.add_argument(
help="The path to the configurations for intermediate loggings. Should be a string.") "--intermediate-log-config-path",
type=str,
help="The path to the configurations for intermediate loggings. "
"Should be a string.")
# Observability arguments # Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig) observability_kwargs = get_kwargs(ObservabilityConfig)
observability_group = parser.add_argument_group( observability_group = parser.add_argument_group(
@ -865,9 +868,6 @@ class EngineArgs:
vllm_group.add_argument("--additional-config", vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"]) **vllm_kwargs["additional_config"])
# Other arguments # Other arguments
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
@ -979,11 +979,9 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load, use_tqdm_on_load=self.use_tqdm_on_load,
pt_load_map_location=self.pt_load_map_location, pt_load_map_location=self.pt_load_map_location,
) )
def create_intermediate_log_config( def create_intermediate_log_config(
self, self, ) -> Optional[IntermediateLoggingConfig]:
) -> Optional[IntermediateLoggingConfig]:
"""Initializes and returns an IntermediateLoggingConfig object based on """Initializes and returns an IntermediateLoggingConfig object based on
`intermediate_log_config` or `intermediate_log_config_path`. `intermediate_log_config` or `intermediate_log_config_path`.
""" """
@ -991,7 +989,7 @@ class EngineArgs:
return IntermediateLoggingConfig.from_dict( return IntermediateLoggingConfig.from_dict(
self.intermediate_log_config) self.intermediate_log_config)
if self.intermediate_log_config_path is not None: if self.intermediate_log_config_path is not None:
with open(self.intermediate_log_config_path, "r") as f: with open(self.intermediate_log_config_path) as f:
return IntermediateLoggingConfig.from_dict(json.load(f)) return IntermediateLoggingConfig.from_dict(json.load(f))
return None return None
@ -1235,8 +1233,7 @@ class EngineArgs:
disable_log_stats=self.disable_log_stats, disable_log_stats=self.disable_log_stats,
) )
intermediate_log_config = self.create_intermediate_log_config( intermediate_log_config = self.create_intermediate_log_config()
)
# Reminder: Please update docs/features/compatibility_matrix.md # Reminder: Please update docs/features/compatibility_matrix.md
# If the feature combo become valid # If the feature combo become valid

View File

@ -7,8 +7,6 @@ This module provides functionality to capture and save intermediate tensors
(inputs and outputs) from PyTorch modules during forward passes. (inputs and outputs) from PyTorch modules during forward passes.
""" """
import json
import os
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
@ -17,8 +15,6 @@ import torch
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from vllm.config import IntermediateLoggingConfig from vllm.config import IntermediateLoggingConfig
# Import logger from vllm
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@ -91,7 +87,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
return False return False
# If no patterns are defined, log all modules # If no patterns are defined, log all modules
if not config._compiled_module_calls: if not config._compiled_module_calls:
logger.debug("No patterns defined, will log module: %s", module_name)
set_il_module_name(module, module_name) set_il_module_name(module, module_name)
set_il_module_call_idx(module, -1) set_il_module_call_idx(module, -1)
return True return True
@ -115,7 +110,6 @@ def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
def is_log_enabled(config): def is_log_enabled(config):
if not config or not config.enabled: if not config or not config.enabled:
logger.debug("Not logging because config not enabled")
return False return False
if torch.compiler.is_compiling(): if torch.compiler.is_compiling():
logger.debug("Not logging because torch.compile is in progress") logger.debug("Not logging because torch.compile is in progress")
@ -161,151 +155,7 @@ def get_current_il_config():
return _global_config return _global_config
def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any: def save_tensors(tensor: Any, file_path: str) -> Any:
try:
# Convert inputs to JSON-serializable format
intermediates_json = convert_intermediates_to_json(intermediates)
with open(path, "w") as f:
json.dump(intermediates_json, f, indent=2)
logger.debug("Saved all intermediates as JSON to %s", path)
except Exception as e:
logger.warning("Failed to save intermediates as JSON: %s", e)
import traceback
logger.warning(traceback.format_exc())
def convert_intermediates_to_json(tensor: Any) -> Any:
"""Convert a intermediates(including tensor) to a JSON-serializable
representation.
Args:
intermediates: The intermediates to convert.
Returns:
A JSON-serializable representation of the tensor.
"""
if isinstance(tensor, torch.Tensor):
try:
result = {
"type": "tensor",
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"numel": tensor.numel(),
}
return result
except Exception as e:
# Handle any errors in tensor conversion
return {
"type": "tensor_error",
"error": str(e),
"tensor_type": str(type(tensor)),
}
elif isinstance(tensor, (list, tuple)):
# For lists/tuples, recursively convert each element
container_type = "list" if isinstance(tensor, list) else "tuple"
# If it's a large list, only include a sample
if len(tensor) > 20:
return {
"type": container_type,
"length": len(tensor),
"sample": [
convert_intermediates_to_json(item) for item in tensor[:100]
],
"note": f"Showing only first 20 of {len(tensor)} items",
}
else:
return {
"type": container_type,
"items": [convert_intermediates_to_json(item) for item in tensor],
}
elif isinstance(tensor, dict):
# For dictionaries, recursively convert each value
if len(tensor) > 20:
# For large dicts, only include keys and a sample of values
keys = list(tensor.keys())
sample_keys = keys[:20]
return {
"type": "dict",
"length": len(tensor),
"keys": keys,
"sample": {
k: convert_intermediates_to_json(tensor[k]) for k in sample_keys
},
"note": f"Showing only first 20 of {len(tensor)} items",
}
else:
return {
"type": "dict",
"items": {
k: convert_intermediates_to_json(v) for k, v in tensor.items()
},
}
elif tensor is None:
return None
elif isinstance(tensor, (int, float, bool, str)):
# Primitive types can be directly serialized
return tensor
else:
# For other types, use string representation
return {"type": str(type(tensor).__name__), "string_repr": str(tensor)}
def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool:
"""Utility function to dump tensor metadata to a file.
Args:
tensor: The tensor to dump.
file_path: Base path where to save the tensor (without extension).
"""
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None:
return False
if (
intermediate_log_config.max_tensor_size is not None
and tensor.numel() > intermediate_log_config.max_tensor_size
):
# Save tensor metadata instead of full tensor
tensor_info = {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"device": str(tensor.device),
"numel": tensor.numel(),
"skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size "
f"{intermediate_log_config.max_tensor_size}",
}
os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True)
with open(f"{file_path}.json", "w") as f:
json.dump(tensor_info, f, indent=2)
return True
return False
def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any:
if reload_dir is None:
return None
try:
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None
replace_dir = str(intermediate_log_config.output_run_dir)
reload_path = save_path.replace(replace_dir, reload_dir)
logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path)
return torch.load(reload_path, map_location=tensor.device)
except Exception as e:
logger.warning("Failed to load tensor from %s: %s", reload_dir, e)
return tensor
def save_tensors(
tensor: Any, file_path: str, reload_input_dir: Optional[str] = None
) -> Any:
"""Utility function to dump tensor to a file. """Utility function to dump tensor to a file.
Args: Args:
@ -314,52 +164,32 @@ def save_tensors(
file_path: Base path where to save the tensor (without extension). file_path: Base path where to save the tensor (without extension).
""" """
# Also save the actual tensor data for tensors
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
# Check if tensor is too large
if save_tensors_metadata_if_too_large(tensor, file_path):
return
# Get device name
device_name = str(tensor.device) device_name = str(tensor.device)
# Skip if device filtering is enabled and this device should not be
# logged
intermediate_log_config = get_current_il_config() intermediate_log_config = get_current_il_config()
if not should_log_device(intermediate_log_config, device_name): if not should_log_device(intermediate_log_config, device_name):
logger.debug(
"Skipping tensor on device %s due to device filter", device_name
)
return tensor return tensor
# Append device name to file path
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt" pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
try: try:
# Save tensor directly without detaching or moving to CPU
torch.save(tensor, pt_path) torch.save(tensor, pt_path)
reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir) logger.debug("Saved tensor of shape %s to %s", tensor.shape,
if reloaded_tensor is not None: pt_path)
return reloaded_tensor
logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path)
except Exception as e: except Exception as e:
logger.warning("Failed to save tensor to %s: %s", pt_path, e) logger.warning("Failed to save tensor to %s: %s", pt_path, e)
return tensor return tensor
if isinstance(tensor, (list, tuple)): if isinstance(tensor, (list, tuple)):
# For collections, also save each item individually
reloaded_inputs = []
for i, item in enumerate(tensor): for i, item in enumerate(tensor):
reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir) save_tensors(item, f"{file_path}_{i}")
reloaded_inputs.append(reloaded) return tensor
return tuple(reloaded_inputs) if reloaded_inputs else tensor
if isinstance(tensor, dict): if isinstance(tensor, dict):
reloaded_inputs = {}
# For dictionaries, also save each value individually
for k, v in tensor.items(): for k, v in tensor.items():
reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir) save_tensors(v, f"{file_path}_{k}")
reloaded_inputs[k] = reloaded return tensor
return reloaded_inputs if reloaded_inputs else tensor
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None: def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to increment the global step counter after a forward pass. """Hook to increment the global step counter after a forward pass.
Args: Args:
@ -381,7 +211,8 @@ def _prepare_module_log_dir(
is_pre_fwd: bool = False, is_pre_fwd: bool = False,
) -> Path: ) -> Path:
# Create a unique directory for this step if not # Create a unique directory for this step if not
dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}" dump_dir = Path(
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir.mkdir(exist_ok=True, parents=True) dump_dir.mkdir(exist_ok=True, parents=True)
# Create module directory # Create module directory
@ -393,7 +224,8 @@ def _prepare_module_log_dir(
if is_pre_fwd: if is_pre_fwd:
_log_module_call(intermediate_log_config, module_name + suffix) _log_module_call(intermediate_log_config, module_name + suffix)
module_dir.mkdir(exist_ok=True, parents=True) module_dir.mkdir(exist_ok=True, parents=True)
logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir) logger.debug("Logging module %s inputs/outputs to %s", module_name,
module_dir)
return module_dir return module_dir
@ -401,13 +233,8 @@ def _log_module_call(
intermediate_log_config: IntermediateLoggingConfig, intermediate_log_config: IntermediateLoggingConfig,
module_name: str, module_name: str,
) -> None: ) -> None:
logger.debug("Logging module call for %s", module_name) file = (Path(intermediate_log_config.output_run_dir) /
# write module name and call to step: f"step_{get_step()}" / "module_calls.txt")
file = (
Path(intermediate_log_config.output_run_dir)
/ f"step_{get_step()}"
/ "module_calls.txt"
)
with open(file, "a") as f: with open(file, "a") as f:
f.write(f"{module_name}\n") f.write(f"{module_name}\n")
@ -425,7 +252,8 @@ def get_current_step_module_call(module_name: str) -> int:
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0) return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]: def prepare_log_current_fwd(module,
is_pre_fwd: bool = False) -> Optional[Path]:
intermediate_log_config = get_current_il_config() intermediate_log_config = get_current_il_config()
if intermediate_log_config is None or not intermediate_log_config.enabled: if intermediate_log_config is None or not intermediate_log_config.enabled:
return None return None
@ -443,15 +271,14 @@ def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
if is_pre_fwd: if is_pre_fwd:
update_current_step_module_call(module_name) update_current_step_module_call(module_name)
if should_log: if should_log:
log_dir = _prepare_module_log_dir( log_dir = _prepare_module_log_dir(intermediate_log_config,
intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd module_name,
) is_pre_fwd=is_pre_fwd)
return log_dir return log_dir
def log_pre_fwd_hook( def log_pre_fwd_hook(module: torch.nn.Module,
module: torch.nn.Module, inputs: tuple[Any, ...] inputs: tuple[Any, ...]) -> tuple[Any, ...]:
) -> tuple[Any, ...]:
"""Hook to capture module inputs before forward pass. """Hook to capture module inputs before forward pass.
Args: Args:
@ -462,27 +289,12 @@ def log_pre_fwd_hook(
The unchanged inputs. The unchanged inputs.
""" """
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True): if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
dump_intermediates_to_json(inputs, log_dir / "inputs.json") save_tensors(inputs, str(log_dir / "inputs"))
intermediate_log_config = get_current_il_config()
if intermediate_log_config is not None:
reload_input_dir = getattr(
intermediate_log_config,
"reload_input_dir",
"/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5",
)
else:
reload_input_dir = None
reloaded_inputs = save_tensors(
inputs, str(log_dir / "inputs"), reload_input_dir
)
if reloaded_inputs is not None:
return reloaded_inputs
return inputs return inputs
def log_post_fwd_hook( def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any outputs: Any) -> None:
) -> None:
"""Hook to capture module outputs after forward pass. """Hook to capture module outputs after forward pass.
Args: Args:
@ -491,12 +303,11 @@ def log_post_fwd_hook(
outputs: The outputs from the module's forward function. outputs: The outputs from the module's forward function.
""" """
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False): if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
dump_intermediates_to_json(outputs, log_dir / "outputs.json")
save_tensors(outputs, str(log_dir / "outputs")) save_tensors(outputs, str(log_dir / "outputs"))
intermediate_log_config = get_current_il_config() intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None, "IL config should not be None" assert intermediate_log_config is not None, \
"IL config should not be None"
if intermediate_log_config.log_post_fwd_inputs: if intermediate_log_config.log_post_fwd_inputs:
dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json")
save_tensors(inputs, str(log_dir / "post_fwd_inputs")) save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
@ -532,14 +343,14 @@ class IntermediatesLogger:
def __init__(self, config: IntermediateLoggingConfig): def __init__(self, config: IntermediateLoggingConfig):
self.config = config self.config = config
self.hooks: list[ self.hooks: list[tuple[str, str, Optional[RemovableHandle],
tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]] Optional[RemovableHandle]]] = []
] = []
logger.debug("Created IntermediatesLogger with config: %s", config) logger.debug("Created IntermediatesLogger with config: %s", config)
path = Path(config.output_run_dir) path = Path(config.output_run_dir)
path.mkdir(exist_ok=True, parents=True) path.mkdir(exist_ok=True, parents=True)
# Log configuration # Log configuration
logger.info("Intermediates will be logged in %s", config.output_run_dir) logger.info("Intermediates will be logged in %s",
config.output_run_dir)
def register_hooks(self, model: torch.nn.Module) -> None: def register_hooks(self, model: torch.nn.Module) -> None:
"""Register hooks for the model. """Register hooks for the model.
@ -551,13 +362,11 @@ class IntermediatesLogger:
for name, module in model.named_modules(): for name, module in model.named_modules():
if name and should_log_module(self.config, name, module): if name and should_log_module(self.config, name, module):
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook) pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
logger.debug( logger.debug("Registered pre_fwd hook for %s",
"Registered pre_fwd hook for %s", module.__class__.__name__ module.__class__.__name__)
)
post_hook = module.register_forward_hook(log_post_fwd_hook) post_hook = module.register_forward_hook(log_post_fwd_hook)
logger.debug( logger.debug("Registered post_fwd hook for %s",
"Registered post_fwd hook for %s", module.__class__.__name__ module.__class__.__name__)
)
self.hooks.append((name, module, pre_hook, post_hook)) self.hooks.append((name, module, pre_hook, post_hook))
# Register a step counter hook for the root model # Register a step counter hook for the root model
@ -578,7 +387,8 @@ class IntermediatesLogger:
def register_intermediate_hooks( def register_intermediate_hooks(
model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs model: torch.nn.Module,
config: Optional[IntermediateLoggingConfig] = None
) -> IntermediatesLogger: ) -> IntermediatesLogger:
"""Register hooks to log intermediate tensors for a model. """Register hooks to log intermediate tensors for a model.
@ -590,10 +400,6 @@ def register_intermediate_hooks(
Returns: Returns:
An IntermediatesLogger instance that can be used to manage the hooks. An IntermediatesLogger instance that can be used to manage the hooks.
""" """
if config is None:
# Create config from kwargs
config = IntermediateLoggingConfig.from_dict(kwargs)
logger_instance = IntermediatesLogger(config) logger_instance = IntermediatesLogger(config)
logger_instance.register_hooks(model) logger_instance.register_hooks(model)
return logger_instance return logger_instance

View File

@ -27,12 +27,12 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -6,10 +6,11 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig, IntermediateLoggingConfig from vllm.config import IntermediateLoggingConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.intermediates.intermediates_logging import (
register_intermediate_hooks)
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
logger = init_logger(__name__) logger = init_logger(__name__)
@ -64,27 +65,26 @@ class WorkerBase(WorkerBaseV0):
def check_health(self) -> None: def check_health(self) -> None:
"""Basic health check (override for device-specific checks).""" """Basic health check (override for device-specific checks)."""
return return
def register_intermediate_hooks(self, def register_intermediate_hooks(
config: Optional[IntermediateLoggingConfig] = None, self, config: Optional[IntermediateLoggingConfig] = None) -> None:
**kwargs) -> None:
"""Register hooks for intermediate tensor logging. """Register hooks for intermediate tensor logging.
This method is called via collective_rpc from the engine core. This method is called via collective_rpc from the engine core.
It registers hooks on the model to dump intermediate tensors during execution. It registers hooks on the model to dump intermediate tensors during
execution.
Args: Args:
config: Configuration for intermediate logging. If provided, this takes precedence over kwargs. config: Configuration for intermediate logging. If provided, this
takes precedence over kwargs.
""" """
if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None: if self.model_runner is None or not hasattr(
logger.error("Could not register intermediate hooks: model_runner.model is not accessible") self.model_runner, "model") or self.model_runner.model is None:
logger.error("Could not register intermediate hooks: "
"model_runner.model is not accessible")
return return
model = self.model_runner.model model = self.model_runner.model
try: try:
# Register hooks register_intermediate_hooks(model, config)
register_intermediate_hooks(model, config, **kwargs) except Exception:
# Store the logger instance for potential later hook removal logger.exception("Error registering intermediate hooks")
except Exception as e:
logger.info("Successfully registered intermediate hooks")
logger.error("Error registering intermediate hooks", exc_info=True)

View File

@ -128,21 +128,21 @@ class WorkerBase:
def vocab_size(self) -> int: def vocab_size(self) -> int:
"""Get vocabulary size from model configuration.""" """Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
def register_intermediate_hooks(self, config=None, **kwargs) -> None: def register_intermediate_hooks(self, config=None) -> None:
"""Register hooks for intermediate tensor logging. """Register hooks for intermediate tensor logging.
This method is a stub for v0 workers. The actual implementation is in v1 workers. This method is a stub for v0 workers. The actual implementation is
It's included here for compatibility with the collective_rpc mechanism. in v1 workers. It's included here for compatibility with the
collective_rpc mechanism.
Args: Args:
config: Configuration for intermediate logging. config: Configuration for intermediate logging.
**kwargs: Configuration parameters for intermediate logging.
These are ignored in v0 workers.
""" """
logger.warning( logger.warning(
"register_intermediate_hooks is not implemented in v0 workers. " "register_intermediate_hooks is not implemented in v0 workers. "
"This is only available in v1 workers. No hooks will be registered.") "This is only available in v1 workers. No hooks will be registered."
)
return None return None