mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import torch
|
|
from torch import Generator
|
|
|
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
|
|
|
DEVICE = "cuda"
|
|
|
|
BATCH_SIZE = 1024
|
|
VOCAB_SIZE = 128 * 1024
|
|
|
|
|
|
def test_topk_impl_equivalance():
|
|
|
|
with torch.device(DEVICE):
|
|
generator = Generator(device=DEVICE).manual_seed(33)
|
|
|
|
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
|
|
|
# Random top-k values between 1 and 9.
|
|
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
|
|
|
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
|
k.masked_fill_(
|
|
torch.randint(0,
|
|
2, (BATCH_SIZE, ),
|
|
generator=generator,
|
|
dtype=bool), VOCAB_SIZE)
|
|
|
|
# Top-k only implementation
|
|
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
|
|
|
# Top-p + top-k
|
|
no_op_top_p = torch.tensor([1.0])
|
|
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
|
|
|
assert torch.allclose(result1, result2)
|