vllm/tests/v1/tpu/test_topk_topp_sampler.py
Hyesoo Yang 1b84eff03a
[V1][TPU] TPU-optimized top-p implementation (avoids scattering). (#15736)
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
2025-04-02 17:18:08 -07:00

133 lines
4.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import math
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
import torch_xla.core.xla_model as xm
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6
def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
probs = logits.softmax(dim=-1)
# Random top-p values between 0 and 1.
p = torch.rand((BATCH_SIZE, ))
# Set p=1 for ~50% of requests in the batch (top-p disabled).
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1)
no_op_k = torch.tensor([VOCAB_SIZE])
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(),
k=no_op_k,
p=p)
# Verify that the masked logit's probability sums to at least p.
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
xm.mark_step()
# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
def test_topp_basic():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_topp_select_all():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0]))
xm.mark_step()
assert torch.allclose(logits.cpu(), result.cpu())
def test_topp_with_ties():
with torch.device(xm.xla_device()):
# Input has multiple math.log(0.3).
logits = torch.tensor(
[[math.log(0.3),
math.log(0.3),
math.log(0.3),
math.log(0.1)]])
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([4]),
p=torch.tensor([0.2]))
xm.mark_step()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
# probability of being chosen).
expected_result = logits.clone().cpu()
expected_result[0, 3] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
def test_both_topk_topp():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])
# Set k=1 for the first batch.
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
# Since for the first batch k=1, expect only the largest element gets
# selected.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[0, 1] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())