From 7bdb42b2f22f14bf450e084b8f9938f598c08f9b Mon Sep 17 00:00:00 2001 From: Zhang Xiangze Date: Fri, 7 Nov 2025 19:03:57 +0800 Subject: [PATCH] [CPU]Avoid repeated random sample compile (#28260) Signed-off-by: Zhang Xiangze --- vllm/v1/sample/ops/topk_topp_sampler.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7a4b224822bd..02ea658b7f20 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -127,15 +127,6 @@ class TopKTopPSampler(nn.Module): elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) - # Note: this is a workaround for - # https://github.com/pytorch/pytorch/pull/151218 - @torch.compile(dynamic=True) - def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: - probs = logits.softmax(dim=-1, dtype=torch.float32) - q = torch.empty_like(probs) - q.exponential_() - return probs.div(q).argmax(dim=-1).view(-1) - if len(generators) != logits.shape[0]: return compiled_random_sample(logits), logits_to_return else: @@ -148,6 +139,16 @@ class TopKTopPSampler(nn.Module): return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return +# Note: this is a workaround for +# https://github.com/pytorch/pytorch/pull/151218 +@torch.compile(dynamic=True) +def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: + probs = logits.softmax(dim=-1, dtype=torch.float32) + q = torch.empty_like(probs) + q.exponential_() + return probs.div(q).argmax(dim=-1).view(-1) + + def apply_top_k_top_p( logits: torch.Tensor, k: torch.Tensor | None,