record entropy and prob

This commit is contained in:
LiuXiaoxuanPKU 2025-06-29 22:33:49 -07:00
parent 17bccecb1c
commit 2815bd6143
2 changed files with 54 additions and 6 deletions

View File

@ -3,6 +3,8 @@ import torch
import torch.nn as nn
import triton
import triton.language as tl
import torch.nn.functional as F
from torch.distributions import Categorical
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
@ -98,12 +100,27 @@ class EagleProposer:
)
sample_hidden_states = hidden_states_logits[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
all_draft_probs = []
all_draft_entropy = []
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
draft_token_ids = logits.argmax(dim=-1)
# Get the probabilities of the draft tokens.
draft_probs = probs.gather(
dim=1,
index=draft_token_ids.unsqueeze(1)
)
dist = Categorical(logits=logits)
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
all_draft_probs.append(draft_probs)
all_draft_entropy.append(entropy)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
return draft_token_ids.view(-1, 1), all_draft_probs, all_draft_entropy
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
@ -164,9 +181,19 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
draft_probs = probs.gather(
dim=1,
index=draft_token_ids.unsqueeze(1)
)
dist = Categorical(logits=logits)
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
all_draft_probs.append(draft_probs)
all_draft_entropy.append(entropy)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
return draft_token_ids, all_draft_probs, all_draft_entropy
@staticmethod
def prepare_inputs(

View File

@ -1193,8 +1193,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for i, token_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[i]
if req_id not in self.acceptance_stats:
self.acceptance_stats[req_id] = []
self.acceptance_stats[req_id].append(len(token_ids))
self.acceptance_stats[req_id] = {
'acc_len': [],
'acc_prob': [],
'acc_entropy': [],
}
self.acceptance_stats[req_id]['acc_len'].append(len(token_ids))
# Force 1 generated token per request.
for i, token_ids in enumerate(valid_sampled_token_ids):
valid_sampled_token_ids[i] = token_ids[:1]
@ -1274,7 +1278,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
draft_token_ids = self.drafter.propose(
draft_token_ids, draft_probs, draft_entropy = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
@ -1286,6 +1290,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
spec_token_ids = draft_token_ids.tolist()
for req_id in self.input_batch.req_ids:
if req_id not in self.acceptance_stats:
self.acceptance_stats[req_id] = {
'acc_len': [],
'acc_prob': [],
'acc_entropy': [],
}
req_index = self.input_batch.req_id_to_index[req_id]
step_probs, step_entropy = [], []
for prob, entropy in zip(
draft_probs, draft_entropy):
step_probs.append(prob[req_index].item())
step_entropy.append(entropy[req_index].item())
self.acceptance_stats[req_id]['acc_prob'].append(step_probs)
self.acceptance_stats[req_id]['acc_entropy'].append(step_entropy)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()