mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[CPU]Avoid repeated random sample compile (#28260)
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
This commit is contained in:
parent
315068eb4a
commit
7bdb42b2f2
@ -127,15 +127,6 @@ class TopKTopPSampler(nn.Module):
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# Note: this is a workaround for
|
||||
# https://github.com/pytorch/pytorch/pull/151218
|
||||
@torch.compile(dynamic=True)
|
||||
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
return probs.div(q).argmax(dim=-1).view(-1)
|
||||
|
||||
if len(generators) != logits.shape[0]:
|
||||
return compiled_random_sample(logits), logits_to_return
|
||||
else:
|
||||
@ -148,6 +139,16 @@ class TopKTopPSampler(nn.Module):
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
|
||||
|
||||
# Note: this is a workaround for
|
||||
# https://github.com/pytorch/pytorch/pull/151218
|
||||
@torch.compile(dynamic=True)
|
||||
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
return probs.div(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user