Improve fast_topk function with type hints and documentation (#22530)

Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
ZiTian Zhao 2025-08-10 11:25:42 +08:00 committed by GitHub
parent 3d7363e61c
commit 534c45b962
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -736,7 +736,23 @@ def cast_overflow_tensors(
return tensors
def fast_topk(values, topk, dim):
def fast_topk(values: torch.Tensor, topk: int,
dim: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Optimized topk implementation that uses torch.max for k=1 case.
This function provides better performance for the common case of k=1
by using torch.max instead of the more general torch.topk.
Args:
values: Input tensor to find top-k values from
topk: Number of top values to return (k). Must be > 0.
dim: Dimension along which to compute topk
Returns:
Tuple of (values, indices) where values are the top-k values
and indices are their corresponding indices in the input tensor
"""
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)