[Core] Use flashinfer sampling kernel when available (#7137)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Peng Guanwen 2024-08-19 11:24:03 +08:00 committed by GitHub
parent ff7ec82c4d
commit f710fb5265
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 129 additions and 27 deletions

View File

@ -192,7 +192,9 @@ steps:
- vllm/model_executor/layers - vllm/model_executor/layers
- vllm/sampling_metadata.py - vllm/sampling_metadata.py
- tests/samplers - tests/samplers
command: pytest -v -s samplers commands:
- pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
- label: LogitsProcessor Test # 5min - label: LogitsProcessor Test # 5min
mirror_hardwares: [amd] mirror_hardwares: [amd]

View File

@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################

View File

@ -8,6 +8,7 @@ import pytest
import torch import torch
from transformers import GenerationConfig, GenerationMixin from transformers import GenerationConfig, GenerationMixin
import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]] return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None) for prob in probs], None)
with patch("vllm.model_executor.layers.sampler._sample", mock_sample): # top-k and top-p is only calculated when flashinfer kernel is not available
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
patch("vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", None):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata) sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
assert sample_probs is not None assert sample_probs is not None
@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
pytest.skip("Flashinfer sampler is disabled")
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
def failing_flashinfer_sampling(*_args, **_kwargs):
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
with patch(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
assert sampler_output == fallback_sampler_output
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str): def test_sampler_repetition_penalty_mixed(device: str):

View File

@ -30,6 +30,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
@ -256,6 +257,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND": "VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
# Pipeline stage partition strategy # Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": "VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),

View File

@ -1,5 +1,7 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
import warnings
from importlib.util import find_spec
from math import inf from math import inf
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton from vllm.model_executor.layers.ops.sample import sample as sample_triton
import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata, from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors, SamplingTensors,
SequenceGroupToSample) SequenceGroupToSample)
@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SamplerOutput, PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceOutput) SequenceOutput)
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None
# (num_token_ids, num_parent_ids) per sequence group. # (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]] SampleResultType = List[Tuple[List[int], List[int]]]
@ -123,7 +136,7 @@ class Sampler(nn.Module):
logits = logits.to(torch.float) logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k: if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks) sampling_tensors.top_ks)
@ -476,14 +489,7 @@ def _multinomial(
seq_groups: Optional[List[SequenceGroupToSample]] = None, seq_groups: Optional[List[SequenceGroupToSample]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also probs = probs.repeat_interleave(num_samples, dim=0)
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs) q = torch.empty_like(probs)
if seq_groups is None: if seq_groups is None:
q.exponential_() q.exponential_()
@ -491,17 +497,57 @@ def _multinomial(
sample_idx = 0 sample_idx = 0
for seq_group in seq_groups: for seq_group in seq_groups:
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples stride = len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_( assert seq_group.generator is not None
generator=seq_group.generator) q[sample_idx:sample_idx +
sample_idx = next_sample_idx stride].exponential_(generator=seq_group.generator)
sample_idx += stride
return probs.div_(q).argmax(dim=1).view(-1, num_samples) return probs.div_(q).argmax(dim=1).view(-1, num_samples)
def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
top_ks = top_ks.repeat_interleave(num_samples)
top_ps = top_ps.repeat_interleave(num_samples)
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if seq_groups is None:
uniform_samples.uniform_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
uniform_samples[:, sample_idx:sample_idx +
stride].uniform_(generator=seq_group.generator)
sample_idx += stride
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
probs,
uniform_samples,
top_ks,
top_ps,
)
if not success.all():
warnings.warn("FlashInfer rejection sampling failed, fallback.",
stacklevel=1)
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0])
return batch_next_token_ids.view(-1, num_samples)
def _sample_with_torch( def _sample_with_torch(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
@ -564,18 +610,28 @@ def _sample_with_torch(
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch, max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of) sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else { seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
"seq_groups": seq_groups, seq_groups)
}
multinomial_samples[sampling_type] = _multinomial( if flashinfer_top_k_top_p_sampling is not None:
probs[long_sample_indices], max_best_of_in_batch, multinomial_samples[
**seeded_args) sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_best_of_in_batch,
seq_groups=seq_groups_arg)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[ sampled_token_ids_tensor[long_sample_indices] = \
long_sample_indices] = multinomial_samples[sampling_type] multinomial_samples[sampling_type].to(torch.long)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices] beam_search_logprobs = logprobs[sample_indices]
@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
def _sample( def _sample(
probs: torch.Tensor, logprobs: torch.Tensor, probs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, logprobs: torch.Tensor,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
""" """
Args: Args:
@ -713,6 +772,7 @@ def _sample(
probs, probs,
logprobs, logprobs,
sampling_metadata, sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor, include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs, modify_greedy_probs=modify_greedy_probs,
) )