mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 06:09:10 +08:00
[Core] Use flashinfer sampling kernel when available (#7137)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
ff7ec82c4d
commit
f710fb5265
@ -192,7 +192,9 @@ steps:
|
|||||||
- vllm/model_executor/layers
|
- vllm/model_executor/layers
|
||||||
- vllm/sampling_metadata.py
|
- vllm/sampling_metadata.py
|
||||||
- tests/samplers
|
- tests/samplers
|
||||||
command: pytest -v -s samplers
|
commands:
|
||||||
|
- pytest -v -s samplers
|
||||||
|
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||||
|
|
||||||
- label: LogitsProcessor Test # 5min
|
- label: LogitsProcessor Test # 5min
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
|||||||
@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
|
|||||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||||
#################### vLLM installation IMAGE ####################
|
#################### vLLM installation IMAGE ####################
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from transformers import GenerationConfig, GenerationMixin
|
from transformers import GenerationConfig, GenerationMixin
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||||
for prob in probs], None)
|
for prob in probs], None)
|
||||||
|
|
||||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
# top-k and top-p is only calculated when flashinfer kernel is not available
|
||||||
|
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
|
||||||
|
patch("vllm.model_executor.layers.sampler."
|
||||||
|
"flashinfer_top_k_top_p_sampling", None):
|
||||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
assert sample_probs is not None
|
assert sample_probs is not None
|
||||||
@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_flashinfer_fallback(seed: int, device: str):
|
||||||
|
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||||
|
pytest.skip("Flashinfer sampler is disabled")
|
||||||
|
|
||||||
|
set_random_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
def failing_flashinfer_sampling(*_args, **_kwargs):
|
||||||
|
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
n=random.randint(1, 10),
|
||||||
|
seed=random.randint(0, 10000),
|
||||||
|
)
|
||||||
|
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||||
|
sampling_params, device)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"vllm.model_executor.layers.sampler."
|
||||||
|
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
|
||||||
|
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||||
|
sampling_params, device)
|
||||||
|
|
||||||
|
assert sampler_output == fallback_sampler_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_sampler_repetition_penalty_mixed(device: str):
|
def test_sampler_repetition_penalty_mixed(device: str):
|
||||||
|
|
||||||
|
|||||||
@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
||||||
VLLM_TRACE_FUNCTION: int = 0
|
VLLM_TRACE_FUNCTION: int = 0
|
||||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||||
|
VLLM_USE_FLASHINFER_SAMPLER: bool = False
|
||||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||||
@ -256,6 +257,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ATTENTION_BACKEND":
|
"VLLM_ATTENTION_BACKEND":
|
||||||
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
||||||
|
|
||||||
|
# If set, vllm will use flashinfer sampler
|
||||||
|
"VLLM_USE_FLASHINFER_SAMPLER":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
|
||||||
|
|
||||||
# Pipeline stage partition strategy
|
# Pipeline stage partition strategy
|
||||||
"VLLM_PP_LAYER_PARTITION":
|
"VLLM_PP_LAYER_PARTITION":
|
||||||
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
|
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
import itertools
|
import itertools
|
||||||
|
import warnings
|
||||||
|
from importlib.util import find_spec
|
||||||
from math import inf
|
from math import inf
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
|
|||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||||
SamplingTensors,
|
SamplingTensors,
|
||||||
SequenceGroupToSample)
|
SequenceGroupToSample)
|
||||||
@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
|||||||
PromptLogprobs, SampleLogprobs, SamplerOutput,
|
PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
|
|
||||||
|
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||||
|
import flashinfer.sampling
|
||||||
|
# yapf: disable
|
||||||
|
from flashinfer.sampling import (
|
||||||
|
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
else:
|
||||||
|
flashinfer_top_k_top_p_sampling = None
|
||||||
|
|
||||||
# (num_token_ids, num_parent_ids) per sequence group.
|
# (num_token_ids, num_parent_ids) per sequence group.
|
||||||
SampleResultType = List[Tuple[List[int], List[int]]]
|
SampleResultType = List[Tuple[List[int], List[int]]]
|
||||||
|
|
||||||
@ -123,7 +136,7 @@ class Sampler(nn.Module):
|
|||||||
logits = logits.to(torch.float)
|
logits = logits.to(torch.float)
|
||||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
||||||
|
|
||||||
if do_top_p_top_k:
|
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
|
||||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||||
sampling_tensors.top_ks)
|
sampling_tensors.top_ks)
|
||||||
|
|
||||||
@ -476,14 +489,7 @@ def _multinomial(
|
|||||||
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if num_samples > 1:
|
if num_samples > 1:
|
||||||
# This is equivalent to torch.repeat_interleaved (which also
|
probs = probs.repeat_interleave(num_samples, dim=0)
|
||||||
# forces a GPU<->CPU sync).
|
|
||||||
# This allows us to do sampling with replacement by creating
|
|
||||||
# num_samples copies of each row in the tensor, and then
|
|
||||||
# batch sampling the resulting tensor.
|
|
||||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
|
||||||
probs.shape[1]).contiguous().view(
|
|
||||||
-1, probs.shape[1])
|
|
||||||
q = torch.empty_like(probs)
|
q = torch.empty_like(probs)
|
||||||
if seq_groups is None:
|
if seq_groups is None:
|
||||||
q.exponential_()
|
q.exponential_()
|
||||||
@ -491,17 +497,57 @@ def _multinomial(
|
|||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
for seq_group in seq_groups:
|
for seq_group in seq_groups:
|
||||||
seq_ids = seq_group.seq_ids
|
seq_ids = seq_group.seq_ids
|
||||||
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
stride = len(seq_ids) * num_samples
|
||||||
q[sample_idx:next_sample_idx].exponential_(
|
assert seq_group.generator is not None
|
||||||
generator=seq_group.generator)
|
q[sample_idx:sample_idx +
|
||||||
sample_idx = next_sample_idx
|
stride].exponential_(generator=seq_group.generator)
|
||||||
|
sample_idx += stride
|
||||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||||
|
|
||||||
|
|
||||||
|
def _top_k_top_p_multinomial_with_flashinfer(
|
||||||
|
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
|
||||||
|
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
|
||||||
|
max_top_k_round = 32
|
||||||
|
if num_samples > 1:
|
||||||
|
probs = probs.repeat_interleave(num_samples, dim=0)
|
||||||
|
top_ks = top_ks.repeat_interleave(num_samples)
|
||||||
|
top_ps = top_ps.repeat_interleave(num_samples)
|
||||||
|
batch_size = probs.shape[0]
|
||||||
|
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||||
|
device=probs.device)
|
||||||
|
if seq_groups is None:
|
||||||
|
uniform_samples.uniform_()
|
||||||
|
else:
|
||||||
|
sample_idx = 0
|
||||||
|
for seq_group in seq_groups:
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
|
stride = len(seq_ids) * num_samples
|
||||||
|
assert seq_group.generator is not None
|
||||||
|
uniform_samples[:, sample_idx:sample_idx +
|
||||||
|
stride].uniform_(generator=seq_group.generator)
|
||||||
|
sample_idx += stride
|
||||||
|
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
|
||||||
|
probs,
|
||||||
|
uniform_samples,
|
||||||
|
top_ks,
|
||||||
|
top_ps,
|
||||||
|
)
|
||||||
|
if not success.all():
|
||||||
|
warnings.warn("FlashInfer rejection sampling failed, fallback.",
|
||||||
|
stacklevel=1)
|
||||||
|
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
|
||||||
|
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
|
||||||
|
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||||
|
probs, uniform_samples[0])
|
||||||
|
return batch_next_token_ids.view(-1, num_samples)
|
||||||
|
|
||||||
|
|
||||||
def _sample_with_torch(
|
def _sample_with_torch(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
sampling_tensors: SamplingTensors,
|
||||||
include_gpu_probs_tensor: bool,
|
include_gpu_probs_tensor: bool,
|
||||||
modify_greedy_probs: bool,
|
modify_greedy_probs: bool,
|
||||||
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||||
@ -564,18 +610,28 @@ def _sample_with_torch(
|
|||||||
sampling_params = seq_group.sampling_params
|
sampling_params = seq_group.sampling_params
|
||||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||||
sampling_params.best_of)
|
sampling_params.best_of)
|
||||||
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
||||||
"seq_groups": seq_groups,
|
seq_groups)
|
||||||
}
|
|
||||||
|
|
||||||
|
if flashinfer_top_k_top_p_sampling is not None:
|
||||||
|
multinomial_samples[
|
||||||
|
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
|
||||||
|
probs[long_sample_indices],
|
||||||
|
sampling_tensors.top_ks[long_sample_indices],
|
||||||
|
sampling_tensors.top_ps[long_sample_indices],
|
||||||
|
max_best_of_in_batch,
|
||||||
|
seq_groups_arg,
|
||||||
|
)
|
||||||
|
else:
|
||||||
multinomial_samples[sampling_type] = _multinomial(
|
multinomial_samples[sampling_type] = _multinomial(
|
||||||
probs[long_sample_indices], max_best_of_in_batch,
|
probs[long_sample_indices],
|
||||||
**seeded_args)
|
max_best_of_in_batch,
|
||||||
|
seq_groups=seq_groups_arg)
|
||||||
|
|
||||||
if sampled_token_ids_tensor is not None:
|
if sampled_token_ids_tensor is not None:
|
||||||
# Store sampled tokens in output tensor.
|
# Store sampled tokens in output tensor.
|
||||||
sampled_token_ids_tensor[
|
sampled_token_ids_tensor[long_sample_indices] = \
|
||||||
long_sample_indices] = multinomial_samples[sampling_type]
|
multinomial_samples[sampling_type].to(torch.long)
|
||||||
|
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
beam_search_logprobs = logprobs[sample_indices]
|
beam_search_logprobs = logprobs[sample_indices]
|
||||||
@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
probs: torch.Tensor, logprobs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
logprobs: torch.Tensor,
|
||||||
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
sampling_metadata: SamplingMetadata,
|
||||||
|
sampling_tensors: SamplingTensors,
|
||||||
|
include_gpu_probs_tensor: bool,
|
||||||
|
modify_greedy_probs: bool,
|
||||||
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -713,6 +772,7 @@ def _sample(
|
|||||||
probs,
|
probs,
|
||||||
logprobs,
|
logprobs,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
|
sampling_tensors,
|
||||||
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
||||||
modify_greedy_probs=modify_greedy_probs,
|
modify_greedy_probs=modify_greedy_probs,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user