mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 12:16:13 +08:00
[V1][TPU] TPU-optimized top-p implementation (avoids scattering). (#15736)
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
This commit is contained in:
parent
55acf86bf8
commit
1b84eff03a
@ -36,7 +36,9 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
&& echo TEST_6 \
|
&& echo TEST_6 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
||||||
&& echo TEST_7 \
|
&& echo TEST_7 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
|
||||||
|
&& echo TEST_8 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \
|
||||||
|
|
||||||
|
|
||||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||||
|
|||||||
132
tests/v1/tpu/test_topk_topp_sampler.py
Normal file
132
tests/v1/tpu/test_topk_topp_sampler.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
|
||||||
|
|
||||||
|
if not current_platform.is_tpu():
|
||||||
|
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
BATCH_SIZE = 1024
|
||||||
|
VOCAB_SIZE = 128 * 1024
|
||||||
|
TOLERANCE = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_topp_result_sums_past_p():
|
||||||
|
with torch.device(xm.xla_device()):
|
||||||
|
xm.set_rng_state(seed=33)
|
||||||
|
|
||||||
|
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
|
||||||
|
# Random top-p values between 0 and 1.
|
||||||
|
p = torch.rand((BATCH_SIZE, ))
|
||||||
|
|
||||||
|
# Set p=1 for ~50% of requests in the batch (top-p disabled).
|
||||||
|
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1)
|
||||||
|
|
||||||
|
no_op_k = torch.tensor([VOCAB_SIZE])
|
||||||
|
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||||
|
k=no_op_k,
|
||||||
|
p=p)
|
||||||
|
|
||||||
|
# Verify that the masked logit's probability sums to at least p.
|
||||||
|
probs.masked_fill_(logits_masked.isinf(), 0)
|
||||||
|
masked_prob_sum = probs.sum(dim=-1)
|
||||||
|
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
# Perform assertion on CPU.
|
||||||
|
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
|
||||||
|
|
||||||
|
|
||||||
|
def test_topp_basic():
|
||||||
|
with torch.device(xm.xla_device()):
|
||||||
|
logits = torch.tensor([[math.log(0.2),
|
||||||
|
math.log(0.3),
|
||||||
|
math.log(0.5)],
|
||||||
|
[math.log(0.5),
|
||||||
|
math.log(0.1),
|
||||||
|
math.log(0.4)]])
|
||||||
|
|
||||||
|
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||||
|
k=torch.tensor([3, 3]),
|
||||||
|
p=torch.tensor([0.79, 0.79]))
|
||||||
|
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
# Expect the smallest elements to be dropped.
|
||||||
|
expected_result = logits.clone().cpu()
|
||||||
|
expected_result[0, 0] = float("-inf")
|
||||||
|
expected_result[1, 1] = float("-inf")
|
||||||
|
assert torch.allclose(expected_result, result.cpu())
|
||||||
|
|
||||||
|
|
||||||
|
def test_topp_select_all():
|
||||||
|
with torch.device(xm.xla_device()):
|
||||||
|
logits = torch.tensor([[math.log(0.2),
|
||||||
|
math.log(0.3),
|
||||||
|
math.log(0.5)],
|
||||||
|
[math.log(0.5),
|
||||||
|
math.log(0.1),
|
||||||
|
math.log(0.4)]])
|
||||||
|
|
||||||
|
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||||
|
k=torch.tensor([3, 3]),
|
||||||
|
p=torch.tensor([1.0, 1.0]))
|
||||||
|
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
assert torch.allclose(logits.cpu(), result.cpu())
|
||||||
|
|
||||||
|
|
||||||
|
def test_topp_with_ties():
|
||||||
|
with torch.device(xm.xla_device()):
|
||||||
|
# Input has multiple math.log(0.3).
|
||||||
|
logits = torch.tensor(
|
||||||
|
[[math.log(0.3),
|
||||||
|
math.log(0.3),
|
||||||
|
math.log(0.3),
|
||||||
|
math.log(0.1)]])
|
||||||
|
|
||||||
|
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||||
|
k=torch.tensor([4]),
|
||||||
|
p=torch.tensor([0.2]))
|
||||||
|
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
# All tie values are included in the top-p set. Tie breaking is left
|
||||||
|
# to be done during final sampling (all tie tokens have equal
|
||||||
|
# probability of being chosen).
|
||||||
|
expected_result = logits.clone().cpu()
|
||||||
|
expected_result[0, 3] = float("-inf")
|
||||||
|
assert torch.allclose(expected_result, result.cpu())
|
||||||
|
|
||||||
|
|
||||||
|
def test_both_topk_topp():
|
||||||
|
with torch.device(xm.xla_device()):
|
||||||
|
logits = torch.tensor([[math.log(0.2),
|
||||||
|
math.log(0.3),
|
||||||
|
math.log(0.5)],
|
||||||
|
[math.log(0.5),
|
||||||
|
math.log(0.1),
|
||||||
|
math.log(0.4)]])
|
||||||
|
|
||||||
|
# Set k=1 for the first batch.
|
||||||
|
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||||
|
k=torch.tensor([1, 3]),
|
||||||
|
p=torch.tensor([0.79, 0.79]))
|
||||||
|
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
# Since for the first batch k=1, expect only the largest element gets
|
||||||
|
# selected.
|
||||||
|
expected_result = logits.clone().cpu()
|
||||||
|
expected_result[0, 0] = float("-inf")
|
||||||
|
expected_result[0, 1] = float("-inf")
|
||||||
|
expected_result[1, 1] = float("-inf")
|
||||||
|
assert torch.allclose(expected_result, result.cpu())
|
||||||
@ -122,23 +122,48 @@ class TopKTopPSampler(nn.Module):
|
|||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# If only top-k is specified, use pytorch's builtin topk op. This leads
|
logits = apply_top_k_top_p_tpu(logits, k, p)
|
||||||
# to significant speed up on TPU compared to using apply_top_k_top_p.
|
|
||||||
if k is not None and p is None:
|
|
||||||
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
|
|
||||||
|
|
||||||
mask = torch.ones_like(logits, dtype=torch.bool)
|
|
||||||
mask.scatter_(-1, topk_indices, False)
|
|
||||||
logits.masked_fill_(mask, float('-inf'))
|
|
||||||
else:
|
|
||||||
# TODO Placeholder for TPU optimized topp kernel
|
|
||||||
# logits = apply_top_k_top_p(logits, k, p)
|
|
||||||
pass
|
|
||||||
|
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_top_k_top_p_tpu(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
p: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply top-k and top-p optimized for TPU.
|
||||||
|
|
||||||
|
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||||
|
This is achieved by finding a "cut-off" element in the original logit, and
|
||||||
|
after thresholding the logit using this cut-off, the remaining elements
|
||||||
|
shall constitute the top-p set.
|
||||||
|
|
||||||
|
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||||
|
logit), all tie elements are included in the top-p set. In other words,
|
||||||
|
this function does not break ties. Instead, these tie tokens have equal
|
||||||
|
chance of being chosen during final sampling, so we can consider the tie
|
||||||
|
being broken then.
|
||||||
|
"""
|
||||||
|
if k is not None:
|
||||||
|
logits = apply_top_k_only(logits, k)
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||||
|
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||||
|
top_p_mask[:, -1] = False # at least one
|
||||||
|
|
||||||
|
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||||
|
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||||
|
elements_to_discard = probs < top_p_cutoff
|
||||||
|
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p(
|
def apply_top_k_top_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
@ -199,7 +224,7 @@ def apply_top_k_only(
|
|||||||
max_top_k = k.max()
|
max_top_k = k.max()
|
||||||
# topk.values tensor has shape [batch_size, max_top_k].
|
# topk.values tensor has shape [batch_size, max_top_k].
|
||||||
# Convert top k to 0-based index in range [0, max_top_k).
|
# Convert top k to 0-based index in range [0, max_top_k).
|
||||||
k_index = k.sub_(1).unsqueeze(1)
|
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
|
||||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||||
# Handle non-topk rows.
|
# Handle non-topk rows.
|
||||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user