mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 17:54:33 +08:00
[Sampler] Adapt to FlashInfer 0.2.3 sampler API (#15777)
Signed-off-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
aef94c6d07
commit
7fdfa01530
@ -255,9 +255,10 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
# uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.4/flashinfer_python-0.2.4+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
|
||||
# TESTING: install FlashInfer from source to test 2.7.0 final RC
|
||||
FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.2.post1" ; \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.4" ; \
|
||||
fi
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
|
||||
@ -169,7 +169,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
@ -214,7 +217,10 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
@ -284,6 +290,10 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
Test the flashinfer and nonflashinfer backend generate
|
||||
the same output metrics.
|
||||
"""
|
||||
|
||||
pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
|
||||
"the ability to pass in uniform samples.")
|
||||
|
||||
torch.set_default_device(device)
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
|
||||
@ -647,6 +647,8 @@ def test_flashinfer_fallback(seed: int, device: str):
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
pytest.skip("Flashinfer sampler is disabled")
|
||||
|
||||
pytest.skip("After FlashInfer 0.2.3, sampling will never fail")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
|
||||
@ -1,14 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||
from torch import Generator
|
||||
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
is_flashinfer_available)
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
|
||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||
|
||||
|
||||
def test_topk_impl_equivalance():
|
||||
|
||||
@ -35,3 +41,67 @@ def test_topk_impl_equivalance():
|
||||
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
|
||||
assert torch.allclose(result1, result2)
|
||||
|
||||
|
||||
def test_flashinfer_sampler():
|
||||
'''
|
||||
This test verifies that the FlashInfer top-k and top-p sampling
|
||||
implementation produces the same results as the Python implementation.
|
||||
|
||||
NOTE: FlashInfer did not directly expose an interface for fused top-k and
|
||||
top-p prob renorm (it did provide fused sampling but we cannot compare
|
||||
sampling results due to randomness), so we will compare the probability
|
||||
renormed consequently by top-k and then top-p of FlashInfer implementation.
|
||||
'''
|
||||
|
||||
if not FLASHINFER_ENABLED:
|
||||
pytest.skip(
|
||||
"FlashInfer not installed or not available on this platform.")
|
||||
|
||||
with torch.device(DEVICE):
|
||||
generator = Generator(device=DEVICE).manual_seed(42)
|
||||
|
||||
# Generate random logits
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Generate various top-k and top-p values
|
||||
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
|
||||
p_values = torch.rand(
|
||||
(BATCH_SIZE, ),
|
||||
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
|
||||
|
||||
# Sometimes disable top-k (k=vocab_size)
|
||||
k_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), VOCAB_SIZE)
|
||||
|
||||
# Sometimes disable top-p (p=1.0)
|
||||
p_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), 1.0)
|
||||
|
||||
python_logits = apply_top_k_top_p(
|
||||
logits=logits.clone(),
|
||||
k=k_values,
|
||||
p=p_values,
|
||||
)
|
||||
python_probs = torch.softmax(python_logits, dim=-1)
|
||||
|
||||
# FlashInfer only exposed renorm interfaces for probs so convert first
|
||||
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
|
||||
flashinfer_probs = top_k_renorm_probs(
|
||||
probs=flashinfer_probs,
|
||||
top_k=k_values,
|
||||
)
|
||||
flashinfer_probs = top_p_renorm_probs(
|
||||
probs=flashinfer_probs,
|
||||
top_p=p_values,
|
||||
)
|
||||
|
||||
# Compare the results
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
|
||||
"FlashInfer and Python sampling implementations do not match!"
|
||||
|
||||
@ -123,12 +123,13 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
# for rejection sampling
|
||||
if self.use_flashinfer and chain_speculative_sampling is not None:
|
||||
batch_size, k, _ = draft_probs.shape
|
||||
uniform_samples = self._create_uniform_samples(
|
||||
seeded_seqs, batch_size, k, draft_probs.device)
|
||||
output_token_ids, accepted_token_num, emitted_token_num \
|
||||
= chain_speculative_sampling(
|
||||
draft_probs, draft_token_ids, uniform_samples,
|
||||
target_with_bonus_probs)
|
||||
|
||||
(output_token_ids, accepted_token_num,
|
||||
emitted_token_num) = chain_speculative_sampling(
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
target_with_bonus_probs,
|
||||
)
|
||||
|
||||
# num_emitted_tokens returned by flashinfer
|
||||
# does not include the bonus token
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
@ -24,7 +23,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
import flashinfer.sampling
|
||||
# yapf: disable
|
||||
from flashinfer.sampling import (
|
||||
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
||||
@ -33,6 +31,10 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
else:
|
||||
flashinfer_top_k_top_p_sampling = None
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_sampler() -> torch.nn.Module:
|
||||
if envs.VLLM_USE_V1:
|
||||
@ -545,38 +547,15 @@ def _multinomial(
|
||||
def _top_k_top_p_multinomial_with_flashinfer(
|
||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
|
||||
num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
|
||||
max_top_k_round = 32
|
||||
if num_samples > 1:
|
||||
probs = probs.repeat_interleave(num_samples, dim=0)
|
||||
top_ks = top_ks.repeat_interleave(num_samples)
|
||||
top_ps = top_ps.repeat_interleave(num_samples)
|
||||
batch_size = probs.shape[0]
|
||||
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||
device=probs.device)
|
||||
if seq_groups is None:
|
||||
uniform_samples.uniform_()
|
||||
else:
|
||||
sample_idx = 0
|
||||
for seq_group in seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
stride = len(seq_ids) * num_samples
|
||||
assert seq_group.generator is not None
|
||||
uniform_samples[:, sample_idx:sample_idx +
|
||||
stride].uniform_(generator=seq_group.generator)
|
||||
sample_idx += stride
|
||||
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
|
||||
batch_next_token_ids = flashinfer_top_k_top_p_sampling(
|
||||
probs,
|
||||
uniform_samples,
|
||||
top_ks,
|
||||
top_ps,
|
||||
)
|
||||
if not success.all():
|
||||
warnings.warn("FlashInfer rejection sampling failed, fallback.",
|
||||
stacklevel=1)
|
||||
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
|
||||
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
|
||||
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||
probs, uniform_samples[0])
|
||||
return batch_next_token_ids.view(-1, num_samples)
|
||||
|
||||
|
||||
@ -712,19 +691,14 @@ def _sample_with_torch(
|
||||
seq_groups)
|
||||
|
||||
if flashinfer_top_k_top_p_sampling is not None:
|
||||
multinomial_samples[
|
||||
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
|
||||
probs[long_sample_indices],
|
||||
sampling_tensors.top_ks[long_sample_indices],
|
||||
sampling_tensors.top_ps[long_sample_indices],
|
||||
max_n_in_batch,
|
||||
seq_groups_arg,
|
||||
)
|
||||
else:
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[long_sample_indices],
|
||||
max_n_in_batch,
|
||||
seq_groups=seq_groups_arg)
|
||||
logger.warning("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[long_sample_indices],
|
||||
max_n_in_batch,
|
||||
seq_groups=seq_groups_arg)
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
|
||||
@ -31,21 +31,10 @@ class TopKTopPSampler(nn.Module):
|
||||
if current_platform.is_cuda():
|
||||
if is_flashinfer_available:
|
||||
flashinfer_version = flashinfer.__version__
|
||||
if flashinfer_version >= "0.2.3":
|
||||
# FIXME(DefTruth): Currently, we have errors when using
|
||||
# FlashInfer>=v0.2.3 for top-p & top-k sampling. As a
|
||||
# workaround, we disable FlashInfer for top-p & top-k
|
||||
# sampling by default while FlashInfer>=v0.2.3.
|
||||
# The sampling API removes the success return value
|
||||
# of all sampling API, which is not compatible with
|
||||
# earlier design.
|
||||
# https://github.com/flashinfer-ai/flashinfer/releases/
|
||||
# tag/v0.2.3
|
||||
logger.info(
|
||||
"Currently, FlashInfer top-p & top-k sampling sampler "
|
||||
"is disabled because FlashInfer>=v0.2.3 is not "
|
||||
"backward compatible. Falling back to the PyTorch-"
|
||||
"native implementation of top-p & top-k sampling.")
|
||||
if flashinfer_version < "0.2.3":
|
||||
logger.warning(
|
||||
"FlashInfer version >= 0.2.3 required. "
|
||||
"Falling back to default sampling implementation.")
|
||||
self.forward = self.forward_native
|
||||
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||
@ -106,6 +95,11 @@ class TopKTopPSampler(nn.Module):
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
return random_sample(probs, generators)
|
||||
if generators:
|
||||
logger.warning("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
return flashinfer_sample(probs, k, p, generators)
|
||||
|
||||
def forward_tpu(
|
||||
@ -280,36 +274,18 @@ def flashinfer_sample(
|
||||
the synchronization overhead.
|
||||
"""
|
||||
assert not (k is None and p is None)
|
||||
max_top_k_round = 32
|
||||
batch_size = probs.shape[0]
|
||||
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||
device=probs.device)
|
||||
if len(generators) != batch_size:
|
||||
uniform_samples.uniform_()
|
||||
if generators:
|
||||
for i, generator in generators.items():
|
||||
uniform_samples[:, i].uniform_(generator=generator)
|
||||
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, uniform_samples, p, deterministic=True)
|
||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, p, deterministic=True)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, uniform_samples, k, deterministic=True)
|
||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids, success = (
|
||||
flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, k, p, deterministic=True))
|
||||
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
||||
probs, k, p, deterministic=True))
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
if not success.all():
|
||||
if k is not None:
|
||||
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
||||
if p is not None:
|
||||
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
||||
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||
probs, uniform_samples[0], deterministic=True)
|
||||
return next_token_ids.view(-1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user