mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 18:04:38 +08:00
[Core] Use flashinfer sampling kernel when available (#7137)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
ff7ec82c4d
commit
f710fb5265
@ -192,7 +192,9 @@ steps:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/sampling_metadata.py
|
||||
- tests/samplers
|
||||
command: pytest -v -s samplers
|
||||
commands:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: LogitsProcessor Test # 5min
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
|
||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||
# top-k and top-p is only calculated when flashinfer kernel is not available
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
|
||||
patch("vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", None):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
assert sample_probs is not None
|
||||
@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_flashinfer_fallback(seed: int, device: str):
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
pytest.skip("Flashinfer sampler is disabled")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
def failing_flashinfer_sampling(*_args, **_kwargs):
|
||||
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
|
||||
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert sampler_output == fallback_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
||||
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
||||
VLLM_TRACE_FUNCTION: int = 0
|
||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_SAMPLER: bool = False
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
@ -256,6 +257,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
||||
|
||||
# If set, vllm will use flashinfer sampler
|
||||
"VLLM_USE_FLASHINFER_SAMPLER":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
|
||||
|
||||
# Pipeline stage partition strategy
|
||||
"VLLM_PP_LAYER_PARTITION":
|
||||
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
import itertools
|
||||
import warnings
|
||||
from importlib.util import find_spec
|
||||
from math import inf
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
|
||||
if HAS_TRITON:
|
||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||
SequenceOutput)
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
import flashinfer.sampling
|
||||
# yapf: disable
|
||||
from flashinfer.sampling import (
|
||||
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
||||
|
||||
# yapf: enable
|
||||
else:
|
||||
flashinfer_top_k_top_p_sampling = None
|
||||
|
||||
# (num_token_ids, num_parent_ids) per sequence group.
|
||||
SampleResultType = List[Tuple[List[int], List[int]]]
|
||||
|
||||
@ -123,7 +136,7 @@ class Sampler(nn.Module):
|
||||
logits = logits.to(torch.float)
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
||||
|
||||
if do_top_p_top_k:
|
||||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||
sampling_tensors.top_ks)
|
||||
|
||||
@ -476,14 +489,7 @@ def _multinomial(
|
||||
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
||||
) -> torch.Tensor:
|
||||
if num_samples > 1:
|
||||
# This is equivalent to torch.repeat_interleaved (which also
|
||||
# forces a GPU<->CPU sync).
|
||||
# This allows us to do sampling with replacement by creating
|
||||
# num_samples copies of each row in the tensor, and then
|
||||
# batch sampling the resulting tensor.
|
||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||
probs.shape[1]).contiguous().view(
|
||||
-1, probs.shape[1])
|
||||
probs = probs.repeat_interleave(num_samples, dim=0)
|
||||
q = torch.empty_like(probs)
|
||||
if seq_groups is None:
|
||||
q.exponential_()
|
||||
@ -491,17 +497,57 @@ def _multinomial(
|
||||
sample_idx = 0
|
||||
for seq_group in seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
||||
q[sample_idx:next_sample_idx].exponential_(
|
||||
generator=seq_group.generator)
|
||||
sample_idx = next_sample_idx
|
||||
stride = len(seq_ids) * num_samples
|
||||
assert seq_group.generator is not None
|
||||
q[sample_idx:sample_idx +
|
||||
stride].exponential_(generator=seq_group.generator)
|
||||
sample_idx += stride
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
|
||||
|
||||
def _top_k_top_p_multinomial_with_flashinfer(
|
||||
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
|
||||
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
|
||||
max_top_k_round = 32
|
||||
if num_samples > 1:
|
||||
probs = probs.repeat_interleave(num_samples, dim=0)
|
||||
top_ks = top_ks.repeat_interleave(num_samples)
|
||||
top_ps = top_ps.repeat_interleave(num_samples)
|
||||
batch_size = probs.shape[0]
|
||||
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||
device=probs.device)
|
||||
if seq_groups is None:
|
||||
uniform_samples.uniform_()
|
||||
else:
|
||||
sample_idx = 0
|
||||
for seq_group in seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
stride = len(seq_ids) * num_samples
|
||||
assert seq_group.generator is not None
|
||||
uniform_samples[:, sample_idx:sample_idx +
|
||||
stride].uniform_(generator=seq_group.generator)
|
||||
sample_idx += stride
|
||||
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
|
||||
probs,
|
||||
uniform_samples,
|
||||
top_ks,
|
||||
top_ps,
|
||||
)
|
||||
if not success.all():
|
||||
warnings.warn("FlashInfer rejection sampling failed, fallback.",
|
||||
stacklevel=1)
|
||||
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
|
||||
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
|
||||
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||
probs, uniform_samples[0])
|
||||
return batch_next_token_ids.view(-1, num_samples)
|
||||
|
||||
|
||||
def _sample_with_torch(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||
@ -564,18 +610,28 @@ def _sample_with_torch(
|
||||
sampling_params = seq_group.sampling_params
|
||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||
sampling_params.best_of)
|
||||
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
||||
"seq_groups": seq_groups,
|
||||
}
|
||||
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
||||
seq_groups)
|
||||
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[long_sample_indices], max_best_of_in_batch,
|
||||
**seeded_args)
|
||||
if flashinfer_top_k_top_p_sampling is not None:
|
||||
multinomial_samples[
|
||||
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
|
||||
probs[long_sample_indices],
|
||||
sampling_tensors.top_ks[long_sample_indices],
|
||||
sampling_tensors.top_ps[long_sample_indices],
|
||||
max_best_of_in_batch,
|
||||
seq_groups_arg,
|
||||
)
|
||||
else:
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[long_sample_indices],
|
||||
max_best_of_in_batch,
|
||||
seq_groups=seq_groups_arg)
|
||||
|
||||
if sampled_token_ids_tensor is not None:
|
||||
# Store sampled tokens in output tensor.
|
||||
sampled_token_ids_tensor[
|
||||
long_sample_indices] = multinomial_samples[sampling_type]
|
||||
sampled_token_ids_tensor[long_sample_indices] = \
|
||||
multinomial_samples[sampling_type].to(torch.long)
|
||||
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
beam_search_logprobs = logprobs[sample_indices]
|
||||
@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
|
||||
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor, logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
@ -713,6 +772,7 @@ def _sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sampling_metadata,
|
||||
sampling_tensors,
|
||||
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user