record entropy and prob

This commit is contained in:
LiuXiaoxuanPKU 2025-06-29 22:34:49 -07:00
parent 2815bd6143
commit 54be44ee74
2 changed files with 15 additions and 23 deletions

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
import torch.nn.functional as F
from torch.distributions import Categorical from torch.distributions import Categorical
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
@ -100,7 +100,6 @@ class EagleProposer:
) )
sample_hidden_states = hidden_states_logits[last_token_indices] sample_hidden_states = hidden_states_logits[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
all_draft_probs = [] all_draft_probs = []
all_draft_entropy = [] all_draft_entropy = []
@ -108,19 +107,16 @@ class EagleProposer:
probs = F.softmax(logits, dim=-1, dtype=torch.float32) probs = F.softmax(logits, dim=-1, dtype=torch.float32)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
# Get the probabilities of the draft tokens. # Get the probabilities of the draft tokens.
draft_probs = probs.gather( draft_probs = probs.gather(dim=1, index=draft_token_ids.unsqueeze(1))
dim=1, dist = Categorical(logits=logits)
index=draft_token_ids.unsqueeze(1) entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
)
dist = Categorical(logits=logits)
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
all_draft_probs.append(draft_probs) all_draft_probs.append(draft_probs)
all_draft_entropy.append(entropy) all_draft_entropy.append(entropy)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1: if self.num_speculative_tokens == 1:
return draft_token_ids.view(-1, 1), all_draft_probs, all_draft_entropy return draft_token_ids.view(-1,
1), all_draft_probs, all_draft_entropy
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
@ -182,12 +178,10 @@ class EagleProposer:
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
probs = F.softmax(logits, dim=-1, dtype=torch.float32) probs = F.softmax(logits, dim=-1, dtype=torch.float32)
draft_probs = probs.gather( draft_probs = probs.gather(dim=1,
dim=1, index=draft_token_ids.unsqueeze(1))
index=draft_token_ids.unsqueeze(1) dist = Categorical(logits=logits)
) entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
dist = Categorical(logits=logits)
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
all_draft_probs.append(draft_probs) all_draft_probs.append(draft_probs)
all_draft_entropy.append(entropy) all_draft_entropy.append(entropy)

View File

@ -49,7 +49,6 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders) scatter_mm_placeholders)
import json
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
@ -282,7 +281,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
self.acceptance_stats = {} self.acceptance_stats = {}
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
@ -1007,7 +1006,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, torch.Tensor]:
# Update KVConnector with the KVConnector metadata forward(). # Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata( get_kv_transfer_group().bind_connector_metadata(
@ -1202,7 +1201,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Force 1 generated token per request. # Force 1 generated token per request.
for i, token_ids in enumerate(valid_sampled_token_ids): for i, token_ids in enumerate(valid_sampled_token_ids):
valid_sampled_token_ids[i] = token_ids[:1] valid_sampled_token_ids[i] = token_ids[:1]
# Mask out the sampled tokens that should not be sampled. # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices: for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear() valid_sampled_token_ids[i].clear()
@ -1299,11 +1298,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
} }
req_index = self.input_batch.req_id_to_index[req_id] req_index = self.input_batch.req_id_to_index[req_id]
step_probs, step_entropy = [], [] step_probs, step_entropy = [], []
for prob, entropy in zip( for prob, entropy in zip(draft_probs, draft_entropy):
draft_probs, draft_entropy):
step_probs.append(prob[req_index].item()) step_probs.append(prob[req_index].item())
step_entropy.append(entropy[req_index].item()) step_entropy.append(entropy[req_index].item())
self.acceptance_stats[req_id]['acc_prob'].append(step_probs) self.acceptance_stats[req_id]['acc_prob'].append(step_probs)
self.acceptance_stats[req_id]['acc_entropy'].append(step_entropy) self.acceptance_stats[req_id]['acc_entropy'].append(step_entropy)