diff --git a/docker/Dockerfile b/docker/Dockerfile index 17adb7a92dc19..3ee84eb55a61c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -255,9 +255,10 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ + # uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.4/flashinfer_python-0.2.4+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \ # TESTING: install FlashInfer from source to test 2.7.0 final RC FLASHINFER_ENABLE_AOT=1 TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' \ - uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.2.post1" ; \ + uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.4" ; \ fi COPY examples examples COPY benchmarks benchmarks diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 8884f8ae70b8e..6ef61f2ff4069 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -169,7 +169,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, @pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("n_rep", [100]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_flashinfer", [True, False]) +# @pytest.mark.parametrize("use_flashinfer", [True, False]) +# Not testing FlashInfer now, since 0.2.3 API removed the ability +# to pass in uniform samples. +@pytest.mark.parametrize("use_flashinfer", [False]) @torch.inference_mode() def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, frac_seeded: float, n_rep: int, device: str, @@ -214,7 +217,10 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", [3, 8, 32, 128]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_flashinfer", [True, False]) +# @pytest.mark.parametrize("use_flashinfer", [True, False]) +# Not testing FlashInfer now, since 0.2.3 API removed the ability +# to pass in uniform samples. +@pytest.mark.parametrize("use_flashinfer", [False]) @torch.inference_mode() def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int, device: str, use_flashinfer: bool): @@ -284,6 +290,10 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int, Test the flashinfer and nonflashinfer backend generate the same output metrics. """ + + pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed " + "the ability to pass in uniform samples.") + torch.set_default_device(device) torch.manual_seed(0) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 90340f8cff039..7b19d5750906d 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -647,6 +647,8 @@ def test_flashinfer_fallback(seed: int, device: str): if not envs.VLLM_USE_FLASHINFER_SAMPLER: pytest.skip("Flashinfer sampler is disabled") + pytest.skip("After FlashInfer 0.2.3, sampling will never fail") + set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index 8a5076412cfae..a8a713d446b79 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -1,14 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest import torch +from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs from torch import Generator -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.platforms import current_platform +from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, + is_flashinfer_available) DEVICE = "cuda" BATCH_SIZE = 1024 VOCAB_SIZE = 128 * 1024 +FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available + def test_topk_impl_equivalance(): @@ -35,3 +41,67 @@ def test_topk_impl_equivalance(): result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) assert torch.allclose(result1, result2) + + +def test_flashinfer_sampler(): + ''' + This test verifies that the FlashInfer top-k and top-p sampling + implementation produces the same results as the Python implementation. + + NOTE: FlashInfer did not directly expose an interface for fused top-k and + top-p prob renorm (it did provide fused sampling but we cannot compare + sampling results due to randomness), so we will compare the probability + renormed consequently by top-k and then top-p of FlashInfer implementation. + ''' + + if not FLASHINFER_ENABLED: + pytest.skip( + "FlashInfer not installed or not available on this platform.") + + with torch.device(DEVICE): + generator = Generator(device=DEVICE).manual_seed(42) + + # Generate random logits + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + + # Generate various top-k and top-p values + k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) + p_values = torch.rand( + (BATCH_SIZE, ), + generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] + + # Sometimes disable top-k (k=vocab_size) + k_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), VOCAB_SIZE) + + # Sometimes disable top-p (p=1.0) + p_values.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=torch.bool), 1.0) + + python_logits = apply_top_k_top_p( + logits=logits.clone(), + k=k_values, + p=p_values, + ) + python_probs = torch.softmax(python_logits, dim=-1) + + # FlashInfer only exposed renorm interfaces for probs so convert first + flashinfer_probs = torch.softmax(logits.clone(), dim=-1) + flashinfer_probs = top_k_renorm_probs( + probs=flashinfer_probs, + top_k=k_values, + ) + flashinfer_probs = top_p_renorm_probs( + probs=flashinfer_probs, + top_p=p_values, + ) + + # Compare the results + assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ + "FlashInfer and Python sampling implementations do not match!" diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 26a2760f76f62..af82b9dc93b70 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -123,12 +123,13 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): # for rejection sampling if self.use_flashinfer and chain_speculative_sampling is not None: batch_size, k, _ = draft_probs.shape - uniform_samples = self._create_uniform_samples( - seeded_seqs, batch_size, k, draft_probs.device) - output_token_ids, accepted_token_num, emitted_token_num \ - = chain_speculative_sampling( - draft_probs, draft_token_ids, uniform_samples, - target_with_bonus_probs) + + (output_token_ids, accepted_token_num, + emitted_token_num) = chain_speculative_sampling( + draft_probs, + draft_token_ids, + target_with_bonus_probs, + ) # num_emitted_tokens returned by flashinfer # does not include the bonus token diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2e2c46edfeaf9..d6b910e4b75a0 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """A layer that samples the next tokens from the model's outputs.""" import itertools -import warnings from collections.abc import Iterator from dataclasses import dataclass from importlib.util import find_spec @@ -24,7 +23,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - import flashinfer.sampling # yapf: disable from flashinfer.sampling import ( top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) @@ -33,6 +31,10 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): else: flashinfer_top_k_top_p_sampling = None +from vllm.logger import init_logger + +logger = init_logger(__name__) + def get_sampler() -> torch.nn.Module: if envs.VLLM_USE_V1: @@ -545,38 +547,15 @@ def _multinomial( def _top_k_top_p_multinomial_with_flashinfer( probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): - max_top_k_round = 32 if num_samples > 1: probs = probs.repeat_interleave(num_samples, dim=0) top_ks = top_ks.repeat_interleave(num_samples) top_ps = top_ps.repeat_interleave(num_samples) - batch_size = probs.shape[0] - uniform_samples = torch.empty((max_top_k_round, batch_size), - device=probs.device) - if seq_groups is None: - uniform_samples.uniform_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - uniform_samples[:, sample_idx:sample_idx + - stride].uniform_(generator=seq_group.generator) - sample_idx += stride - batch_next_token_ids, success = flashinfer_top_k_top_p_sampling( + batch_next_token_ids = flashinfer_top_k_top_p_sampling( probs, - uniform_samples, top_ks, top_ps, ) - if not success.all(): - warnings.warn("FlashInfer rejection sampling failed, fallback.", - stacklevel=1) - probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) - probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) - batch_next_token_ids = flashinfer.sampling.sampling_from_probs( - probs, uniform_samples[0]) return batch_next_token_ids.view(-1, num_samples) @@ -712,19 +691,14 @@ def _sample_with_torch( seq_groups) if flashinfer_top_k_top_p_sampling is not None: - multinomial_samples[ - sampling_type] = _top_k_top_p_multinomial_with_flashinfer( - probs[long_sample_indices], - sampling_tensors.top_ks[long_sample_indices], - sampling_tensors.top_ps[long_sample_indices], - max_n_in_batch, - seq_groups_arg, - ) - else: - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], - max_n_in_batch, - seq_groups=seq_groups_arg) + logger.warning("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") + + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], + max_n_in_batch, + seq_groups=seq_groups_arg) if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 745b81ded3f11..5d8b3f423b025 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -31,21 +31,10 @@ class TopKTopPSampler(nn.Module): if current_platform.is_cuda(): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ - if flashinfer_version >= "0.2.3": - # FIXME(DefTruth): Currently, we have errors when using - # FlashInfer>=v0.2.3 for top-p & top-k sampling. As a - # workaround, we disable FlashInfer for top-p & top-k - # sampling by default while FlashInfer>=v0.2.3. - # The sampling API removes the success return value - # of all sampling API, which is not compatible with - # earlier design. - # https://github.com/flashinfer-ai/flashinfer/releases/ - # tag/v0.2.3 - logger.info( - "Currently, FlashInfer top-p & top-k sampling sampler " - "is disabled because FlashInfer>=v0.2.3 is not " - "backward compatible. Falling back to the PyTorch-" - "native implementation of top-p & top-k sampling.") + if flashinfer_version < "0.2.3": + logger.warning( + "FlashInfer version >= 0.2.3 required. " + "Falling back to default sampling implementation.") self.forward = self.forward_native elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for @@ -106,6 +95,11 @@ class TopKTopPSampler(nn.Module): # not needed. This is because `random_sample` does not require # CPU-GPU synchronization while `flashinfer_sample` does. return random_sample(probs, generators) + if generators: + logger.warning("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") + return self.forward_native(logits, generators, k, p) return flashinfer_sample(probs, k, p, generators) def forward_tpu( @@ -280,36 +274,18 @@ def flashinfer_sample( the synchronization overhead. """ assert not (k is None and p is None) - max_top_k_round = 32 - batch_size = probs.shape[0] - uniform_samples = torch.empty((max_top_k_round, batch_size), - device=probs.device) - if len(generators) != batch_size: - uniform_samples.uniform_() - if generators: - for i, generator in generators.items(): - uniform_samples[:, i].uniform_(generator=generator) if k is None: # Top-p only. - next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( - probs, uniform_samples, p, deterministic=True) + next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( + probs, p, deterministic=True) elif p is None: # Top-k only. - next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( - probs, uniform_samples, k, deterministic=True) + next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( + probs, k, deterministic=True) else: # Both top-k and top-p. - next_token_ids, success = ( - flashinfer.sampling.top_k_top_p_sampling_from_probs( - probs, uniform_samples, k, p, deterministic=True)) + next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, k, p, deterministic=True)) - # NOTE: CPU-GPU synchronization happens here. - if not success.all(): - if k is not None: - probs = flashinfer.sampling.top_k_renorm_prob(probs, k) - if p is not None: - probs = flashinfer.sampling.top_p_renorm_prob(probs, p) - next_token_ids = flashinfer.sampling.sampling_from_probs( - probs, uniform_samples[0], deterministic=True) return next_token_ids.view(-1)