mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 05:05:01 +08:00
Improve fast_topk function with type hints and documentation (#22530)
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com>
This commit is contained in:
parent
3d7363e61c
commit
534c45b962
@ -736,7 +736,23 @@ def cast_overflow_tensors(
|
|||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
def fast_topk(values, topk, dim):
|
def fast_topk(values: torch.Tensor, topk: int,
|
||||||
|
dim: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Optimized topk implementation that uses torch.max for k=1 case.
|
||||||
|
|
||||||
|
This function provides better performance for the common case of k=1
|
||||||
|
by using torch.max instead of the more general torch.topk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: Input tensor to find top-k values from
|
||||||
|
topk: Number of top values to return (k). Must be > 0.
|
||||||
|
dim: Dimension along which to compute topk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (values, indices) where values are the top-k values
|
||||||
|
and indices are their corresponding indices in the input tensor
|
||||||
|
"""
|
||||||
if topk == 1:
|
if topk == 1:
|
||||||
# Use max along the specified dimension to get both value and index
|
# Use max along the specified dimension to get both value and index
|
||||||
return torch.max(values, dim=dim, keepdim=True)
|
return torch.max(values, dim=dim, keepdim=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user