[Sampler] Adapt to FlashInfer 0.2.3 sampler API (#15777)

Signed-off-by: Bowen Wang <abmfy@icloud.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Bowen Wang 2025-05-16 15:14:03 -07:00 committed by GitHub
parent aef94c6d07
commit 7fdfa01530
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 122 additions and 88 deletions

View File

@ -255,9 +255,10 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
. /etc/environment && \ . /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
# uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.4/flashinfer_python-0.2.4+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
# TESTING: install FlashInfer from source to test 2.7.0 final RC # TESTING: install FlashInfer from source to test 2.7.0 final RC
FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \ FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.2.post1" ; \ uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.4" ; \
fi fi
COPY examples examples COPY examples examples
COPY benchmarks benchmarks COPY benchmarks benchmarks

View File

@ -169,7 +169,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100]) @pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False]) # @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode() @torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, device: str, frac_seeded: float, n_rep: int, device: str,
@ -214,7 +217,10 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128]) @pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False]) # @pytest.mark.parametrize("use_flashinfer", [True, False])
# Not testing FlashInfer now, since 0.2.3 API removed the ability
# to pass in uniform samples.
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode() @torch.inference_mode()
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int, def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool): device: str, use_flashinfer: bool):
@ -284,6 +290,10 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
Test the flashinfer and nonflashinfer backend generate Test the flashinfer and nonflashinfer backend generate
the same output metrics. the same output metrics.
""" """
pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
"the ability to pass in uniform samples.")
torch.set_default_device(device) torch.set_default_device(device)
torch.manual_seed(0) torch.manual_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)

View File

@ -647,6 +647,8 @@ def test_flashinfer_fallback(seed: int, device: str):
if not envs.VLLM_USE_FLASHINFER_SAMPLER: if not envs.VLLM_USE_FLASHINFER_SAMPLER:
pytest.skip("Flashinfer sampler is disabled") pytest.skip("Flashinfer sampler is disabled")
pytest.skip("After FlashInfer 0.2.3, sampling will never fail")
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)

View File

@ -1,14 +1,20 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
import torch import torch
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
from torch import Generator from torch import Generator
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
is_flashinfer_available)
DEVICE = "cuda" DEVICE = "cuda"
BATCH_SIZE = 1024 BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024 VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
def test_topk_impl_equivalance(): def test_topk_impl_equivalance():
@ -35,3 +41,67 @@ def test_topk_impl_equivalance():
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
assert torch.allclose(result1, result2) assert torch.allclose(result1, result2)
def test_flashinfer_sampler():
'''
This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation.
NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
'''
if not FLASHINFER_ENABLED:
pytest.skip(
"FlashInfer not installed or not available on this platform.")
with torch.device(DEVICE):
generator = Generator(device=DEVICE).manual_seed(42)
# Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
# Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand(
(BATCH_SIZE, ),
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
# Sometimes disable top-k (k=vocab_size)
k_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), VOCAB_SIZE)
# Sometimes disable top-p (p=1.0)
p_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), 1.0)
python_logits = apply_top_k_top_p(
logits=logits.clone(),
k=k_values,
p=p_values,
)
python_probs = torch.softmax(python_logits, dim=-1)
# FlashInfer only exposed renorm interfaces for probs so convert first
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
flashinfer_probs = top_k_renorm_probs(
probs=flashinfer_probs,
top_k=k_values,
)
flashinfer_probs = top_p_renorm_probs(
probs=flashinfer_probs,
top_p=p_values,
)
# Compare the results
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
"FlashInfer and Python sampling implementations do not match!"

View File

@ -123,12 +123,13 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# for rejection sampling # for rejection sampling
if self.use_flashinfer and chain_speculative_sampling is not None: if self.use_flashinfer and chain_speculative_sampling is not None:
batch_size, k, _ = draft_probs.shape batch_size, k, _ = draft_probs.shape
uniform_samples = self._create_uniform_samples(
seeded_seqs, batch_size, k, draft_probs.device) (output_token_ids, accepted_token_num,
output_token_ids, accepted_token_num, emitted_token_num \ emitted_token_num) = chain_speculative_sampling(
= chain_speculative_sampling( draft_probs,
draft_probs, draft_token_ids, uniform_samples, draft_token_ids,
target_with_bonus_probs) target_with_bonus_probs,
)
# num_emitted_tokens returned by flashinfer # num_emitted_tokens returned by flashinfer
# does not include the bonus token # does not include the bonus token

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
import warnings
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
@ -24,7 +23,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable # yapf: disable
from flashinfer.sampling import ( from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
@ -33,6 +31,10 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
else: else:
flashinfer_top_k_top_p_sampling = None flashinfer_top_k_top_p_sampling = None
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_sampler() -> torch.nn.Module: def get_sampler() -> torch.nn.Module:
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
@ -545,38 +547,15 @@ def _multinomial(
def _top_k_top_p_multinomial_with_flashinfer( def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1: if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0) probs = probs.repeat_interleave(num_samples, dim=0)
top_ks = top_ks.repeat_interleave(num_samples) top_ks = top_ks.repeat_interleave(num_samples)
top_ps = top_ps.repeat_interleave(num_samples) top_ps = top_ps.repeat_interleave(num_samples)
batch_size = probs.shape[0] batch_next_token_ids = flashinfer_top_k_top_p_sampling(
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if seq_groups is None:
uniform_samples.uniform_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
uniform_samples[:, sample_idx:sample_idx +
stride].uniform_(generator=seq_group.generator)
sample_idx += stride
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
probs, probs,
uniform_samples,
top_ks, top_ks,
top_ps, top_ps,
) )
if not success.all():
warnings.warn("FlashInfer rejection sampling failed, fallback.",
stacklevel=1)
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0])
return batch_next_token_ids.view(-1, num_samples) return batch_next_token_ids.view(-1, num_samples)
@ -712,19 +691,14 @@ def _sample_with_torch(
seq_groups) seq_groups)
if flashinfer_top_k_top_p_sampling is not None: if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[ logger.warning("FlashInfer 0.2.3+ does not support "
sampling_type] = _top_k_top_p_multinomial_with_flashinfer( "per-request generators. Falling back to "
probs[long_sample_indices], "PyTorch-native implementation.")
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices], multinomial_samples[sampling_type] = _multinomial(
max_n_in_batch, probs[long_sample_indices],
seq_groups_arg, max_n_in_batch,
) seq_groups=seq_groups_arg)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_n_in_batch,
seq_groups=seq_groups_arg)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.

View File

@ -31,21 +31,10 @@ class TopKTopPSampler(nn.Module):
if current_platform.is_cuda(): if current_platform.is_cuda():
if is_flashinfer_available: if is_flashinfer_available:
flashinfer_version = flashinfer.__version__ flashinfer_version = flashinfer.__version__
if flashinfer_version >= "0.2.3": if flashinfer_version < "0.2.3":
# FIXME(DefTruth): Currently, we have errors when using logger.warning(
# FlashInfer>=v0.2.3 for top-p & top-k sampling. As a "FlashInfer version >= 0.2.3 required. "
# workaround, we disable FlashInfer for top-p & top-k "Falling back to default sampling implementation.")
# sampling by default while FlashInfer>=v0.2.3.
# The sampling API removes the success return value
# of all sampling API, which is not compatible with
# earlier design.
# https://github.com/flashinfer-ai/flashinfer/releases/
# tag/v0.2.3
logger.info(
"Currently, FlashInfer top-p & top-k sampling sampler "
"is disabled because FlashInfer>=v0.2.3 is not "
"backward compatible. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling.")
self.forward = self.forward_native self.forward = self.forward_native
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
@ -106,6 +95,11 @@ class TopKTopPSampler(nn.Module):
# not needed. This is because `random_sample` does not require # not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does. # CPU-GPU synchronization while `flashinfer_sample` does.
return random_sample(probs, generators) return random_sample(probs, generators)
if generators:
logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(probs, k, p, generators) return flashinfer_sample(probs, k, p, generators)
def forward_tpu( def forward_tpu(
@ -280,36 +274,18 @@ def flashinfer_sample(
the synchronization overhead. the synchronization overhead.
""" """
assert not (k is None and p is None) assert not (k is None and p is None)
max_top_k_round = 32
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if len(generators) != batch_size:
uniform_samples.uniform_()
if generators:
for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator)
if k is None: if k is None:
# Top-p only. # Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True) probs, p, deterministic=True)
elif p is None: elif p is None:
# Top-k only. # Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True) probs, k, deterministic=True)
else: else:
# Both top-k and top-p. # Both top-k and top-p.
next_token_ids, success = ( next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
flashinfer.sampling.top_k_top_p_sampling_from_probs( probs, k, p, deterministic=True))
probs, uniform_samples, k, p, deterministic=True))
# NOTE: CPU-GPU synchronization happens here.
if not success.all():
if k is not None:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if p is not None:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True)
return next_token_ids.view(-1) return next_token_ids.view(-1)