mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 10:37:04 +08:00
add il tool
more changes Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> fix tp Signed-off-by: Lu Fang <fanglu@fb.com> add comparison tool tmp add unit test and fix format Signed-off-by: Lu Fang <fanglu@fb.com> add comparison script and documentation Signed-off-by: Lu Fang <fanglu@fb.com> provide default intermediate logging Signed-off-by: Lu Fang <fanglu@fb.com> optional register il Signed-off-by: Lu Fang <fanglu@fb.com> add input reload and improve intermediate compare
This commit is contained in:
parent
c6c9122d50
commit
d8bff253d7
136
docs/contributing/intermediate_logging.md
Normal file
136
docs/contributing/intermediate_logging.md
Normal file
@ -0,0 +1,136 @@
|
||||
# Intermediate Tensor Logging
|
||||
|
||||
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
|
||||
|
||||
## Overview
|
||||
|
||||
The intermediate tensor logging feature enables you to:
|
||||
|
||||
- Log input and output tensors from a configured set of filters
|
||||
- Filter modules by name using regex patterns
|
||||
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
|
||||
- Filter tensors by device
|
||||
- Filter whole model fwd step id
|
||||
|
||||
This is manily useful for debugging model accucacy gaps with 2 runs
|
||||
|
||||
## Usage
|
||||
|
||||
### Enabling via parameters or config file
|
||||
|
||||
**Offline Inference example**
|
||||
|
||||
Dump all modules, all devices for step 0 (default behavior)
|
||||
|
||||
```bash
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}'
|
||||
```
|
||||
|
||||
Dump first layers module, all devices for step 0
|
||||
|
||||
```bash
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}'
|
||||
```
|
||||
|
||||
Dump customized layers, devices, steps through a config file
|
||||
|
||||
The configuration file should be a JSON file with the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"output_dir": "/tmp/vllm_intermediates",
|
||||
"module_call_match": ["layers\\.0\\.(?!.*rotary_emb).*", "rotary_emb:0", "embed_tokens", "model\\.norm"],
|
||||
"log_step_ids": [0, 1],
|
||||
"device_names": ["cuda:0"]
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config-path $HOME/intermediate_logging_config.json
|
||||
```
|
||||
|
||||
|
||||
#### Configuration Parameters
|
||||
|
||||
| Parameter | Type | Description | Default |
|
||||
|-----------|------|-------------|---------|
|
||||
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
|
||||
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
|
||||
| `log_step_ids` | array | List of step IDs to log | `[0]` |
|
||||
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
|
||||
| `device_names` | array | List of device names to log | `[]` (log all devices) |
|
||||
|
||||
### Output Directory Structure
|
||||
|
||||
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
|
||||
|
||||
```
|
||||
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
|
||||
└── step_0
|
||||
├── model.embed_tokens
|
||||
│ ├── inputs_0_cuda_0.pt
|
||||
│ ├── inputs.json
|
||||
│ ├── outputs_cuda_0.pt
|
||||
│ └── outputs.json
|
||||
├── model.layers.0.input_layernorm
|
||||
│ ├── inputs_0_cuda_0.pt
|
||||
│ ├── inputs.json
|
||||
│ ├── outputs_cuda_0.pt
|
||||
│ └── outputs.json
|
||||
└── step_1/
|
||||
└── ...
|
||||
```
|
||||
|
||||
Each tensor is saved in two formats:
|
||||
1. `.json` files containing metadata and small tensor values
|
||||
2. `.pt` files containing the full PyTorch tensors (can be loaded with `torch.load()`)
|
||||
|
||||
## Comparing Intermediate Logging Results
|
||||
|
||||
vLLM provides a tool called `compare_intermediate.py` to compare intermediate tensors between two different runs. This is particularly useful for debugging accuracy differences or verifying that code changes don't affect model outputs.
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
python tools/compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options]
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
| Option | Description | Default |
|
||||
|--------|-------------|---------|
|
||||
| `--dir1` | First intermediate logging directory | (required) |
|
||||
| `--dir2` | Second intermediate logging directory | (required) |
|
||||
| `--output` | Output file for the report | stdout |
|
||||
| `--rtol` | Relative tolerance for tensor comparison | 1e-5 |
|
||||
| `--atol` | Absolute tolerance for tensor comparison | 1e-8 |
|
||||
| `--steps` | Comma-separated list of steps to compare | all |
|
||||
| `--modules` | Comma-separated list of module name patterns to compare | all |
|
||||
| `--verbose` | Include detailed information about each tensor | false |
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
# Compare all tensors from two different runs
|
||||
python tools/compare_intermediate.py --dir1 /tmp/vllm_intermediates/run1 --dir2 /tmp/vllm_intermediates/run2
|
||||
|
||||
# Compare only specific modules and steps with custom tolerance
|
||||
python tools/compare_intermediate.py \
|
||||
--dir1 /tmp/vllm_intermediates/run1 \
|
||||
--dir2 /tmp/vllm_intermediates/run2 \
|
||||
--steps 0,1 \
|
||||
--modules ".*attention.*,.*mlp.*" \
|
||||
--rtol 1e-4 \
|
||||
--atol 1e-7 \
|
||||
--output comparison_report.md
|
||||
```
|
||||
|
||||
### Output
|
||||
|
||||
The tool generates a detailed markdown report that includes:
|
||||
|
||||
- Overall summary of matching and mismatched tensors
|
||||
- Per-module comparison results
|
||||
- Detailed tensor differences (when using `--verbose`)
|
||||
|
||||
This makes it easy to identify which specific tensors differ between runs and by how much.
|
||||
325
tests/v1/test_intermediates_logging.py
Normal file
325
tests/v1/test_intermediates_logging.py
Normal file
@ -0,0 +1,325 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for the intermediate tensor logging functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from os.path import isdir
|
||||
import shutil
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
from vllm.v1.intermediates.intermediates_logging import (get_current_il_config,
|
||||
get_step, increment_step,
|
||||
intermediate_logging,
|
||||
register_intermediate_hooks,
|
||||
reset_step,
|
||||
should_log_device,
|
||||
should_log_module,
|
||||
should_log_step)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
"""A simple model for testing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(10, 20)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(20, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_dir():
|
||||
"""Create a temporary directory for test outputs."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
# Clean up after the test
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_model():
|
||||
"""Create a simple model for testing."""
|
||||
return SimpleModel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def il_config(temp_output_dir):
|
||||
"""Create a basic IntermediateLoggingConfig for testing."""
|
||||
return IntermediateLoggingConfig(output_dir=temp_output_dir,
|
||||
enabled=True,
|
||||
log_step_ids=[0, 1],
|
||||
module_call_match=[".*linear.*"])
|
||||
|
||||
|
||||
def test_step_counter():
|
||||
"""Test the step counter functionality."""
|
||||
# Reset the step counter
|
||||
reset_step()
|
||||
assert get_step() == 0
|
||||
|
||||
# Increment the step counter
|
||||
increment_step()
|
||||
assert get_step() == 1
|
||||
|
||||
# Increment again
|
||||
increment_step()
|
||||
assert get_step() == 2
|
||||
|
||||
# Reset again
|
||||
reset_step()
|
||||
assert get_step() == 0
|
||||
|
||||
|
||||
def test_intermediate_logging_context_manager():
|
||||
"""Test the intermediate_logging context manager."""
|
||||
# Create a config
|
||||
config = IntermediateLoggingConfig(enabled=True)
|
||||
|
||||
# Initially, there should be no global config
|
||||
assert get_current_il_config() is None
|
||||
|
||||
# Use the context manager
|
||||
with intermediate_logging(config):
|
||||
# Inside the context, the global config should be set
|
||||
assert get_current_il_config() is not None
|
||||
assert get_current_il_config().enabled is True
|
||||
|
||||
# After the context, the global config should be None again
|
||||
assert get_current_il_config() is None
|
||||
|
||||
# Test with a different config
|
||||
config2 = IntermediateLoggingConfig(enabled=False)
|
||||
with intermediate_logging(config2):
|
||||
assert get_current_il_config() is not None
|
||||
assert get_current_il_config().enabled is False
|
||||
|
||||
|
||||
def test_should_log_step():
|
||||
"""Test the should_log_step function."""
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Create configs with different step IDs
|
||||
config_all_steps = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
log_step_ids=[] # Empty list means log all steps
|
||||
)
|
||||
config_specific_steps = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
log_step_ids=[0, 1, 2])
|
||||
|
||||
# Test with all steps config
|
||||
with intermediate_logging(config_all_steps):
|
||||
assert should_log_step(config_all_steps) is True # Step 0
|
||||
increment_step()
|
||||
assert should_log_step(config_all_steps) is True # Step 1
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Test with specific steps config
|
||||
with intermediate_logging(config_specific_steps):
|
||||
assert should_log_step(config_specific_steps) is True # Step 0
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is False # Step 1
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is True # Step 2
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is False # Step 3
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is True # Step 4
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_step(config_disabled) is False # Disabled
|
||||
|
||||
|
||||
def test_should_log_device():
|
||||
"""Test the should_log_device function."""
|
||||
# Create configs with different device filters
|
||||
config_all_devices = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
device_names=[] # Empty list means log all devices
|
||||
)
|
||||
config_specific_devices = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
device_names=["cuda:0", "cpu"])
|
||||
|
||||
# Test with all devices config
|
||||
with intermediate_logging(config_all_devices):
|
||||
assert should_log_device(config_all_devices, "cuda:0") is True
|
||||
assert should_log_device(config_all_devices, "cuda:1") is True
|
||||
assert should_log_device(config_all_devices, "cpu") is True
|
||||
|
||||
# Test with specific devices config
|
||||
with intermediate_logging(config_specific_devices):
|
||||
assert should_log_device(config_specific_devices, "cuda:0") is True
|
||||
assert should_log_device(config_specific_devices, "cuda:1") is False
|
||||
assert should_log_device(config_specific_devices, "cpu") is True
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_device(config_disabled, "cuda:0") is False
|
||||
assert should_log_device(config_disabled, "cpu") is False
|
||||
|
||||
|
||||
def test_should_log_module(simple_model):
|
||||
"""Test the should_log_module function."""
|
||||
# Create configs with different module name filters
|
||||
config_all_modules = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
module_call_match=None # None means log all modules
|
||||
)
|
||||
config_specific_modules = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
module_call_match=[".*linear.*"
|
||||
] # Only log modules with "linear" in the name
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
module_call_match=[".*"])
|
||||
|
||||
# Test with all modules config
|
||||
with intermediate_logging(config_all_modules):
|
||||
assert should_log_module(config_all_modules, "linear1",
|
||||
simple_model.linear1) is True
|
||||
assert should_log_module(config_all_modules, "relu",
|
||||
simple_model.relu) is True
|
||||
|
||||
# Test with specific modules config
|
||||
with intermediate_logging(config_specific_modules):
|
||||
assert should_log_module(config_specific_modules, "linear1",
|
||||
simple_model.linear1) is True
|
||||
assert should_log_module(config_specific_modules, "relu",
|
||||
simple_model.relu) is False
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_module(config_disabled, "linear1",
|
||||
simple_model.linear1) is False
|
||||
assert should_log_module(config_disabled, "relu",
|
||||
simple_model.relu) is False
|
||||
|
||||
|
||||
def test_register_hooks(simple_model, il_config):
|
||||
"""Test registering hooks on a model."""
|
||||
# Register hooks
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Check that hooks were registered
|
||||
assert len(logger_instance.hooks) > 0
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
# Check that hooks were removed
|
||||
assert len(logger_instance.hooks) == 0
|
||||
|
||||
|
||||
@mock.patch('vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
|
||||
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
|
||||
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
|
||||
il_config, temp_output_dir):
|
||||
"""Test that forward hooks are called during model execution."""
|
||||
mock_save_tensors.return_value = None
|
||||
# Register hooks
|
||||
with intermediate_logging(il_config):
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(2, 10)
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Forward pass
|
||||
simple_model(input_tensor)
|
||||
|
||||
# Check that the step counter was incremented
|
||||
assert get_step() == 1
|
||||
|
||||
# Check that dump_intermediates_to_json and save_tensors were called
|
||||
assert mock_dump_json.called
|
||||
assert mock_save_tensors.called
|
||||
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
|
||||
def test_end_to_end(simple_model, il_config, temp_output_dir):
|
||||
"""Test the entire intermediate logging workflow end-to-end."""
|
||||
# Register hooks
|
||||
with intermediate_logging(il_config):
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(2, 10)
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Forward pass
|
||||
simple_model(input_tensor)
|
||||
|
||||
# Check that output directories were created
|
||||
root_dir = Path(il_config._output_run_dir)
|
||||
assert root_dir.exists()
|
||||
step_dir = root_dir / "step_0"
|
||||
assert step_dir.exists()
|
||||
|
||||
module_dirs = list(step_dir.glob("*"))
|
||||
print(f"{module_dirs=}")
|
||||
assert len(module_dirs) > 0
|
||||
|
||||
# Check that input and output files were created
|
||||
for module_dir in module_dirs:
|
||||
print(f"{module_dir=}")
|
||||
if os.path.isdir(module_dir):
|
||||
inputs_json = module_dir / "inputs.json"
|
||||
outputs_json = module_dir / "outputs.json"
|
||||
|
||||
# Check that JSON files exist
|
||||
assert inputs_json.exists()
|
||||
assert outputs_json.exists()
|
||||
|
||||
# Check that JSON files contain valid data
|
||||
with open(inputs_json) as f:
|
||||
inputs_data = json.load(f)
|
||||
assert "type" in inputs_data
|
||||
|
||||
with open(outputs_json) as f:
|
||||
outputs_data = json.load(f)
|
||||
assert "type" in outputs_data
|
||||
|
||||
# Check that tensor files exist
|
||||
tensor_files = list(module_dir.glob("*.pt"))
|
||||
assert len(tensor_files) > 0
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
706
tools/compare_intermediate.py
Executable file
706
tools/compare_intermediate.py
Executable file
@ -0,0 +1,706 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Script to compare intermediate logging outputs from two different runs.
|
||||
|
||||
This script compares the tensor outputs from two different intermediate logging
|
||||
directories and generates a report of the differences.
|
||||
|
||||
Usage:
|
||||
python compare_intermediate.py --dir1 /path/to/first/log/dir --dir2 /path/to/second/log/dir [options]
|
||||
|
||||
Options:
|
||||
--dir1 DIR First intermediate logging directory
|
||||
--dir2 DIR Second intermediate logging directory
|
||||
--output FILE Output file for the report (default: stdout)
|
||||
--format {md,json} Output format (default: md)
|
||||
--rtol FLOAT Relative tolerance for tensor comparison (default: 1e-5)
|
||||
--atol FLOAT Absolute tolerance for tensor comparison (default: 1e-8)
|
||||
--steps STEPS Comma-separated list of steps to compare (default: all)
|
||||
--modules MODULES Comma-separated list of module name patterns to compare (default: all)
|
||||
--verbose Include detailed information about each tensor
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def load_tensor(path: Path) -> torch.Tensor:
|
||||
"""Load a tensor from a .pt file."""
|
||||
try:
|
||||
return torch.load(path, map_location="cpu")
|
||||
except Exception as e:
|
||||
print(f"Error loading tensor from {path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_json(path: Path) -> Dict:
|
||||
"""Load a JSON file."""
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error loading JSON from {path}: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def extract_diff_metatada(exception_str: str) -> Dict:
|
||||
try:
|
||||
num_diff_elements = int(
|
||||
re.search(r"Mismatched elements: (\d+) /", exception_str).group(1)
|
||||
)
|
||||
total_elements = int(
|
||||
re.search(r"Mismatched elements: \d+ / (\d+)", exception_str).group(1)
|
||||
)
|
||||
max_abs_diff = float(
|
||||
re.search(
|
||||
r"Greatest absolute difference: ([\d\.e-]+)", exception_str
|
||||
).group(1)
|
||||
)
|
||||
max_rel_diff = float(
|
||||
re.search(
|
||||
r"Greatest relative difference: ([\d\.e-]+)", exception_str
|
||||
).group(1)
|
||||
)
|
||||
return {
|
||||
"num_diff_elements": num_diff_elements,
|
||||
"total_elements": total_elements,
|
||||
"max_abs_diff": max_abs_diff,
|
||||
"max_rel_diff": max_rel_diff,
|
||||
}
|
||||
except Exception:
|
||||
return {"error": exception_str}
|
||||
|
||||
|
||||
def compare_tensors(
|
||||
tensor1: torch.Tensor, tensor2: torch.Tensor, rtol: float, atol: float
|
||||
) -> Dict:
|
||||
"""Compare two tensors and return a dictionary with comparison results."""
|
||||
if tensor1 is None or tensor2 is None:
|
||||
return {"match": False, "error": "One or both tensors are None"}
|
||||
|
||||
if tensor1.shape != tensor2.shape:
|
||||
return {
|
||||
"match": False,
|
||||
"error": f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}",
|
||||
}
|
||||
|
||||
if tensor1.dtype != tensor2.dtype:
|
||||
return {
|
||||
"match": False,
|
||||
"error": f"Dtype mismatch: {tensor1.dtype} vs {tensor2.dtype}",
|
||||
}
|
||||
|
||||
# Check if tensors are close using PyTorch's assert_close
|
||||
try:
|
||||
torch.testing.assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
|
||||
except Exception as e:
|
||||
return {"match": False, **extract_diff_metatada(str(e))}
|
||||
return {"match": True}
|
||||
|
||||
|
||||
def compare_json_values(value1: Any, value2: Any) -> Dict:
|
||||
"""Compare two JSON values and return a dictionary with comparison results."""
|
||||
if type(value1) is not type(value2):
|
||||
return {
|
||||
"match": False,
|
||||
"error": f"Type mismatch: {type(value1).__name__} vs {type(value2).__name__}",
|
||||
}
|
||||
|
||||
if isinstance(value1, dict):
|
||||
# Compare dictionaries
|
||||
all_keys = set(value1.keys()) | set(value2.keys())
|
||||
mismatches = {}
|
||||
|
||||
for key in all_keys:
|
||||
if key not in value1:
|
||||
mismatches[key] = {"error": "Missing in first dict"}
|
||||
elif key not in value2:
|
||||
mismatches[key] = {"error": "Missing in second dict"}
|
||||
else:
|
||||
comparison = compare_json_values(value1[key], value2[key])
|
||||
if not comparison["match"]:
|
||||
mismatches[key] = comparison
|
||||
|
||||
if mismatches:
|
||||
return {"match": False, "mismatches": mismatches}
|
||||
return {"match": True}
|
||||
|
||||
elif isinstance(value1, list):
|
||||
# Compare lists
|
||||
if len(value1) != len(value2):
|
||||
return {
|
||||
"match": False,
|
||||
"error": f"Length mismatch: {len(value1)} vs {len(value2)}",
|
||||
}
|
||||
|
||||
mismatches = {}
|
||||
for i, (item1, item2) in enumerate(zip(value1, value2)):
|
||||
comparison = compare_json_values(item1, item2)
|
||||
if not comparison["match"]:
|
||||
mismatches[i] = comparison
|
||||
|
||||
if mismatches:
|
||||
return {"match": False, "mismatches": mismatches}
|
||||
return {"match": True}
|
||||
|
||||
else:
|
||||
# Compare primitive values
|
||||
if value1 == value2:
|
||||
return {"match": True}
|
||||
else:
|
||||
return {"match": False, "value1": value1, "value2": value2}
|
||||
|
||||
|
||||
def find_tensor_files(directory: Path) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
|
||||
"""
|
||||
Find all tensor files in the given directory.
|
||||
|
||||
Returns a dictionary with the structure:
|
||||
{
|
||||
"step_0": {
|
||||
"module_name_123456": {
|
||||
"inputs": [Path("inputs_0_cuda_0.pt"), ...],
|
||||
"outputs": [Path("output_cuda_0.pt"), ...]
|
||||
},
|
||||
...
|
||||
},
|
||||
...
|
||||
}
|
||||
"""
|
||||
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
|
||||
# Find all step directories
|
||||
step_dirs = [d for d in directory.glob("step_*") if d.is_dir()]
|
||||
|
||||
for step_dir in step_dirs:
|
||||
step_name = step_dir.name
|
||||
|
||||
# Find all module directories
|
||||
module_dirs = [d for d in step_dir.glob("*") if d.is_dir()]
|
||||
|
||||
for module_dir in module_dirs:
|
||||
module_name = module_dir.name
|
||||
|
||||
# Find input tensor files
|
||||
input_tensors = list(module_dir.glob("inputs_*.pt"))
|
||||
if input_tensors:
|
||||
result[step_name][module_name]["inputs"] = input_tensors
|
||||
|
||||
# Find output tensor files
|
||||
output_tensors = list(module_dir.glob("output*.pt"))
|
||||
if output_tensors:
|
||||
result[step_name][module_name]["outputs"] = output_tensors
|
||||
|
||||
# Find JSON metadata files
|
||||
inputs_json = module_dir / "inputs.json"
|
||||
if inputs_json.exists():
|
||||
result[step_name][module_name]["inputs_json"] = [inputs_json]
|
||||
|
||||
outputs_json = module_dir / "outputs.json"
|
||||
if outputs_json.exists():
|
||||
result[step_name][module_name]["outputs_json"] = [outputs_json]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_steps_and_modules(
|
||||
tensor_files: Dict[str, Dict[str, Dict[str, List[Path]]]],
|
||||
steps: Optional[List[str]] = None,
|
||||
module_patterns: Optional[List[str]] = None,
|
||||
) -> Dict[str, Dict[str, Dict[str, List[Path]]]]:
|
||||
"""Filter tensor files by steps and module patterns."""
|
||||
result = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
|
||||
# Filter steps
|
||||
if steps:
|
||||
step_names = [f"step_{step}" for step in steps]
|
||||
steps_to_include = {step: True for step in step_names}
|
||||
else:
|
||||
steps_to_include = {step: True for step in tensor_files.keys()}
|
||||
|
||||
# Compile module patterns
|
||||
if module_patterns:
|
||||
compiled_patterns = [re.compile(pattern) for pattern in module_patterns]
|
||||
else:
|
||||
compiled_patterns = None
|
||||
|
||||
for step_name, modules in tensor_files.items():
|
||||
if step_name not in steps_to_include:
|
||||
continue
|
||||
|
||||
for module_name, file_types in modules.items():
|
||||
# Check if module matches any pattern
|
||||
if compiled_patterns:
|
||||
if not any(
|
||||
pattern.search(module_name) for pattern in compiled_patterns
|
||||
):
|
||||
continue
|
||||
|
||||
result[step_name][module_name] = file_types
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compare_directories(
|
||||
dir1: Path,
|
||||
dir2: Path,
|
||||
rtol: Optional[float] = None,
|
||||
atol: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
module_patterns: Optional[List[str]] = None,
|
||||
) -> Dict:
|
||||
"""Compare two intermediate logging directories and return a report."""
|
||||
# Find tensor files in both directories
|
||||
tensor_files1 = find_tensor_files(dir1)
|
||||
tensor_files2 = find_tensor_files(dir2)
|
||||
|
||||
# Filter by steps and modules
|
||||
if steps or module_patterns:
|
||||
tensor_files1 = filter_steps_and_modules(tensor_files1, steps, module_patterns)
|
||||
tensor_files2 = filter_steps_and_modules(tensor_files2, steps, module_patterns)
|
||||
|
||||
# Get all steps and modules from both directories
|
||||
all_steps = set(tensor_files1.keys()) | set(tensor_files2.keys())
|
||||
|
||||
report = {
|
||||
"dir1": str(dir1),
|
||||
"dir2": str(dir2),
|
||||
"rtol": rtol,
|
||||
"atol": atol,
|
||||
"steps": {},
|
||||
}
|
||||
|
||||
# Compare each step
|
||||
for step in sorted(all_steps):
|
||||
step_report = {
|
||||
"modules": {},
|
||||
"summary": {
|
||||
"total_modules": 0,
|
||||
"matching_modules": 0,
|
||||
"mismatched_modules": 0,
|
||||
"missing_modules": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# Get all modules from both directories for this step
|
||||
modules1 = tensor_files1.get(step, {})
|
||||
modules2 = tensor_files2.get(step, {})
|
||||
# TODO: read from module calls.txt to get the full module list
|
||||
# TODO: check if module calls txt exsits
|
||||
dir1_module_call_file = dir1 / step / "module_calls.txt"
|
||||
if dir1_module_call_file.exists():
|
||||
with open(dir1 / step / "module_calls.txt", "r") as f:
|
||||
all_modules = f.read().splitlines()
|
||||
else:
|
||||
print(
|
||||
"Warnings: the module call orders are missed, ordering using module alphbetics"
|
||||
)
|
||||
all_modules = sorted(set(modules1.keys()) | set(modules2.keys()))
|
||||
step_report["module_call_list"] = []
|
||||
for module in all_modules:
|
||||
module_report = {
|
||||
"inputs": {},
|
||||
"outputs": {},
|
||||
"summary": {
|
||||
"total_tensors": 0,
|
||||
"matching_tensors": 0,
|
||||
"mismatched_tensors": 0,
|
||||
"missing_tensors": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# Check if module exists in both directories
|
||||
if module not in modules1:
|
||||
module_report["error"] = f"Module missing in {dir1}"
|
||||
step_report["summary"]["missing_modules"] += 1
|
||||
step_report["modules"][module] = module_report
|
||||
continue
|
||||
|
||||
if module not in modules2:
|
||||
module_report["error"] = f"Module missing in {dir2}"
|
||||
step_report["summary"]["missing_modules"] += 1
|
||||
step_report["modules"][module] = module_report
|
||||
continue
|
||||
|
||||
# Compare JSON metadata
|
||||
for json_type in ["inputs_json", "outputs_json"]:
|
||||
json_files1 = modules1[module].get(json_type, [])
|
||||
json_files2 = modules2[module].get(json_type, [])
|
||||
|
||||
if json_files1 and json_files2:
|
||||
json1 = load_json(json_files1[0])
|
||||
json2 = load_json(json_files2[0])
|
||||
|
||||
json_comparison = compare_json_values(json1, json2)
|
||||
json_name = json_type.replace("_json", "")
|
||||
module_report[f"{json_name}_metadata"] = json_comparison
|
||||
|
||||
# Add file paths for manual checking when there's a mismatch
|
||||
if not json_comparison.get("match", True):
|
||||
module_report[f"{json_name}_metadata"]["file1"] = str(
|
||||
json_files1[0]
|
||||
)
|
||||
module_report[f"{json_name}_metadata"]["file2"] = str(
|
||||
json_files2[0]
|
||||
)
|
||||
|
||||
# Compare input tensors
|
||||
input_tensors1 = {p.name: p for p in modules1[module].get("inputs", [])}
|
||||
input_tensors2 = {p.name: p for p in modules2[module].get("inputs", [])}
|
||||
all_input_names = set(input_tensors1.keys()) | set(input_tensors2.keys())
|
||||
|
||||
for tensor_name in sorted(all_input_names):
|
||||
if tensor_name not in input_tensors1:
|
||||
module_report["inputs"][tensor_name] = {
|
||||
"match": False,
|
||||
"error": f"Tensor missing in {dir1}",
|
||||
}
|
||||
module_report["summary"]["missing_tensors"] += 1
|
||||
elif tensor_name not in input_tensors2:
|
||||
module_report["inputs"][tensor_name] = {
|
||||
"match": False,
|
||||
"error": f"Tensor missing in {dir2}",
|
||||
}
|
||||
module_report["summary"]["missing_tensors"] += 1
|
||||
else:
|
||||
tensor1 = load_tensor(input_tensors1[tensor_name])
|
||||
tensor2 = load_tensor(input_tensors2[tensor_name])
|
||||
|
||||
comparison = compare_tensors(tensor1, tensor2, rtol, atol)
|
||||
# Add file paths for manual checking when there's a mismatch
|
||||
if not comparison.get("match", False):
|
||||
comparison["file1"] = str(input_tensors1[tensor_name])
|
||||
comparison["file2"] = str(input_tensors2[tensor_name])
|
||||
|
||||
module_report["inputs"][tensor_name] = comparison
|
||||
|
||||
if comparison.get("match", False):
|
||||
module_report["summary"]["matching_tensors"] += 1
|
||||
else:
|
||||
module_report["summary"]["mismatched_tensors"] += 1
|
||||
|
||||
module_report["summary"]["total_tensors"] += 1
|
||||
|
||||
# Compare output tensors
|
||||
output_tensors1 = {p.name: p for p in modules1[module].get("outputs", [])}
|
||||
output_tensors2 = {p.name: p for p in modules2[module].get("outputs", [])}
|
||||
all_output_names = set(output_tensors1.keys()) | set(output_tensors2.keys())
|
||||
|
||||
for tensor_name in sorted(all_output_names):
|
||||
if tensor_name not in output_tensors1:
|
||||
module_report["outputs"][tensor_name] = {
|
||||
"match": False,
|
||||
"error": f"Tensor missing in {dir1}",
|
||||
}
|
||||
module_report["summary"]["missing_tensors"] += 1
|
||||
elif tensor_name not in output_tensors2:
|
||||
module_report["outputs"][tensor_name] = {
|
||||
"match": False,
|
||||
"error": f"Tensor missing in {dir2}",
|
||||
}
|
||||
module_report["summary"]["missing_tensors"] += 1
|
||||
else:
|
||||
tensor1 = load_tensor(output_tensors1[tensor_name])
|
||||
tensor2 = load_tensor(output_tensors2[tensor_name])
|
||||
|
||||
comparison = compare_tensors(tensor1, tensor2, rtol, atol)
|
||||
# Add file paths for manual checking when there's a mismatch
|
||||
if not comparison.get("match", False):
|
||||
comparison["file1"] = str(output_tensors1[tensor_name])
|
||||
comparison["file2"] = str(output_tensors2[tensor_name])
|
||||
|
||||
module_report["outputs"][tensor_name] = comparison
|
||||
|
||||
if comparison.get("match", False):
|
||||
module_report["summary"]["matching_tensors"] += 1
|
||||
else:
|
||||
module_report["summary"]["mismatched_tensors"] += 1
|
||||
|
||||
module_report["summary"]["total_tensors"] += 1
|
||||
|
||||
# Update module status
|
||||
if module_report["summary"]["mismatched_tensors"] > 0:
|
||||
step_report["summary"]["mismatched_modules"] += 1
|
||||
else:
|
||||
step_report["summary"]["matching_modules"] += 1
|
||||
|
||||
step_report["summary"]["total_modules"] += 1
|
||||
step_report["modules"][module] = module_report
|
||||
step_report["module_call_list"].append(module)
|
||||
|
||||
report["steps"][step] = step_report
|
||||
|
||||
# Add overall summary
|
||||
report["summary"] = {
|
||||
"total_steps": len(all_steps),
|
||||
"total_modules": sum(
|
||||
step_report["summary"]["total_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"matching_modules": sum(
|
||||
step_report["summary"]["matching_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"mismatched_modules": sum(
|
||||
step_report["summary"]["mismatched_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"missing_modules": sum(
|
||||
step_report["summary"]["missing_modules"]
|
||||
for step_report in report["steps"].values()
|
||||
),
|
||||
"total_tensors": sum(
|
||||
module_report["summary"]["total_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
"matching_tensors": sum(
|
||||
module_report["summary"]["matching_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
"mismatched_tensors": sum(
|
||||
module_report["summary"]["mismatched_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
"missing_tensors": sum(
|
||||
module_report["summary"]["missing_tensors"]
|
||||
for step_report in report["steps"].values()
|
||||
for module_name, module_report in step_report["modules"].items()
|
||||
if "summary" in module_report
|
||||
),
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def generate_markdown_report(report: Dict, verbose: bool = False) -> str:
|
||||
"""Generate a markdown report from the comparison results."""
|
||||
lines = []
|
||||
|
||||
# Add header
|
||||
lines.append("# Intermediate Logging Comparison Report")
|
||||
lines.append("")
|
||||
lines.append("Comparing intermediate logging outputs between:")
|
||||
lines.append(f"- **Directory 1**: `{report['dir1']}`")
|
||||
lines.append(f"- **Directory 2**: `{report['dir2']}`")
|
||||
lines.append("")
|
||||
lines.append(f"Comparison parameters:")
|
||||
lines.append(f"- Relative tolerance (rtol): {report['rtol']}")
|
||||
lines.append(f"- Absolute tolerance (atol): {report['atol']}")
|
||||
lines.append("")
|
||||
|
||||
# Add overall summary
|
||||
summary = report["summary"]
|
||||
lines.append("## Overall Summary")
|
||||
lines.append("")
|
||||
lines.append("| Category | Total | Matching | Mismatched | Missing |")
|
||||
lines.append("|----------|-------|----------|------------|---------|")
|
||||
lines.append(f"| Steps | {summary['total_steps']} | - | - | - |")
|
||||
lines.append(
|
||||
f"| Modules | {summary['total_modules']} | {summary['matching_modules']} | {summary['mismatched_modules']} | {summary['missing_modules']} |"
|
||||
)
|
||||
lines.append(
|
||||
f"| Tensors | {summary['total_tensors']} | {summary['matching_tensors']} | {summary['mismatched_tensors']} | {summary['missing_tensors']} |"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Add step details
|
||||
for step_name, step_report in sorted(report["steps"].items()):
|
||||
step_summary = step_report["summary"]
|
||||
|
||||
lines.append(f"## {step_name}")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"**Summary**: {step_summary['matching_modules']} matching modules, {step_summary['mismatched_modules']} mismatched modules, {step_summary['missing_modules']} missing modules"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Add module details
|
||||
for module_name in step_report["module_call_list"]:
|
||||
module_report = step_report["modules"][module_name]
|
||||
if "error" in module_report:
|
||||
lines.append(f"### ❌ {module_name}")
|
||||
lines.append("")
|
||||
lines.append(f"**Error**: {module_report['error']}")
|
||||
lines.append("")
|
||||
continue
|
||||
|
||||
module_summary = module_report["summary"]
|
||||
|
||||
# Determine module status
|
||||
if module_summary["mismatched_tensors"] > 0:
|
||||
status = "❌"
|
||||
else:
|
||||
status = "✅"
|
||||
|
||||
lines.append(f"### {status} {module_name}")
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"**Summary**: {module_summary['matching_tensors']} matching tensors, {module_summary['mismatched_tensors']} mismatched tensors, {module_summary['missing_tensors']} missing tensors"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
# Add metadata comparison results if available
|
||||
for metadata_type in ["inputs_metadata", "outputs_metadata"]:
|
||||
if metadata_type in module_report:
|
||||
metadata_comparison = module_report[metadata_type]
|
||||
if not metadata_comparison.get("match", True):
|
||||
file_paths = ""
|
||||
if (
|
||||
"file1" in metadata_comparison
|
||||
and "file2" in metadata_comparison
|
||||
):
|
||||
file_paths = f" - Files: `{metadata_comparison['file1']}` vs `{metadata_comparison['file2']}`"
|
||||
|
||||
lines.append(
|
||||
f"**{metadata_type.capitalize()}**: Mismatch detected{file_paths}"
|
||||
)
|
||||
if verbose and "mismatches" in metadata_comparison:
|
||||
lines.append("```json")
|
||||
lines.append(
|
||||
json.dumps(metadata_comparison["mismatches"], indent=2)
|
||||
)
|
||||
lines.append("```")
|
||||
lines.append("")
|
||||
|
||||
# Add tensor comparison details
|
||||
if module_summary["mismatched_tensors"] > 0 or verbose:
|
||||
# Add input tensor details
|
||||
if module_report["inputs"]:
|
||||
lines.append("#### Input Tensors")
|
||||
lines.append("")
|
||||
lines.append("| Tensor | Status | Details |")
|
||||
lines.append("|--------|--------|---------|")
|
||||
|
||||
for tensor_name, comparison in sorted(
|
||||
module_report["inputs"].items()
|
||||
):
|
||||
if comparison.get("match", False):
|
||||
status = "✅"
|
||||
details = "Tensors match"
|
||||
elif "error" in comparison:
|
||||
status = "❌"
|
||||
details = comparison["error"]
|
||||
else:
|
||||
status = "❌"
|
||||
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A'):.2e}, "
|
||||
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A'):.2e}, "
|
||||
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
|
||||
if "file1" in comparison and "file2" in comparison:
|
||||
details += f"<br>Files: `{comparison['file1']}` vs `{comparison['file2']}`"
|
||||
|
||||
lines.append(f"| {tensor_name} | {status} | {details} |")
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Add output tensor details
|
||||
if module_report["outputs"]:
|
||||
lines.append("#### Output Tensors")
|
||||
lines.append("")
|
||||
lines.append("| Tensor | Status | Details |")
|
||||
lines.append("|--------|--------|---------|")
|
||||
|
||||
for tensor_name, comparison in sorted(
|
||||
module_report["outputs"].items()
|
||||
):
|
||||
if comparison.get("match", False):
|
||||
status = "✅"
|
||||
details = "Tensors match"
|
||||
elif "error" in comparison:
|
||||
status = "❌"
|
||||
details = comparison["error"]
|
||||
else:
|
||||
status = "❌"
|
||||
details = f"Max abs diff: {comparison.get('max_abs_diff', 'N/A')}, "
|
||||
details = f"Max relative diff: {comparison.get('max_rel_diff', 'N/A')}, "
|
||||
details += f"Diff elements: {comparison.get('num_diff_elements', 'N/A')}/{comparison.get('total_elements', 'N/A')}"
|
||||
|
||||
lines.append(f"| {tensor_name} | {status} | {details} |")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare intermediate logging outputs from two different runs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dir1", required=True, help="First intermediate logging directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dir2", required=True, help="Second intermediate logging directory"
|
||||
)
|
||||
parser.add_argument("--output", help="Output file for the report (default: stdout)")
|
||||
parser.add_argument(
|
||||
"--rtol",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Relative tolerance for tensor comparison (default: 1e-5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Absolute tolerance for tensor comparison (default: 1e-8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps", help="Comma-separated list of steps to compare (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modules",
|
||||
help="Comma-separated list of module name patterns to compare (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Include detailed information about each tensor",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse steps and modules
|
||||
steps = args.steps.split(",") if args.steps else None
|
||||
module_patterns = args.modules.split(",") if args.modules else None
|
||||
|
||||
# Compare directories
|
||||
report = compare_directories(
|
||||
Path(args.dir1),
|
||||
Path(args.dir2),
|
||||
rtol=args.rtol,
|
||||
atol=args.atol,
|
||||
steps=steps,
|
||||
module_patterns=module_patterns,
|
||||
)
|
||||
|
||||
# Generate report
|
||||
output = generate_markdown_report(report, verbose=args.verbose)
|
||||
|
||||
# Write report
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
f.write(output)
|
||||
print(f"Report written to {args.output}")
|
||||
else:
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
main()
|
||||
125
vllm/config.py
125
vllm/config.py
@ -17,7 +17,8 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, TypeVar, Union, cast, get_args)
|
||||
Protocol, TypeVar, Union, cast, get_args, List, Set)
|
||||
from re import Pattern
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -4024,6 +4025,122 @@ class KVEventsConfig:
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class IntermediateLoggingConfig:
|
||||
"""Configuration for intermediate tensor logging."""
|
||||
|
||||
output_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Directory where to save the intermediate tensors."""
|
||||
|
||||
reload_input_dir: Optional[str] = None
|
||||
"""Directory where to load the inputs for the steps/modules.
|
||||
This is used when we want to check per module numerical gaps instead
|
||||
of accumulated gap to further dive into the actual numerical issues."""
|
||||
|
||||
module_call_match: Optional[List[str]] = None
|
||||
"""Match modules by name regex and call index (
|
||||
a module can be called multiple times in a step)
|
||||
List of regex:call_idx, call_idx is -1 for default for all calls """
|
||||
|
||||
log_step_ids: List[int] = field(default_factory=lambda: [0])
|
||||
"""List of step IDs to log (empty list means log all steps)."""
|
||||
|
||||
log_post_fwd_inputs: bool = False
|
||||
"""Whether logging inputs after forwards for each module"""
|
||||
|
||||
max_tensor_size: Optional[int] = None
|
||||
"""Maximum number of elements in tensors to log (None = no limit)."""
|
||||
|
||||
enabled: bool = True
|
||||
"""Whether logging is enabled."""
|
||||
device_names: List[str] = field(default_factory=list)
|
||||
"""List of device names to log (empty list means log all devices)."""
|
||||
|
||||
_compiled_module_calls: dict[Pattern,int] = field(default_factory=dict, init=False)
|
||||
"""Compiled regex patterns for module filtering."""
|
||||
|
||||
_module_call: dict[str, int] = field(default_factory=dict, init=False)
|
||||
_step_id_set: Set[int] = field(default_factory=set, init=False)
|
||||
"""Set of step IDs for faster lookup."""
|
||||
_output_run_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Unique directory to save single run/serve logging result."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize derived fields after instance creation."""
|
||||
self._compile_regex_patterns()
|
||||
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
|
||||
self._step_id_set = set(self.log_step_ids)
|
||||
|
||||
def _compile_regex_patterns(self):
|
||||
"""Compile regex patterns for module name filtering."""
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
self._compiled_module_matches = []
|
||||
|
||||
if self.module_call_match is None:
|
||||
logger.info("No module name regex patterns provided, will log all modules")
|
||||
return
|
||||
|
||||
# Compile all patterns
|
||||
for regex_pattern_call_idx in self.module_call_match:
|
||||
try:
|
||||
splits = regex_pattern_call_idx.split(":", 2)
|
||||
regex_pattern = splits[0]
|
||||
call_idx = -1
|
||||
if len(splits) > 1:
|
||||
call_idx = int(splits[1])
|
||||
compiled_pattern: Pattern[str] = re.compile(regex_pattern)
|
||||
self._compiled_module_calls[compiled_pattern] = call_idx
|
||||
logger.info(f"Successfully compiled regex pattern: '{regex_pattern}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}")
|
||||
raise ValueError(f"Failed to parse module_call_match '{regex_pattern_call_idx}': {e}") from e
|
||||
|
||||
|
||||
logger.info(f"Compiled {len(self._compiled_module_calls)} regex patterns")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the config to a dictionary for serialization."""
|
||||
return {
|
||||
"output_run_dir": self.output_run_dir,
|
||||
"module_call_match": self.module_call_match,
|
||||
"log_step_ids": self.log_step_ids,
|
||||
"max_tensor_size": self.max_tensor_size,
|
||||
"enabled": self.enabled,
|
||||
"device_names": self.device_names
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
|
||||
"""Parse the CLI value for the speculative config."""
|
||||
return cls(**dict_value)
|
||||
|
||||
@property
|
||||
def output_run_dir(self) -> str:
|
||||
return self._output_run_dir
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# Intermediate logging doesn't affect the computation graph
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
NO_COMPILATION = 0
|
||||
@ -4480,6 +4597,8 @@ class VllmConfig:
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
"""The configurations for event publishing."""
|
||||
intermediate_log_config: Optional[IntermediateLoggingConfig] = None
|
||||
"""Configuration for intermediate tensor logging."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
@ -4564,6 +4683,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.intermediate_log_config:
|
||||
vllm_factors.append(self.intermediate_log_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
|
||||
@ -26,7 +26,8 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DecodingConfig, DetailedTraceModules, Device,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||
HfOverrides, KVEventsConfig, KVTransferConfig,
|
||||
HfOverrides, IntermediateLoggingConfig,
|
||||
KVEventsConfig, KVTransferConfig,
|
||||
LoadConfig, LogprobsMode, LoRAConfig, ModelConfig,
|
||||
ModelDType, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
@ -399,6 +400,7 @@ class EngineArgs:
|
||||
str] = ModelConfig.logits_processor_pattern
|
||||
|
||||
speculative_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = \
|
||||
ObservabilityConfig.show_hidden_metrics_for_version
|
||||
@ -444,6 +446,9 @@ class EngineArgs:
|
||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||
# DEPRECATED
|
||||
enable_prompt_adapter: bool = False
|
||||
intermediate_log_config_path: Optional[str] = None
|
||||
|
||||
intermediate_log_config: Optional[dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
@ -758,6 +763,20 @@ class EngineArgs:
|
||||
help="The configurations for speculative decoding. Should be a "
|
||||
"JSON string.")
|
||||
|
||||
intermediate_log_group = parser.add_argument_group(
|
||||
title="IntermediateLoggingConfig",
|
||||
description=IntermediateLoggingConfig.__doc__,
|
||||
)
|
||||
intermediate_log_group.add_argument(
|
||||
"--intermediate-log-config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="The configurations for intermediate loggings. Should be a "
|
||||
"JSON string.")
|
||||
|
||||
intermediate_log_group.add_argument("--intermediate-log-config-path", type=str,
|
||||
help="The path to the configurations for intermediate loggings. Should be a string.")
|
||||
|
||||
# Observability arguments
|
||||
observability_kwargs = get_kwargs(ObservabilityConfig)
|
||||
observability_group = parser.add_argument_group(
|
||||
@ -846,6 +865,9 @@ class EngineArgs:
|
||||
vllm_group.add_argument("--additional-config",
|
||||
**vllm_kwargs["additional_config"])
|
||||
|
||||
|
||||
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
@ -957,6 +979,21 @@ class EngineArgs:
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
pt_load_map_location=self.pt_load_map_location,
|
||||
)
|
||||
|
||||
|
||||
def create_intermediate_log_config(
|
||||
self,
|
||||
) -> Optional[IntermediateLoggingConfig]:
|
||||
"""Initializes and returns an IntermediateLoggingConfig object based on
|
||||
`intermediate_log_config` or `intermediate_log_config_path`.
|
||||
"""
|
||||
if self.intermediate_log_config is not None:
|
||||
return IntermediateLoggingConfig.from_dict(
|
||||
self.intermediate_log_config)
|
||||
if self.intermediate_log_config_path is not None:
|
||||
with open(self.intermediate_log_config_path, "r") as f:
|
||||
return IntermediateLoggingConfig.from_dict(json.load(f))
|
||||
return None
|
||||
|
||||
def create_speculative_config(
|
||||
self,
|
||||
@ -1198,6 +1235,9 @@ class EngineArgs:
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
)
|
||||
|
||||
intermediate_log_config = self.create_intermediate_log_config(
|
||||
)
|
||||
|
||||
# Reminder: Please update docs/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
if self.num_scheduler_steps > 1:
|
||||
@ -1284,7 +1324,6 @@ class EngineArgs:
|
||||
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
||||
collect_detailed_traces=self.collect_detailed_traces,
|
||||
)
|
||||
|
||||
config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
@ -1299,6 +1338,7 @@ class EngineArgs:
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
intermediate_log_config=intermediate_log_config,
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
|
||||
@ -77,6 +77,10 @@ class EngineCore:
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
if vllm_config.intermediate_log_config is not None:
|
||||
self.collective_rpc("register_intermediate_hooks",
|
||||
args=(vllm_config.intermediate_log_config, ))
|
||||
|
||||
if executor_fail_callback is not None:
|
||||
self.model_executor.register_failure_callback(
|
||||
executor_fail_callback)
|
||||
|
||||
0
vllm/v1/intermediates/__init__.py
Normal file
0
vllm/v1/intermediates/__init__.py
Normal file
599
vllm/v1/intermediates/intermediates_logging.py
Normal file
599
vllm/v1/intermediates/intermediates_logging.py
Normal file
@ -0,0 +1,599 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Module for logging intermediate tensors during model execution.
|
||||
|
||||
This module provides functionality to capture and save intermediate tensors
|
||||
(inputs and outputs) from PyTorch modules during forward passes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
|
||||
# Import logger from vllm
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Global step counter
|
||||
_CURRENT_STEP = 0
|
||||
|
||||
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
|
||||
|
||||
IL_MODULE_NAME = "_il_module_name"
|
||||
IL_MODULE_CALL_IDX = "_il_module_call_idx"
|
||||
|
||||
# Utility functions for intermediate logging
|
||||
|
||||
|
||||
def should_log_step(config):
|
||||
"""Check if the current step should be logged based on the step IDs.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
|
||||
Returns:
|
||||
True if the current step should be logged, False otherwise.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
|
||||
# If log_step_ids is empty, log all steps
|
||||
if not config.log_step_ids:
|
||||
return True
|
||||
|
||||
# Otherwise, check if current step is in the set of step IDs to log
|
||||
return get_step() in config._step_id_set
|
||||
|
||||
|
||||
def should_log_device(config, device_name):
|
||||
"""Check if a device should be logged based on the device names.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
|
||||
|
||||
Returns:
|
||||
True if the device should be logged, False otherwise.
|
||||
If device_names is empty, all devices are logged.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
# If device_names is empty, log all devices
|
||||
if not config.device_names:
|
||||
return True
|
||||
|
||||
# Otherwise, check if device_name is in the list of device names to log
|
||||
return device_name in config.device_names
|
||||
|
||||
|
||||
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
|
||||
"""Check if a module should be logged based on the name regex patterns.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
module_name: The name of the module to check.
|
||||
|
||||
Returns:
|
||||
True if the module should be logged, False otherwise.
|
||||
If no patterns are defined, all modules are logged.
|
||||
If patterns are defined, the module is logged if it matches ANY pattern.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
# If no patterns are defined, log all modules
|
||||
if not config._compiled_module_calls:
|
||||
logger.debug("No patterns defined, will log module: %s", module_name)
|
||||
set_il_module_name(module, module_name)
|
||||
set_il_module_call_idx(module, -1)
|
||||
return True
|
||||
|
||||
# Check if the module name matches any of the patterns
|
||||
for pattern, call_idx in config._compiled_module_calls.items():
|
||||
match = pattern.search(module_name)
|
||||
if match:
|
||||
logger.debug(
|
||||
"Module %s, %s matches pattern: '%s', call_idx=%s",
|
||||
module_name,
|
||||
module.__class__.__name__,
|
||||
pattern.pattern,
|
||||
call_idx,
|
||||
)
|
||||
set_il_module_name(module, module_name)
|
||||
set_il_module_call_idx(module, call_idx)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_log_enabled(config):
|
||||
if not config or not config.enabled:
|
||||
logger.debug("Not logging because config not enabled")
|
||||
return False
|
||||
if torch.compiler.is_compiling():
|
||||
logger.debug("Not logging because torch.compile is in progress")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_il_module_name(module: torch.nn.Module) -> str:
|
||||
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
|
||||
|
||||
|
||||
def get_il_module_call_idx(module: torch.nn.Module) -> int:
|
||||
return getattr(module, IL_MODULE_CALL_IDX, -1)
|
||||
|
||||
|
||||
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
|
||||
setattr(module, IL_MODULE_NAME, name)
|
||||
|
||||
|
||||
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
|
||||
setattr(module, IL_MODULE_CALL_IDX, idx)
|
||||
|
||||
|
||||
_global_config: Optional[IntermediateLoggingConfig] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
|
||||
"""
|
||||
Temporarily sets the global config for the duration of the context.
|
||||
:param config: Keyword arguments to set as global config
|
||||
"""
|
||||
global _global_config
|
||||
old_config = _global_config
|
||||
try:
|
||||
_global_config = config
|
||||
yield
|
||||
finally:
|
||||
_global_config = old_config
|
||||
|
||||
|
||||
def get_current_il_config():
|
||||
return _global_config
|
||||
|
||||
|
||||
def dump_intermediates_to_json(intermediates: Any, path: Path) -> Any:
|
||||
try:
|
||||
# Convert inputs to JSON-serializable format
|
||||
intermediates_json = convert_intermediates_to_json(intermediates)
|
||||
with open(path, "w") as f:
|
||||
json.dump(intermediates_json, f, indent=2)
|
||||
logger.debug("Saved all intermediates as JSON to %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save intermediates as JSON: %s", e)
|
||||
import traceback
|
||||
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
|
||||
def convert_intermediates_to_json(tensor: Any) -> Any:
|
||||
"""Convert a intermediates(including tensor) to a JSON-serializable
|
||||
representation.
|
||||
|
||||
Args:
|
||||
intermediates: The intermediates to convert.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable representation of the tensor.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
try:
|
||||
result = {
|
||||
"type": "tensor",
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
"numel": tensor.numel(),
|
||||
}
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Handle any errors in tensor conversion
|
||||
return {
|
||||
"type": "tensor_error",
|
||||
"error": str(e),
|
||||
"tensor_type": str(type(tensor)),
|
||||
}
|
||||
|
||||
elif isinstance(tensor, (list, tuple)):
|
||||
# For lists/tuples, recursively convert each element
|
||||
container_type = "list" if isinstance(tensor, list) else "tuple"
|
||||
|
||||
# If it's a large list, only include a sample
|
||||
if len(tensor) > 20:
|
||||
return {
|
||||
"type": container_type,
|
||||
"length": len(tensor),
|
||||
"sample": [
|
||||
convert_intermediates_to_json(item) for item in tensor[:100]
|
||||
],
|
||||
"note": f"Showing only first 20 of {len(tensor)} items",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": container_type,
|
||||
"items": [convert_intermediates_to_json(item) for item in tensor],
|
||||
}
|
||||
|
||||
elif isinstance(tensor, dict):
|
||||
# For dictionaries, recursively convert each value
|
||||
if len(tensor) > 20:
|
||||
# For large dicts, only include keys and a sample of values
|
||||
keys = list(tensor.keys())
|
||||
sample_keys = keys[:20]
|
||||
return {
|
||||
"type": "dict",
|
||||
"length": len(tensor),
|
||||
"keys": keys,
|
||||
"sample": {
|
||||
k: convert_intermediates_to_json(tensor[k]) for k in sample_keys
|
||||
},
|
||||
"note": f"Showing only first 20 of {len(tensor)} items",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": "dict",
|
||||
"items": {
|
||||
k: convert_intermediates_to_json(v) for k, v in tensor.items()
|
||||
},
|
||||
}
|
||||
|
||||
elif tensor is None:
|
||||
return None
|
||||
|
||||
elif isinstance(tensor, (int, float, bool, str)):
|
||||
# Primitive types can be directly serialized
|
||||
return tensor
|
||||
|
||||
else:
|
||||
# For other types, use string representation
|
||||
return {"type": str(type(tensor).__name__), "string_repr": str(tensor)}
|
||||
|
||||
|
||||
def save_tensors_metadata_if_too_large(tensor: torch.Tensor, file_path: str) -> bool:
|
||||
"""Utility function to dump tensor metadata to a file.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to dump.
|
||||
file_path: Base path where to save the tensor (without extension).
|
||||
"""
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is None:
|
||||
return False
|
||||
if (
|
||||
intermediate_log_config.max_tensor_size is not None
|
||||
and tensor.numel() > intermediate_log_config.max_tensor_size
|
||||
):
|
||||
# Save tensor metadata instead of full tensor
|
||||
tensor_info = {
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype),
|
||||
"device": str(tensor.device),
|
||||
"numel": tensor.numel(),
|
||||
"skipped": f"Tensor size {tensor.numel()} exceeds max_tensor_size "
|
||||
f"{intermediate_log_config.max_tensor_size}",
|
||||
}
|
||||
os.makedirs(os.path.dirname(f"{file_path}.json"), exist_ok=True)
|
||||
with open(f"{file_path}.json", "w") as f:
|
||||
json.dump(tensor_info, f, indent=2)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def safe_reload_tensor(save_path: str, tensor: Any, reload_dir: Optional[str]) -> Any:
|
||||
if reload_dir is None:
|
||||
return None
|
||||
try:
|
||||
intermediate_log_config = get_current_il_config()
|
||||
assert intermediate_log_config is not None
|
||||
replace_dir = str(intermediate_log_config.output_run_dir)
|
||||
reload_path = save_path.replace(replace_dir, reload_dir)
|
||||
logger.debug("reload tensor of shape %s from %s", tensor.shape, reload_path)
|
||||
return torch.load(reload_path, map_location=tensor.device)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load tensor from %s: %s", reload_dir, e)
|
||||
return tensor
|
||||
|
||||
|
||||
def save_tensors(
|
||||
tensor: Any, file_path: str, reload_input_dir: Optional[str] = None
|
||||
) -> Any:
|
||||
"""Utility function to dump tensor to a file.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
|
||||
tensors, or a dictionary containing tensors.
|
||||
file_path: Base path where to save the tensor (without extension).
|
||||
"""
|
||||
|
||||
# Also save the actual tensor data for tensors
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
# Check if tensor is too large
|
||||
if save_tensors_metadata_if_too_large(tensor, file_path):
|
||||
return
|
||||
# Get device name
|
||||
device_name = str(tensor.device)
|
||||
# Skip if device filtering is enabled and this device should not be
|
||||
# logged
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if not should_log_device(intermediate_log_config, device_name):
|
||||
logger.debug(
|
||||
"Skipping tensor on device %s due to device filter", device_name
|
||||
)
|
||||
return tensor
|
||||
# Append device name to file path
|
||||
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
|
||||
try:
|
||||
# Save tensor directly without detaching or moving to CPU
|
||||
torch.save(tensor, pt_path)
|
||||
reloaded_tensor = safe_reload_tensor(pt_path, tensor, reload_input_dir)
|
||||
if reloaded_tensor is not None:
|
||||
return reloaded_tensor
|
||||
logger.debug("Saved tensor of shape %s to %s", tensor.shape, pt_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
|
||||
return tensor
|
||||
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
# For collections, also save each item individually
|
||||
|
||||
reloaded_inputs = []
|
||||
for i, item in enumerate(tensor):
|
||||
reloaded = save_tensors(item, f"{file_path}_{i}", reload_input_dir)
|
||||
reloaded_inputs.append(reloaded)
|
||||
return tuple(reloaded_inputs) if reloaded_inputs else tensor
|
||||
if isinstance(tensor, dict):
|
||||
reloaded_inputs = {}
|
||||
# For dictionaries, also save each value individually
|
||||
for k, v in tensor.items():
|
||||
reloaded = save_tensors(v, f"{file_path}_{k}", reload_input_dir)
|
||||
reloaded_inputs[k] = reloaded
|
||||
return reloaded_inputs if reloaded_inputs else tensor
|
||||
|
||||
|
||||
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any) -> None:
|
||||
"""Hook to increment the global step counter after a forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
outputs: The outputs from the module's forward function.
|
||||
"""
|
||||
if get_current_il_config() is None:
|
||||
return
|
||||
# Increment the global step counter
|
||||
increment_step()
|
||||
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||
_CURRENT_STEP_MODULE_CALL_STEP = {}
|
||||
|
||||
|
||||
def _prepare_module_log_dir(
|
||||
intermediate_log_config: IntermediateLoggingConfig,
|
||||
module_name: str,
|
||||
is_pre_fwd: bool = False,
|
||||
) -> Path:
|
||||
# Create a unique directory for this step if not
|
||||
dump_dir = Path(intermediate_log_config.output_run_dir) / f"step_{get_step()}"
|
||||
dump_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Create module directory
|
||||
suffix = ""
|
||||
module_call_idx = get_current_step_module_call(module_name)
|
||||
if module_call_idx > 0:
|
||||
suffix = f"_{module_call_idx}"
|
||||
module_dir = dump_dir / (module_name + suffix)
|
||||
if is_pre_fwd:
|
||||
_log_module_call(intermediate_log_config, module_name + suffix)
|
||||
module_dir.mkdir(exist_ok=True, parents=True)
|
||||
logger.debug("Logging module %s inputs/outputs to %s", module_name, module_dir)
|
||||
return module_dir
|
||||
|
||||
|
||||
def _log_module_call(
|
||||
intermediate_log_config: IntermediateLoggingConfig,
|
||||
module_name: str,
|
||||
) -> None:
|
||||
logger.debug("Logging module call for %s", module_name)
|
||||
# write module name and call to step:
|
||||
file = (
|
||||
Path(intermediate_log_config.output_run_dir)
|
||||
/ f"step_{get_step()}"
|
||||
/ "module_calls.txt"
|
||||
)
|
||||
with open(file, "a") as f:
|
||||
f.write(f"{module_name}\n")
|
||||
|
||||
|
||||
def update_current_step_module_call(module_name: str) -> None:
|
||||
logger.debug("Updating current step module call for %s", module_name)
|
||||
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
|
||||
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
|
||||
else:
|
||||
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
|
||||
|
||||
|
||||
def get_current_step_module_call(module_name: str) -> int:
|
||||
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
|
||||
|
||||
|
||||
def prepare_log_current_fwd(module, is_pre_fwd: bool = False) -> Optional[Path]:
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is None or not intermediate_log_config.enabled:
|
||||
return None
|
||||
if not should_log_step(intermediate_log_config):
|
||||
return None
|
||||
|
||||
module_name = get_il_module_name(module)
|
||||
log_call_idx = get_il_module_call_idx(module)
|
||||
current_call_idx = get_current_step_module_call(module_name)
|
||||
should_log = True
|
||||
if log_call_idx >= 0 and current_call_idx != log_call_idx:
|
||||
should_log = False
|
||||
|
||||
log_dir = None
|
||||
if is_pre_fwd:
|
||||
update_current_step_module_call(module_name)
|
||||
if should_log:
|
||||
log_dir = _prepare_module_log_dir(
|
||||
intermediate_log_config, module_name, is_pre_fwd=is_pre_fwd
|
||||
)
|
||||
return log_dir
|
||||
|
||||
|
||||
def log_pre_fwd_hook(
|
||||
module: torch.nn.Module, inputs: tuple[Any, ...]
|
||||
) -> tuple[Any, ...]:
|
||||
"""Hook to capture module inputs before forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
|
||||
Returns:
|
||||
The unchanged inputs.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
|
||||
dump_intermediates_to_json(inputs, log_dir / "inputs.json")
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is not None:
|
||||
reload_input_dir = getattr(
|
||||
intermediate_log_config,
|
||||
"reload_input_dir",
|
||||
"/tmp/vllm_intermediates/57f4a3b2-9c4c-4afe-be71-0e95369d74b5",
|
||||
)
|
||||
else:
|
||||
reload_input_dir = None
|
||||
reloaded_inputs = save_tensors(
|
||||
inputs, str(log_dir / "inputs"), reload_input_dir
|
||||
)
|
||||
if reloaded_inputs is not None:
|
||||
return reloaded_inputs
|
||||
return inputs
|
||||
|
||||
|
||||
def log_post_fwd_hook(
|
||||
module: torch.nn.Module, inputs: tuple[Any, ...], outputs: Any
|
||||
) -> None:
|
||||
"""Hook to capture module outputs after forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
outputs: The outputs from the module's forward function.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
|
||||
dump_intermediates_to_json(outputs, log_dir / "outputs.json")
|
||||
save_tensors(outputs, str(log_dir / "outputs"))
|
||||
intermediate_log_config = get_current_il_config()
|
||||
assert intermediate_log_config is not None, "IL config should not be None"
|
||||
if intermediate_log_config.log_post_fwd_inputs:
|
||||
dump_intermediates_to_json(inputs, log_dir / "post_fwd_inputs.json")
|
||||
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
|
||||
|
||||
|
||||
def get_step() -> int:
|
||||
"""Get the current global step counter.
|
||||
|
||||
Returns:
|
||||
The current global step counter.
|
||||
"""
|
||||
return _CURRENT_STEP
|
||||
|
||||
|
||||
def increment_step() -> int:
|
||||
"""Increment the global step counter.
|
||||
|
||||
Returns:
|
||||
The new step counter value.
|
||||
"""
|
||||
global _CURRENT_STEP
|
||||
_CURRENT_STEP += 1
|
||||
return _CURRENT_STEP
|
||||
|
||||
|
||||
def reset_step() -> None:
|
||||
"""Reset the global step counter to zero."""
|
||||
global _CURRENT_STEP
|
||||
_CURRENT_STEP = 0
|
||||
|
||||
|
||||
class IntermediatesLogger:
|
||||
"""Class to manage logging of intermediate tensors during model
|
||||
execution."""
|
||||
|
||||
def __init__(self, config: IntermediateLoggingConfig):
|
||||
self.config = config
|
||||
self.hooks: list[
|
||||
tuple[str, str, Optional[RemovableHandle], Optional[RemovableHandle]]
|
||||
] = []
|
||||
logger.debug("Created IntermediatesLogger with config: %s", config)
|
||||
path = Path(config.output_run_dir)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
# Log configuration
|
||||
logger.info("Intermediates will be logged in %s", config.output_run_dir)
|
||||
|
||||
def register_hooks(self, model: torch.nn.Module) -> None:
|
||||
"""Register hooks for the model.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to register hooks for.
|
||||
"""
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name and should_log_module(self.config, name, module):
|
||||
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
|
||||
logger.debug(
|
||||
"Registered pre_fwd hook for %s", module.__class__.__name__
|
||||
)
|
||||
post_hook = module.register_forward_hook(log_post_fwd_hook)
|
||||
logger.debug(
|
||||
"Registered post_fwd hook for %s", module.__class__.__name__
|
||||
)
|
||||
self.hooks.append((name, module, pre_hook, post_hook))
|
||||
|
||||
# Register a step counter hook for the root model
|
||||
step_hook = model.register_forward_hook(step_fwd)
|
||||
self.hooks.append(("", model, None, step_hook))
|
||||
logger.info("Registered hooks for %s modules", len(self.hooks))
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
"""Remove all registered hooks."""
|
||||
for _, _, pre_hook, post_hook in self.hooks:
|
||||
if pre_hook is not None:
|
||||
pre_hook.remove()
|
||||
if post_hook is not None:
|
||||
post_hook.remove()
|
||||
|
||||
logger.info("Removed %s hooks", len(self.hooks))
|
||||
self.hooks = []
|
||||
|
||||
|
||||
def register_intermediate_hooks(
|
||||
model: torch.nn.Module, config: Optional[IntermediateLoggingConfig] = None, **kwargs
|
||||
) -> IntermediatesLogger:
|
||||
"""Register hooks to log intermediate tensors for a model.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to log intermediates for.
|
||||
config: Configuration for intermediate logging. If provided, this takes
|
||||
precedence over kwargs.
|
||||
|
||||
Returns:
|
||||
An IntermediatesLogger instance that can be used to manage the hooks.
|
||||
"""
|
||||
if config is None:
|
||||
# Create config from kwargs
|
||||
config = IntermediateLoggingConfig.from_dict(kwargs)
|
||||
|
||||
logger_instance = IntermediatesLogger(config)
|
||||
logger_instance.register_hooks(model)
|
||||
return logger_instance
|
||||
@ -32,6 +32,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -344,8 +345,9 @@ class Worker(WorkerBase):
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
with intermediate_logging(self.vllm_config.intermediate_log_config):
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
if parallel_config.distributed_executor_backend != "external_launcher" \
|
||||
|
||||
@ -6,9 +6,10 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, IntermediateLoggingConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.intermediates.intermediates_logging import register_intermediate_hooks
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -63,3 +64,27 @@ class WorkerBase(WorkerBaseV0):
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
|
||||
def register_intermediate_hooks(self,
|
||||
config: Optional[IntermediateLoggingConfig] = None,
|
||||
**kwargs) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is called via collective_rpc from the engine core.
|
||||
It registers hooks on the model to dump intermediate tensors during execution.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging. If provided, this takes precedence over kwargs.
|
||||
"""
|
||||
if self.model_runner is None or not hasattr(self.model_runner, "model") or self.model_runner.model is None:
|
||||
logger.error("Could not register intermediate hooks: model_runner.model is not accessible")
|
||||
return
|
||||
model = self.model_runner.model
|
||||
try:
|
||||
# Register hooks
|
||||
register_intermediate_hooks(model, config, **kwargs)
|
||||
# Store the logger instance for potential later hook removal
|
||||
except Exception as e:
|
||||
logger.info("Successfully registered intermediate hooks")
|
||||
logger.error("Error registering intermediate hooks", exc_info=True)
|
||||
|
||||
|
||||
@ -128,6 +128,22 @@ class WorkerBase:
|
||||
def vocab_size(self) -> int:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def register_intermediate_hooks(self, config=None, **kwargs) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is a stub for v0 workers. The actual implementation is in v1 workers.
|
||||
It's included here for compatibility with the collective_rpc mechanism.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging.
|
||||
**kwargs: Configuration parameters for intermediate logging.
|
||||
These are ignored in v0 workers.
|
||||
"""
|
||||
logger.warning(
|
||||
"register_intermediate_hooks is not implemented in v0 workers. "
|
||||
"This is only available in v1 workers. No hooks will be registered.")
|
||||
return None
|
||||
|
||||
|
||||
class DelegateWorkerBase(WorkerBase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user