From 534c45b9620d4d97cf2ea2cdee77e8461844a243 Mon Sep 17 00:00:00 2001 From: ZiTian Zhao Date: Sun, 10 Aug 2025 11:25:42 +0800 Subject: [PATCH] Improve fast_topk function with type hints and documentation (#22530) Signed-off-by: zitian.zhao --- vllm/model_executor/models/utils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c69df6e616618..6c27fedc61b17 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -736,7 +736,23 @@ def cast_overflow_tensors( return tensors -def fast_topk(values, topk, dim): +def fast_topk(values: torch.Tensor, topk: int, + dim: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Optimized topk implementation that uses torch.max for k=1 case. + + This function provides better performance for the common case of k=1 + by using torch.max instead of the more general torch.topk. + + Args: + values: Input tensor to find top-k values from + topk: Number of top values to return (k). Must be > 0. + dim: Dimension along which to compute topk + + Returns: + Tuple of (values, indices) where values are the top-k values + and indices are their corresponding indices in the input tensor + """ if topk == 1: # Use max along the specified dimension to get both value and index return torch.max(values, dim=dim, keepdim=True)