mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
Improve fast_topk function with type hints and documentation (#22530)
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
parent
3d7363e61c
commit
534c45b962
@ -736,7 +736,23 @@ def cast_overflow_tensors(
|
||||
return tensors
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
def fast_topk(values: torch.Tensor, topk: int,
|
||||
dim: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Optimized topk implementation that uses torch.max for k=1 case.
|
||||
|
||||
This function provides better performance for the common case of k=1
|
||||
by using torch.max instead of the more general torch.topk.
|
||||
|
||||
Args:
|
||||
values: Input tensor to find top-k values from
|
||||
topk: Number of top values to return (k). Must be > 0.
|
||||
dim: Dimension along which to compute topk
|
||||
|
||||
Returns:
|
||||
Tuple of (values, indices) where values are the top-k values
|
||||
and indices are their corresponding indices in the input tensor
|
||||
"""
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
return torch.max(values, dim=dim, keepdim=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user