mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 06:27:04 +08:00
record entropy and prob
This commit is contained in:
parent
17bccecb1c
commit
2815bd6143
@ -3,6 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -98,12 +100,27 @@ class EagleProposer:
|
||||
)
|
||||
sample_hidden_states = hidden_states_logits[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
|
||||
all_draft_probs = []
|
||||
all_draft_entropy = []
|
||||
|
||||
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
# Get the probabilities of the draft tokens.
|
||||
draft_probs = probs.gather(
|
||||
dim=1,
|
||||
index=draft_token_ids.unsqueeze(1)
|
||||
)
|
||||
dist = Categorical(logits=logits)
|
||||
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
||||
all_draft_probs.append(draft_probs)
|
||||
all_draft_entropy.append(entropy)
|
||||
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
return draft_token_ids.view(-1, 1), all_draft_probs, all_draft_entropy
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
@ -164,9 +181,19 @@ class EagleProposer:
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
draft_probs = probs.gather(
|
||||
dim=1,
|
||||
index=draft_token_ids.unsqueeze(1)
|
||||
)
|
||||
dist = Categorical(logits=logits)
|
||||
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
||||
all_draft_probs.append(draft_probs)
|
||||
all_draft_entropy.append(entropy)
|
||||
|
||||
# [batch_size, num_speculative_tokens]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
return draft_token_ids, all_draft_probs, all_draft_entropy
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs(
|
||||
|
||||
@ -1193,8 +1193,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
if req_id not in self.acceptance_stats:
|
||||
self.acceptance_stats[req_id] = []
|
||||
self.acceptance_stats[req_id].append(len(token_ids))
|
||||
self.acceptance_stats[req_id] = {
|
||||
'acc_len': [],
|
||||
'acc_prob': [],
|
||||
'acc_entropy': [],
|
||||
}
|
||||
self.acceptance_stats[req_id]['acc_len'].append(len(token_ids))
|
||||
# Force 1 generated token per request.
|
||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||
valid_sampled_token_ids[i] = token_ids[:1]
|
||||
@ -1274,7 +1278,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
draft_token_ids, draft_probs, draft_entropy = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
@ -1286,6 +1290,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
if req_id not in self.acceptance_stats:
|
||||
self.acceptance_stats[req_id] = {
|
||||
'acc_len': [],
|
||||
'acc_prob': [],
|
||||
'acc_entropy': [],
|
||||
}
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
step_probs, step_entropy = [], []
|
||||
for prob, entropy in zip(
|
||||
draft_probs, draft_entropy):
|
||||
step_probs.append(prob[req_index].item())
|
||||
step_entropy.append(entropy[req_index].item())
|
||||
|
||||
self.acceptance_stats[req_id]['acc_prob'].append(step_probs)
|
||||
self.acceptance_stats[req_id]['acc_entropy'].append(step_entropy)
|
||||
|
||||
# Clear KVConnector state after all KVs are generated.
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user