mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 20:57:56 +08:00
record entropy and prob
This commit is contained in:
parent
2815bd6143
commit
54be44ee74
@ -1,9 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
@ -101,26 +101,22 @@ class EagleProposer:
|
|||||||
sample_hidden_states = hidden_states_logits[last_token_indices]
|
sample_hidden_states = hidden_states_logits[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
|
||||||
|
|
||||||
all_draft_probs = []
|
all_draft_probs = []
|
||||||
all_draft_entropy = []
|
all_draft_entropy = []
|
||||||
|
|
||||||
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
# Get the probabilities of the draft tokens.
|
# Get the probabilities of the draft tokens.
|
||||||
draft_probs = probs.gather(
|
draft_probs = probs.gather(dim=1, index=draft_token_ids.unsqueeze(1))
|
||||||
dim=1,
|
dist = Categorical(logits=logits)
|
||||||
index=draft_token_ids.unsqueeze(1)
|
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
||||||
)
|
|
||||||
dist = Categorical(logits=logits)
|
|
||||||
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
|
||||||
all_draft_probs.append(draft_probs)
|
all_draft_probs.append(draft_probs)
|
||||||
all_draft_entropy.append(entropy)
|
all_draft_entropy.append(entropy)
|
||||||
|
|
||||||
|
|
||||||
# Early exit if there is only one draft token to be generated.
|
# Early exit if there is only one draft token to be generated.
|
||||||
if self.num_speculative_tokens == 1:
|
if self.num_speculative_tokens == 1:
|
||||||
return draft_token_ids.view(-1, 1), all_draft_probs, all_draft_entropy
|
return draft_token_ids.view(-1,
|
||||||
|
1), all_draft_probs, all_draft_entropy
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
@ -182,12 +178,10 @@ class EagleProposer:
|
|||||||
draft_token_ids_list.append(draft_token_ids)
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
draft_probs = probs.gather(
|
draft_probs = probs.gather(dim=1,
|
||||||
dim=1,
|
index=draft_token_ids.unsqueeze(1))
|
||||||
index=draft_token_ids.unsqueeze(1)
|
dist = Categorical(logits=logits)
|
||||||
)
|
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
||||||
dist = Categorical(logits=logits)
|
|
||||||
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
|
|
||||||
all_draft_probs.append(draft_probs)
|
all_draft_probs.append(draft_probs)
|
||||||
all_draft_entropy.append(entropy)
|
all_draft_entropy.append(entropy)
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,6 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|||||||
|
|
||||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||||
scatter_mm_placeholders)
|
scatter_mm_placeholders)
|
||||||
import json
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
@ -1299,8 +1298,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
}
|
}
|
||||||
req_index = self.input_batch.req_id_to_index[req_id]
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
step_probs, step_entropy = [], []
|
step_probs, step_entropy = [], []
|
||||||
for prob, entropy in zip(
|
for prob, entropy in zip(draft_probs, draft_entropy):
|
||||||
draft_probs, draft_entropy):
|
|
||||||
step_probs.append(prob[req_index].item())
|
step_probs.append(prob[req_index].item())
|
||||||
step_entropy.append(entropy[req_index].item())
|
step_entropy.append(entropy[req_index].item())
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user