From 2815bd6143dd87332ac64e81d6402a3ef42bfb8f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Sun, 29 Jun 2025 22:33:49 -0700 Subject: [PATCH] record entropy and prob --- vllm/v1/spec_decode/eagle.py | 33 +++++++++++++++++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 27 +++++++++++++++++++++--- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1de14584d3968..4adbd2571843f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,6 +3,8 @@ import torch import torch.nn as nn import triton import triton.language as tl +import torch.nn.functional as F +from torch.distributions import Categorical from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context @@ -98,12 +100,27 @@ class EagleProposer: ) sample_hidden_states = hidden_states_logits[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + + + all_draft_probs = [] + all_draft_entropy = [] + + probs = F.softmax(logits, dim=-1, dtype=torch.float32) draft_token_ids = logits.argmax(dim=-1) + # Get the probabilities of the draft tokens. + draft_probs = probs.gather( + dim=1, + index=draft_token_ids.unsqueeze(1) + ) + dist = Categorical(logits=logits) + entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1] + all_draft_probs.append(draft_probs) + all_draft_entropy.append(entropy) + # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) + return draft_token_ids.view(-1, 1), all_draft_probs, all_draft_entropy # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -164,9 +181,19 @@ class EagleProposer: draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) + probs = F.softmax(logits, dim=-1, dtype=torch.float32) + draft_probs = probs.gather( + dim=1, + index=draft_token_ids.unsqueeze(1) + ) + dist = Categorical(logits=logits) + entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1] + all_draft_probs.append(draft_probs) + all_draft_entropy.append(entropy) + # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids + return draft_token_ids, all_draft_probs, all_draft_entropy @staticmethod def prepare_inputs( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ddb9233695220..3fe5f3a2399d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1193,8 +1193,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): for i, token_ids in enumerate(valid_sampled_token_ids): req_id = self.input_batch.req_ids[i] if req_id not in self.acceptance_stats: - self.acceptance_stats[req_id] = [] - self.acceptance_stats[req_id].append(len(token_ids)) + self.acceptance_stats[req_id] = { + 'acc_len': [], + 'acc_prob': [], + 'acc_entropy': [], + } + self.acceptance_stats[req_id]['acc_len'].append(len(token_ids)) # Force 1 generated token per request. for i, token_ids in enumerate(valid_sampled_token_ids): valid_sampled_token_ids[i] = token_ids[:1] @@ -1274,7 +1278,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat(target_hidden_states, dim=-1) - draft_token_ids = self.drafter.propose( + draft_token_ids, draft_probs, draft_entropy = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1286,6 +1290,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) spec_token_ids = draft_token_ids.tolist() + for req_id in self.input_batch.req_ids: + if req_id not in self.acceptance_stats: + self.acceptance_stats[req_id] = { + 'acc_len': [], + 'acc_prob': [], + 'acc_entropy': [], + } + req_index = self.input_batch.req_id_to_index[req_id] + step_probs, step_entropy = [], [] + for prob, entropy in zip( + draft_probs, draft_entropy): + step_probs.append(prob[req_index].item()) + step_entropy.append(entropy[req_index].item()) + + self.acceptance_stats[req_id]['acc_prob'].append(step_probs) + self.acceptance_stats[req_id]['acc_entropy'].append(step_entropy) + # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata()