LiuXiaoxuanPKU c335930d75 benchmark
2025-05-02 09:23:30 -07:00

58 lines
1.9 KiB
Python

import json
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
model = "r1-distill-llama-8B"
MODEL_TO_NAMES = {
"r1-distill-llama-8B" : "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
}
method = "ngram"
dataset = "aime"
datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model], use_fast=False)
def cleanup(data):
# Remove the prefill phase
data = data[1:]
# Cap the maximum value to 10
data = [min(x, 10) for x in data]
return data
def load_data(datapath):
acceptance_stats = []
with open(datapath, "r") as f:
lines = f.readlines()
for line in lines:
data = json.loads(line)
acceptance_stats.append(cleanup(data['acc']))
print("Input:", tokenizer.decode(data['prompt_token_ids']))
print("Output:", tokenizer.decode(data['generated_token_ids']))
print("=============================================")
# Pad the acceptance stats to the same length
max_length = max(len(stats) for stats in acceptance_stats)
for i in range(len(acceptance_stats)):
acceptance_stats[i] += [-2] * (max_length - len(acceptance_stats[i]))
print(f"Load {len(acceptance_stats)} with max length {max_length}")
return acceptance_stats
acceptance_stats = load_data(datapath)
fig, ax = plt.subplots()
sns.heatmap(acceptance_stats, cmap="YlGnBu")
plt.xlabel("Position")
plt.ylabel("Request ID")
# Add Y-axis labels on the right
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim()) # Match y-axis range
ax2.set_yticks([]) # Remove right tick marks if undesired
ax2.set_ylabel("# of Accepted Tokens", labelpad=10) # Set right y-axis label
plt.tight_layout()
plt.savefig(f"figures/{model}/{method}_{dataset}_acceptance_stats.png")