2025-08-03 20:06:15 -07:00

64 lines
2.1 KiB
Python

import json
from dataclasses import dataclass
MODEL_TO_NAMES = {
"r1-distill-llama-8B" : "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"llama3-8B" : "meta-llama/Meta-Llama-3-8B-Instruct",
"llama3.1-8B" : "meta-llama/Llama-3.1-8B-Instruct",
"llama3.1-70B" : "meta-llama/Llama-3.1-70B-Instruct",
}
@dataclass
class AccStats:
lens: list[int]
probs: list[float] = None
entropies: list[float] = None
def __post_init__(self):
if self.probs is not None:
assert len(self.lens) == len(self.probs), "Length of lens and probs must match"
if self.entropies is not None:
assert len(self.lens) == len(self.entropies), "Length of lens and entropies must match"
# remove the prefill accepted lens
self.lens = self.lens[1:]
# remove the last proposed tokens
if self.probs:
self.probs = self.probs[:-1]
if self.entropies:
self.entropies = self.entropies[:-1]
@property
def length(self):
return len(self.lens)
# def cleanup(acc_stats: AccStats) ->
# # Remove the prefill phase
# data = data[1:]
# # Cap the maximum value to 10
# data = [min(x, 10) for x in data]
# return data
def load_data(datapath, tokenizer, verbose=False):
acceptance_stats = []
with open(datapath, "r") as f:
lines = f.readlines()
for line in lines:
data = json.loads(line)
stat = AccStats(
lens=data['acc']['acc_len'],
probs=data['acc'].get('acc_prob', None),
entropies=data['acc'].get('acc_entropy', None)
)
acceptance_stats.append(stat)
if verbose:
print("Input:", tokenizer.decode(data['prompt_token_ids']))
print("Output:", tokenizer.decode(data['generated_token_ids']))
print("=============================================")
max_length = max(stats.length for stats in acceptance_stats)
print(f"Load {len(acceptance_stats)} with max length {max_length}")
return acceptance_stats