mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 01:47:04 +08:00
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
import json
|
|
from dataclasses import dataclass
|
|
|
|
MODEL_TO_NAMES = {
|
|
"r1-distill-llama-8B" : "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
|
"llama3-8B" : "meta-llama/Meta-Llama-3-8B-Instruct",
|
|
"llama3.1-8B" : "meta-llama/Llama-3.1-8B-Instruct",
|
|
"llama3.1-70B" : "meta-llama/Llama-3.1-70B-Instruct",
|
|
}
|
|
|
|
@dataclass
|
|
class AccStats:
|
|
lens: list[int]
|
|
probs: list[float] = None
|
|
entropies: list[float] = None
|
|
|
|
def __post_init__(self):
|
|
if self.probs is not None:
|
|
assert len(self.lens) == len(self.probs), "Length of lens and probs must match"
|
|
if self.entropies is not None:
|
|
assert len(self.lens) == len(self.entropies), "Length of lens and entropies must match"
|
|
|
|
# remove the prefill accepted lens
|
|
self.lens = self.lens[1:]
|
|
|
|
# remove the last proposed tokens
|
|
if self.probs:
|
|
self.probs = self.probs[:-1]
|
|
if self.entropies:
|
|
self.entropies = self.entropies[:-1]
|
|
|
|
@property
|
|
def length(self):
|
|
return len(self.lens)
|
|
|
|
# def cleanup(acc_stats: AccStats) ->
|
|
# # Remove the prefill phase
|
|
# data = data[1:]
|
|
# # Cap the maximum value to 10
|
|
# data = [min(x, 10) for x in data]
|
|
# return data
|
|
|
|
def load_data(datapath, tokenizer, verbose=False):
|
|
acceptance_stats = []
|
|
with open(datapath, "r") as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
data = json.loads(line)
|
|
stat = AccStats(
|
|
lens=data['acc']['acc_len'],
|
|
probs=data['acc'].get('acc_prob', None),
|
|
entropies=data['acc'].get('acc_entropy', None)
|
|
)
|
|
acceptance_stats.append(stat)
|
|
if verbose:
|
|
print("Input:", tokenizer.decode(data['prompt_token_ids']))
|
|
print("Output:", tokenizer.decode(data['generated_token_ids']))
|
|
print("=============================================")
|
|
|
|
max_length = max(stats.length for stats in acceptance_stats)
|
|
|
|
print(f"Load {len(acceptance_stats)} with max length {max_length}")
|
|
return acceptance_stats
|