[CPU]Avoid repeated random sample compile (#28260)

Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
This commit is contained in:
Zhang Xiangze 2025-11-07 19:03:57 +08:00 committed by GitHub
parent 315068eb4a
commit 7bdb42b2f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -127,15 +127,6 @@ class TopKTopPSampler(nn.Module):
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
@ -148,6 +139,16 @@ class TopKTopPSampler(nn.Module):
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
def apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor | None,