mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 03:17:03 +08:00
109 lines
3.9 KiB
Python
109 lines
3.9 KiB
Python
import json
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
from transformers import AutoTokenizer
|
|
from .common import MODEL_TO_NAMES, load_data
|
|
import requests
|
|
import os
|
|
from pathlib import Path
|
|
|
|
class AcceptanceStatsClient:
|
|
"""Client for fetching and processing acceptance statistics data."""
|
|
|
|
def __init__(self, model_name, method, dataset, data_path=None):
|
|
"""Initialize the client with model and dataset info."""
|
|
self.model_name = model_name
|
|
self.method = method
|
|
self.dataset = dataset
|
|
|
|
if data_path is None:
|
|
self.data_path = f"/data/lily/batch-sd/data/{model_name}/{method}_{dataset}_acceptance_stats.jsonl"
|
|
else:
|
|
self.data_path = data_path
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model_name], use_fast=False)
|
|
self.acceptance_stats = None
|
|
|
|
def load_data(self):
|
|
"""Load the acceptance statistics from file."""
|
|
self.acceptance_stats = load_data(self.data_path, self.tokenizer)
|
|
return self.acceptance_stats
|
|
|
|
def plot_heatmap(self, output_dir="figures"):
|
|
"""Plot the acceptance statistics as a heatmap."""
|
|
if self.acceptance_stats is None:
|
|
self.load_data()
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8))
|
|
sns.heatmap(self.acceptance_stats, cmap="YlGnBu")
|
|
plt.xlabel("Position")
|
|
plt.ylabel("Request ID")
|
|
|
|
# Add Y-axis labels on the right
|
|
ax2 = ax.twinx()
|
|
ax2.set_ylim(ax.get_ylim())
|
|
ax2.set_yticks([])
|
|
ax2.set_ylabel("# of Accepted Tokens", labelpad=10)
|
|
|
|
plt.title(f"Acceptance Statistics: {self.model_name} - {self.method} - {self.dataset}")
|
|
plt.tight_layout()
|
|
|
|
# Create output directory if it doesn't exist
|
|
output_path = Path(output_dir) / self.model_name
|
|
os.makedirs(output_path, exist_ok=True)
|
|
|
|
output_file = output_path / f"{self.method}_{self.dataset}_acceptance_stats.pdf"
|
|
plt.savefig(output_file)
|
|
print(f"Saved heatmap to {output_file}")
|
|
return fig
|
|
|
|
def get_summary_stats(self):
|
|
"""Get summary statistics about the acceptance data."""
|
|
if self.acceptance_stats is None:
|
|
self.load_data()
|
|
|
|
# Calculate average acceptance rate for each position
|
|
avg_by_position = [sum(col)/len(col) for col in zip(*self.acceptance_stats) if sum(1 for v in col if v >= 0) > 0]
|
|
|
|
# Calculate average acceptance rate for each request
|
|
avg_by_request = [sum(row)/len(row) for row in self.acceptance_stats]
|
|
|
|
return {
|
|
"total_requests": len(self.acceptance_stats),
|
|
"max_position": len(avg_by_position),
|
|
"avg_acceptance_rate": sum(avg_by_request)/len(avg_by_request),
|
|
"avg_by_position": avg_by_position,
|
|
"avg_by_request": avg_by_request
|
|
}
|
|
|
|
# Example model configuration
|
|
model = "llama3.1-8B"
|
|
# model = "r1-distill-llama-8B"
|
|
method = "eagle3"
|
|
dataset = "mtbench"
|
|
# dataset = "aime"
|
|
# method = "ngram"
|
|
# dataset = "cnndailymail"
|
|
# datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
|
|
datapath = "acceptance_stats.jsonl"
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model], use_fast=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Use the client instead of directly loading data
|
|
client = AcceptanceStatsClient(model, method, dataset, datapath)
|
|
acceptance_stats = client.load_data()
|
|
|
|
# Get summary statistics
|
|
summary = client.get_summary_stats()
|
|
print("Summary Statistics:")
|
|
print(f"Total Requests: {summary['total_requests']}")
|
|
print(f"Max Position: {summary['max_position']}")
|
|
print(f"Average Acceptance Rate: {summary['avg_acceptance_rate']:.2f}")
|
|
|
|
# Create heatmap visualization
|
|
plot_heatmap = False
|
|
if plot_heatmap:
|
|
client.plot_heatmap()
|
|
|