mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 03:07:03 +08:00
[refactor] remove triton based sampler (#8524)
This commit is contained in:
parent
cca61642e0
commit
546034b466
@ -1,52 +0,0 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_3d", [True, False])
|
||||
def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
|
||||
device = "cuda"
|
||||
for seed in range(512):
|
||||
set_random_seed(seed)
|
||||
rows = random.randint(1, 512)
|
||||
cols = random.randint(1, 64000)
|
||||
if use_3d:
|
||||
third_dim = random.randint(2, 10)
|
||||
dims = [rows, third_dim, cols]
|
||||
else:
|
||||
dims = [rows, cols]
|
||||
seeds = torch.randint(torch.iinfo(torch.long).min,
|
||||
torch.iinfo(torch.long).max, (rows, ),
|
||||
device=device)
|
||||
|
||||
# Test that the same seed produces the same output
|
||||
out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||
out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||
torch.testing.assert_close(out, out2)
|
||||
# del to save memory
|
||||
del out2
|
||||
|
||||
out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||
torch.testing.assert_close(out, out3)
|
||||
# del to save memory
|
||||
del out3
|
||||
|
||||
# Initialize out tensor with garbage to ensure that it is overwritten
|
||||
out_with_tensor = seeded_uniform(
|
||||
*dims,
|
||||
out=torch.full(
|
||||
(*dims, ),
|
||||
-1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
seeds=seeds,
|
||||
dtype=dtype,
|
||||
)
|
||||
torch.testing.assert_close(out, out_with_tensor)
|
||||
@ -1,209 +0,0 @@
|
||||
import gc
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import (_sample_triton,
|
||||
_uniform_to_exponential,
|
||||
sample)
|
||||
from vllm.model_executor.sampling_metadata import SamplingTensors
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
|
||||
get_num_triton_sampler_splits)
|
||||
|
||||
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
|
||||
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup():
|
||||
yield
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
|
||||
idx = tl.arange(0, n)
|
||||
x = tl.load(input + idx)
|
||||
y = _uniform_to_exponential(x)
|
||||
tl.store(output + idx, y)
|
||||
|
||||
|
||||
def test_uniform_to_exponential():
|
||||
"""Test that we can convert uniform to exponential without div by 0."""
|
||||
input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
|
||||
_uniform_to_exponential_kernel[(1, )](input, output, 2)
|
||||
assert torch.all(torch.isfinite(output))
|
||||
assert torch.all(output > 0)
|
||||
assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
|
||||
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
|
||||
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
|
||||
@pytest.mark.parametrize("seed", [1337])
|
||||
@pytest.mark.parametrize("vocab_size",
|
||||
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||
@pytest.mark.parametrize("save_logprobs", [True, False])
|
||||
def test_sample_decoding_only(random_sampling, max_best_of,
|
||||
modify_greedy_probs, seed, vocab_size,
|
||||
save_logprobs):
|
||||
set_random_seed(seed)
|
||||
bs = 8
|
||||
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
|
||||
for i in range(bs):
|
||||
probs[i, i * (vocab_size // bs)] = 1.0
|
||||
logprobs = torch.rand_like(probs)
|
||||
sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
|
||||
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||
if random_sampling == "mixed":
|
||||
random_sampling_mask = (torch.rand(
|
||||
(1, bs), device="cuda") < 0.5).expand(n_splits, bs)
|
||||
elif random_sampling:
|
||||
random_sampling_mask = torch.ones((n_splits, bs),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
else:
|
||||
random_sampling_mask = torch.zeros((n_splits, bs),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
|
||||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, bs),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
#The current _sample_triton does not utilize the
|
||||
# libentry decoration. The purpose of adding this patch is to test
|
||||
# the correctness of libentry.
|
||||
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
|
||||
LibEntry(_sample_triton)):
|
||||
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
_save_modified_probs=True)
|
||||
assert sampled_tokens.shape == (bs, max_best_of)
|
||||
for i in range(bs):
|
||||
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
|
||||
request_uses_random_sampling = random_sampling_mask[0, i]
|
||||
if modify_greedy_probs and not request_uses_random_sampling:
|
||||
# If we are modifying greedy probs and the request is greedy,
|
||||
# we want to make sure the probs tensor is modified in place
|
||||
torch.testing.assert_close(
|
||||
probs[i][sampled_tokens[i]],
|
||||
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
|
||||
assert torch.sum(probs[i]) == 1.0
|
||||
torch.testing.assert_close(
|
||||
sampled_modified_probs[i][0],
|
||||
torch.full_like(sampled_modified_probs[i][0], 1.0))
|
||||
elif request_uses_random_sampling:
|
||||
# If the request is random, we want to make sure
|
||||
# sampled_modified_probs tensor has noise added
|
||||
# (and thus is different from probs tensor)
|
||||
assert not torch.allclose(sampled_modified_probs[i][0],
|
||||
probs[i][sampled_tokens[i]])
|
||||
elif not request_uses_random_sampling:
|
||||
# If the request is greedy and we are not modifying greedy probs,
|
||||
# we want to make sure sampled_modified_probs tensor is the same as
|
||||
# the probs tensor.
|
||||
torch.testing.assert_close(sampled_modified_probs[i],
|
||||
probs[i][sampled_tokens[i]])
|
||||
|
||||
if save_logprobs:
|
||||
assert sampled_logprobs.shape == (bs, max_best_of)
|
||||
for i in range(bs):
|
||||
for best_of in range(max_best_of):
|
||||
assert torch.all(sampled_logprobs[i] == logprobs[i][
|
||||
sampled_tokens[i, best_of]])
|
||||
else:
|
||||
assert sampled_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
|
||||
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
|
||||
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
|
||||
@pytest.mark.parametrize("seed", [1337])
|
||||
@pytest.mark.parametrize("vocab_size",
|
||||
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||
def test_sample_prompt_logprobs(random_sampling, max_best_of,
|
||||
modify_greedy_probs, seed, vocab_size):
|
||||
|
||||
set_random_seed(seed)
|
||||
prompt_sizes = [16, 32, 64, 128] * 2
|
||||
samples = 8
|
||||
bs = samples + sum(prompt_sizes)
|
||||
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
|
||||
for i in range(bs):
|
||||
probs[i, i * (vocab_size // bs)] = 1.0
|
||||
logprobs = torch.rand_like(probs)
|
||||
sample_indices = torch.tensor(prompt_sizes,
|
||||
dtype=torch.long,
|
||||
device="cuda").cumsum_(0)
|
||||
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||
if random_sampling == "mixed":
|
||||
random_sampling_mask = torch.rand(
|
||||
(n_splits, samples), device="cuda") < 0.5
|
||||
elif random_sampling:
|
||||
random_sampling_mask = torch.ones((n_splits, samples),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
else:
|
||||
random_sampling_mask = torch.zeros((n_splits, samples),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
|
||||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, samples),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
#ditto
|
||||
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
|
||||
LibEntry(_sample_triton)):
|
||||
sampled_tokens, sampled_logprobs, _ = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=True)
|
||||
assert sampled_tokens.shape == (samples, max_best_of)
|
||||
assert sampled_logprobs.shape == (samples, max_best_of)
|
||||
for i, t in enumerate(sample_indices):
|
||||
assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
|
||||
for best_of in range(max_best_of):
|
||||
assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
|
||||
[sampled_tokens[i, best_of]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(16)))
|
||||
def test_get_sequence_seeds(seed):
|
||||
"""Ensure that we get a different child seed from base
|
||||
seed + extra entropy"""
|
||||
starting_seed = seed
|
||||
seq_seed = None
|
||||
extra_entropy = 1
|
||||
for i in range(512):
|
||||
new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
|
||||
i,
|
||||
seeds_to_generate=1,
|
||||
is_greedy=False)[0]
|
||||
new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
|
||||
starting_seed,
|
||||
i,
|
||||
extra_entropy,
|
||||
seeds_to_generate=1,
|
||||
is_greedy=False)[0]
|
||||
assert new_seq_seed_extra_entropy != new_seq_seed
|
||||
assert seq_seed != new_seq_seed
|
||||
seq_seed = new_seq_seed
|
||||
@ -1,157 +0,0 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def seeded_uniform(
|
||||
*size,
|
||||
seeds: torch.Tensor,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
pin_memory: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
"""Similar to torch.rand, but allows for seeds to be set per row.
|
||||
|
||||
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
|
||||
If it is 3d, the additional seeds needed will be derived automatically
|
||||
in a deterministic fashion:
|
||||
[
|
||||
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
|
||||
]
|
||||
"""
|
||||
n_dims = len(size)
|
||||
|
||||
if n_dims > 3:
|
||||
raise ValueError("seeded_uniform only supports up to 3D tensors")
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(*size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
elif out.shape != size:
|
||||
raise ValueError("shape of out and size must be the same")
|
||||
|
||||
if n_dims == 3:
|
||||
n_rows, n_3d, n_cols = out.shape
|
||||
stride_row = out.stride(0)
|
||||
stride_3d = out.stride(1)
|
||||
elif n_dims == 2:
|
||||
n_rows, n_cols = out.shape
|
||||
n_3d = 1
|
||||
stride_row = out.stride(0)
|
||||
stride_3d = 1
|
||||
else:
|
||||
n_cols = out.shape[0]
|
||||
n_rows = 1
|
||||
n_3d = 1
|
||||
stride_row = 1
|
||||
stride_3d = 1
|
||||
|
||||
if seeds.ndim != 1:
|
||||
raise ValueError("seeds must be a 1D tensor")
|
||||
|
||||
if seeds.numel() != n_rows:
|
||||
raise ValueError(
|
||||
"seeds must have the same number of elements as out has rows")
|
||||
|
||||
# The philox PRNG Triton uses generates 4 random numbers at once.
|
||||
# Therefore, the most efficient use of it is to divide the
|
||||
# block size by 4, and then save the generated random numbers to
|
||||
# each of the 4 slices of the tensor.
|
||||
full_block_size = triton.next_power_of_2(n_cols)
|
||||
philox_block_size = max(full_block_size // 4, 1)
|
||||
n_slices = full_block_size // philox_block_size
|
||||
num_warps = 4
|
||||
# Manual tuning. This seems to give best performance on A100 for
|
||||
# simple kernels like this.
|
||||
if philox_block_size >= 8192:
|
||||
num_warps = 32
|
||||
elif philox_block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif philox_block_size >= 2048:
|
||||
num_warps = 8
|
||||
|
||||
_seeded_uniform_triton[(n_rows, n_3d)](
|
||||
out,
|
||||
seeds,
|
||||
stride_row,
|
||||
stride_3d,
|
||||
seeds.stride(0),
|
||||
n_rows,
|
||||
n_3d,
|
||||
n_cols,
|
||||
n_slices=n_slices,
|
||||
num_warps=num_warps,
|
||||
block_size=philox_block_size,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _seeded_uniform_triton(
|
||||
out_ptr: torch.Tensor,
|
||||
seed_ptr: torch.Tensor,
|
||||
out_row_stride: int,
|
||||
out_3d_stride: int,
|
||||
seed_row_stride: int,
|
||||
n_rows: int,
|
||||
n_3d: int,
|
||||
n_cols: int,
|
||||
n_slices: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Generate a random float32 number in [0, 1) for each element in the output
|
||||
tensor. The random numbers in a row generated using the seed for that row.
|
||||
|
||||
Args:
|
||||
out_ptr: The output tensor.
|
||||
seed_ptr: The per-row seeds to use for random number generation.
|
||||
out_row_stride: The stride between rows of the output tensor.
|
||||
out_3d_stride: The stride between 3D slices of the output tensor.
|
||||
seed_row_stride: The stride between rows of the seed tensor.
|
||||
n_rows: The number of rows in the output tensor.
|
||||
n_3d: The size of second dimension of the output tensor,
|
||||
if output tensor is 3D.
|
||||
n_cols: The number of columns in the output tensor.
|
||||
n_slices: The number of philox outputs to use.
|
||||
"""
|
||||
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
|
||||
|
||||
# Get the row index.
|
||||
row_idx = tl.program_id(axis=0)
|
||||
three_d_idx = tl.program_id(axis=1)
|
||||
|
||||
philox_offsets = tl.arange(0, block_size)
|
||||
# Get the seed for the current element.
|
||||
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
|
||||
if three_d_idx > 0:
|
||||
seed ^= three_d_idx
|
||||
# Generate random numbers in [0, 1).
|
||||
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
|
||||
|
||||
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
|
||||
three_d_idx * out_3d_stride)
|
||||
out1_offsets = philox_offsets
|
||||
tl.store(output_row_start_ptr + out1_offsets,
|
||||
out1,
|
||||
mask=out1_offsets < n_cols)
|
||||
if n_slices > 1:
|
||||
out2_offsets = tl.arange(block_size, block_size * 2)
|
||||
tl.store(output_row_start_ptr + out2_offsets,
|
||||
out2,
|
||||
mask=out2_offsets < n_cols)
|
||||
if n_slices > 2:
|
||||
out3_offsets = tl.arange(block_size * 2, block_size * 3)
|
||||
tl.store(output_row_start_ptr + out3_offsets,
|
||||
out3,
|
||||
mask=out3_offsets < n_cols)
|
||||
if n_slices > 3:
|
||||
out4_offsets = tl.arange(block_size * 3, block_size * 4)
|
||||
tl.store(output_row_start_ptr + out4_offsets,
|
||||
out4,
|
||||
mask=out4_offsets < n_cols)
|
||||
@ -1,394 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||
|
||||
_EPS: tl.constexpr = 1e-6
|
||||
|
||||
|
||||
def _multi_split_sample(
|
||||
probs: torch.Tensor,
|
||||
seeds: torch.Tensor,
|
||||
n_splits: int,
|
||||
sampled_tokens_size: Tuple[int, int],
|
||||
sampled_logprobs_size: Tuple[int, int],
|
||||
sample_indices: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
*,
|
||||
modify_greedy_probs: bool = False,
|
||||
save_logprobs: bool = False,
|
||||
):
|
||||
"""Sample tokens where vocab size is split into multiple parts
|
||||
(too large for Triton otherwise)."""
|
||||
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
|
||||
split_probs = probs.tensor_split(n_splits, 1)
|
||||
split_logprobs = logprobs.tensor_split(n_splits, 1)
|
||||
sampled_tokens_tmp = [
|
||||
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
|
||||
for _ in range(n_splits)
|
||||
]
|
||||
sampled_logprobs_tmp = [
|
||||
torch.empty(sampled_logprobs_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device) for _ in range(n_splits)
|
||||
]
|
||||
# We are purposefuly using sampled_tokens_size as we need to always
|
||||
# save modified probs in this case.
|
||||
sampled_modified_probs_tmp = [
|
||||
torch.empty(sampled_tokens_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device) for _ in range(n_splits)
|
||||
]
|
||||
for i in range(n_splits):
|
||||
n_samples = sample_indices.shape[0]
|
||||
n_cols = split_probs[i].shape[1]
|
||||
n_best = sampled_tokens_tmp[i].shape[1]
|
||||
uniform_noise = seeded_uniform(n_samples,
|
||||
n_best,
|
||||
n_cols,
|
||||
seeds=seeds[i].flatten(),
|
||||
device=split_probs[i].device,
|
||||
dtype=split_probs[i].dtype)
|
||||
# TODO(yard1): See if we can remove the contiguous() calls.
|
||||
# Will need kernel support.
|
||||
_sample(
|
||||
split_probs[i].contiguous(),
|
||||
split_logprobs[i].contiguous(),
|
||||
sample_indices,
|
||||
sampled_tokens_tmp[i],
|
||||
sampled_logprobs_tmp[i],
|
||||
sampled_modified_probs_tmp[i],
|
||||
seeds[i],
|
||||
uniform_noise,
|
||||
modify_greedy_probs=False,
|
||||
save_logprobs=save_logprobs,
|
||||
save_modified_probs=True,
|
||||
)
|
||||
if i > 0:
|
||||
# Add offset to sampled tokens
|
||||
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
|
||||
sampled_tokens = torch.stack(sampled_tokens_tmp)
|
||||
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
|
||||
# Reduce the results from the splits.
|
||||
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
|
||||
dim=0,
|
||||
keepdim=True)
|
||||
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
|
||||
if save_logprobs:
|
||||
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
|
||||
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
|
||||
else:
|
||||
sampled_logprobs = None
|
||||
sampled_modified_probs = sampled_modified_probs.squeeze(0)
|
||||
|
||||
if modify_greedy_probs:
|
||||
# We need to modify the greedy probs for the sampled tokens.
|
||||
# We can't do this in the kernel as we need to know the
|
||||
# sampled tokens.
|
||||
probs.fill_(0.0)
|
||||
probs.scatter_(1, sampled_tokens, 1.0)
|
||||
|
||||
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
|
||||
|
||||
|
||||
def sample(
|
||||
probs: torch.Tensor,
|
||||
seeds: torch.Tensor,
|
||||
*,
|
||||
max_best_of: int = 1,
|
||||
sample_indices: Optional[torch.Tensor] = None,
|
||||
logprobs: Optional[torch.Tensor] = None,
|
||||
modify_greedy_probs: bool = False,
|
||||
save_logprobs: bool = False,
|
||||
_save_modified_probs: bool = False, # pylint: disable=invalid-name
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""Sample tokens from probs. with per-sequence seeds.
|
||||
|
||||
Can sample from a subset of sequences through sample_indices.
|
||||
|
||||
Args:
|
||||
probs: Probabilities to sample from.
|
||||
shape = [batch_size, vocab_size]
|
||||
seeds: Per-sequence seed values.
|
||||
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
|
||||
max_best_of: Number of samples to generate per sequence.
|
||||
Sequence seed will be incremented by 1 each time.
|
||||
sample_indices: Indices of sequences to sample from.
|
||||
If not provided, will sample from all sequences.
|
||||
shape = [n]
|
||||
logprobs: Log-probabilities of the sampled tokens.
|
||||
Only used for saving the logprobs if save_logprobs is True.
|
||||
shape = [batch_size, vocab_size]
|
||||
modify_greedy_probs: Whether to modify the greedy probabilities
|
||||
for speculative sampling (sampled token = 1.0,
|
||||
everything else = 0.0).
|
||||
save_logprobs: Whether to save the log-probabilities of the
|
||||
sampled tokens to a tensor.
|
||||
_save_modified_probs: Whether to save the modified probabilities
|
||||
(including gumbel noise) of the sampled tokens to a tensor.
|
||||
DOES NOT include the modification done by modify_greedy_probs
|
||||
(because we want to use the unmodified probs to pick the best
|
||||
split in case of multi-split sampling).
|
||||
This is exposed only for testing.
|
||||
|
||||
Returns:
|
||||
sampled_tokens: shape = [n, max_best_of]
|
||||
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
|
||||
sampled_modified_probs: shape = [n, max_best_of]
|
||||
if save_modified_probs else None
|
||||
"""
|
||||
if sample_indices is None:
|
||||
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
|
||||
|
||||
sampled_tokens_size = (sample_indices.size(0), max_best_of)
|
||||
if save_logprobs:
|
||||
if logprobs is None:
|
||||
raise ValueError(
|
||||
"logprobs tensor must be provided if save_logprobs is True")
|
||||
sampled_logprobs_size = sampled_tokens_size
|
||||
else:
|
||||
# Empty tensors to invoke the kernel
|
||||
sampled_logprobs_size = (0, 0)
|
||||
logprobs = probs
|
||||
|
||||
assert logprobs is not None
|
||||
if _save_modified_probs:
|
||||
sampled_modified_probs_size = sampled_tokens_size
|
||||
else:
|
||||
# Empty tensors to invoke the kernel
|
||||
sampled_modified_probs_size = (0, 0)
|
||||
|
||||
# If the number of columns in probs is too large for Triton to handle,
|
||||
# we split the tensor and sample from each split separately, and then
|
||||
# do an argmax+gather to combine the results.
|
||||
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||
if n_splits > 1:
|
||||
(sampled_tokens, sampled_logprobs,
|
||||
sampled_modified_probs) = _multi_split_sample(
|
||||
probs,
|
||||
seeds,
|
||||
n_splits,
|
||||
sampled_tokens_size,
|
||||
sampled_logprobs_size,
|
||||
sample_indices,
|
||||
logprobs=logprobs,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs)
|
||||
else:
|
||||
sampled_tokens = torch.empty(sampled_tokens_size,
|
||||
dtype=torch.long,
|
||||
device=probs.device)
|
||||
sampled_logprobs = torch.empty(sampled_logprobs_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device)
|
||||
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device)
|
||||
n_samples = sample_indices.shape[0]
|
||||
n_cols = probs.shape[1]
|
||||
uniform_noise = seeded_uniform(n_samples,
|
||||
max_best_of,
|
||||
n_cols,
|
||||
seeds=seeds.flatten(),
|
||||
device=probs.device,
|
||||
dtype=probs.dtype)
|
||||
|
||||
_sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sample_indices,
|
||||
sampled_tokens,
|
||||
sampled_logprobs,
|
||||
sampled_modified_probs,
|
||||
seeds,
|
||||
uniform_noise,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
save_modified_probs=_save_modified_probs,
|
||||
)
|
||||
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
|
||||
sampled_modified_probs if _save_modified_probs else None)
|
||||
|
||||
|
||||
def _sample(probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sample_indices: torch.Tensor,
|
||||
output_samples: torch.Tensor,
|
||||
output_logprobs: torch.Tensor,
|
||||
output_modified_probs: torch.Tensor,
|
||||
seeds: torch.Tensor,
|
||||
uniform_noise: torch.Tensor,
|
||||
*,
|
||||
modify_greedy_probs: bool = False,
|
||||
save_logprobs: bool = True,
|
||||
save_modified_probs: bool = False) -> torch.Tensor:
|
||||
"""Sample tokens from probs.
|
||||
|
||||
Args:
|
||||
probs [batch_size, vocab_size]: probs to sample from.
|
||||
logprobs [batch_size, vocab_size]: logprobs (used when
|
||||
save_logprobsis True).
|
||||
sample_indices [n]: Indices of the samples to use for each row of probs.
|
||||
output_samples [n, n_best]: Output tensor to store samples in.
|
||||
output_logprobs [n, n_best]: Output tensor to store logprobs in.
|
||||
output_modified_probs [n, n_best]: Output tensor to store
|
||||
probs of chosen tokens in (modified with noise).
|
||||
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
|
||||
greedy sampling. Note this is ONLY used for determining
|
||||
whether to use random sampling or not. The actual random
|
||||
noise should be passed as uniform_noise.
|
||||
uniform_noise [batch_size, n_best, vocab_size]: Uniform
|
||||
noise to use for random sampling (will be converted
|
||||
to exponential gumbel noise by the kernel).
|
||||
modify_greedy_probs: If True, we modify the probs tensor in-place
|
||||
to encode the sampling method used for each row. This is used
|
||||
in speculative decoding. Only applies in greedy decoding.
|
||||
save_logprobs: If True, we save the logprobs of the sampled tokens
|
||||
in the output_logprobs tensor.
|
||||
save_modified_probs: If True, we save the modified probs (with noise)
|
||||
of the sampled tokens in the output_modified_probs tensor.
|
||||
DOES NOT include the modification done by modify_greedy_probs
|
||||
(because we want to use the unmodified probs to pick the best
|
||||
split in case of multi-split sampling).
|
||||
"""
|
||||
n_samples = sample_indices.shape[0]
|
||||
n_cols = probs.shape[1]
|
||||
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
|
||||
|
||||
# The block size is the smallest power of two greater than the number of
|
||||
# columns in probs
|
||||
block_size = triton.next_power_of_2(n_cols)
|
||||
num_warps = 4
|
||||
# Manual tuning. This seems to give best performance on A100 for
|
||||
# simple kernels like this.
|
||||
if block_size >= 8192:
|
||||
num_warps = 32
|
||||
elif block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
|
||||
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
|
||||
# instance per row of the probs matrix
|
||||
_sample_triton[(n_samples, n_best)](
|
||||
sample_indices,
|
||||
output_samples,
|
||||
output_logprobs,
|
||||
output_modified_probs,
|
||||
probs,
|
||||
logprobs,
|
||||
seeds,
|
||||
uniform_noise,
|
||||
output_samples.stride(0),
|
||||
probs.stride(0),
|
||||
uniform_noise.stride(0),
|
||||
uniform_noise.stride(1) if n_best > 1 else 1,
|
||||
n_samples,
|
||||
n_cols,
|
||||
n_best,
|
||||
num_warps=num_warps,
|
||||
block_size=block_size,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
save_modified_probs=save_modified_probs,
|
||||
)
|
||||
return output_samples, output_logprobs, output_modified_probs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _uniform_to_exponential(uniform_noise):
|
||||
"""Convert uniform samples to exponential samples."""
|
||||
# tl.rand returns values in [0, 1), so we clamp lower bound
|
||||
# to _EPS to avoid log(0) and thus division by 0 later
|
||||
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
|
||||
uniform_noise = tl.maximum(uniform_noise, lb)
|
||||
# Use the inversion method to turn uniform samples
|
||||
# into exponential samples
|
||||
exponential_noise = -tl.log(uniform_noise)
|
||||
return exponential_noise
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sample_triton(
|
||||
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
|
||||
output_logprobs_ptr: torch.Tensor,
|
||||
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
|
||||
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
|
||||
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
|
||||
probs_row_stride: int, uniform_noise_row_stride: int,
|
||||
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
|
||||
n_best: int, block_size: tl.constexpr,
|
||||
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
|
||||
save_modified_probs: tl.constexpr):
|
||||
# The rows are independent, so we parallelize across those
|
||||
sample_idx = tl.program_id(0)
|
||||
best_idx = tl.program_id(1)
|
||||
|
||||
# Load the row index from DRAM
|
||||
row_idx = tl.load(sample_indices_ptr + sample_idx)
|
||||
seed = tl.load(seeds_ptr + sample_idx)
|
||||
uses_random_sampling = seed != 0
|
||||
|
||||
# The stride represents how much we need to increase the
|
||||
# pointer to advance 1 row
|
||||
row_start_ptr = probs_ptr + row_idx * probs_row_stride
|
||||
|
||||
# The block size is the next power of two greater than n_cols,
|
||||
# so we can fit each row in a single block
|
||||
col_offsets = tl.arange(0, block_size)
|
||||
|
||||
# Load the row into SRAM, using a mask since block_size may be > than n_cols
|
||||
row = tl.load(row_start_ptr + col_offsets,
|
||||
mask=col_offsets < n_cols,
|
||||
other=float("-inf"))
|
||||
|
||||
if uses_random_sampling:
|
||||
uniform_noise_start_ptr = (uniform_noise_ptr +
|
||||
sample_idx * uniform_noise_row_stride +
|
||||
best_idx * uniform_noise_best_stride)
|
||||
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
|
||||
mask=col_offsets < n_cols,
|
||||
other=0.5)
|
||||
exponential_noise = _uniform_to_exponential(uniform_noise)
|
||||
row /= exponential_noise
|
||||
|
||||
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
|
||||
# clamp sampled token to n_cols - 1
|
||||
# this should not be necessary, but we do it
|
||||
# just in case
|
||||
if sampled_token >= n_cols:
|
||||
sampled_token = n_cols - 1
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
|
||||
best_idx)
|
||||
tl.store(output_row_start_ptr, sampled_token)
|
||||
|
||||
if modify_greedy_probs: # noqa
|
||||
if not uses_random_sampling:
|
||||
# Set the probability of the sampled token to 1, all other
|
||||
# tokens to zero. This is used in speculative decoding where
|
||||
# the sampling method must be encoded within the sampled
|
||||
# probability distributions.
|
||||
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
|
||||
tl.store(row_start_ptr + col_offsets,
|
||||
row,
|
||||
mask=col_offsets < n_cols)
|
||||
|
||||
if save_modified_probs:
|
||||
output_row_start_ptr = (output_modified_probs_ptr +
|
||||
sample_idx * output_row_stride + best_idx)
|
||||
tl.store(output_row_start_ptr, sampled_value)
|
||||
|
||||
if save_logprobs:
|
||||
# Load the row into SRAM, using a mask since block_size
|
||||
# may be > than n_cols
|
||||
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
|
||||
sampled_token)
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = (output_logprobs_ptr +
|
||||
sample_idx * output_row_stride + best_idx)
|
||||
tl.store(output_row_start_ptr, sampled_logprob)
|
||||
@ -10,12 +10,6 @@ import msgspec
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
@ -23,6 +17,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SampleLogprobs, SequenceOutput)
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
|
||||
import flashinfer.sampling
|
||||
@ -740,7 +735,7 @@ def _sample_with_torch(
|
||||
) -> SampleReturnType:
|
||||
'''Torch-oriented _sample() implementation.
|
||||
|
||||
Single-step scheduling:
|
||||
Single-step scheduling:
|
||||
* Perform GPU-side sampling computation
|
||||
* Immediately Pythonize sampling result
|
||||
|
||||
@ -777,7 +772,7 @@ def _sample_with_torch(
|
||||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
for sampling_type in SamplingType:
|
||||
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
||||
sample_indices = categorized_sample_indices[sampling_type]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
@ -863,88 +858,6 @@ def _sample_with_torch(
|
||||
)
|
||||
|
||||
|
||||
def _sample_with_triton_kernel(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
) -> SampleResultType:
|
||||
categorized_seq_group_ids: Dict[SamplingType,
|
||||
List[int]] = {t: []
|
||||
for t in SamplingType}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
sampling_params = seq_group.sampling_params
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
|
||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||
sample_metadata: Dict[SamplingType,
|
||||
Tuple[List[int], List[SequenceGroupToSample],
|
||||
torch.Tensor, torch.Tensor]] = {}
|
||||
max_best_of_in_batch = 1
|
||||
|
||||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
for sampling_type in SamplingType:
|
||||
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
||||
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
seq_group_id = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
||||
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
|
||||
sample_indices,
|
||||
sampled_token_indices)
|
||||
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
||||
SamplingType.RANDOM_SEED):
|
||||
for seq_group in seq_groups:
|
||||
if seq_group.is_prompt:
|
||||
sampling_params = seq_group.sampling_params
|
||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||
sampling_params.best_of)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
beam_search_logprobs = logprobs[sample_indices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
|
||||
sampled_tokens, _, _ = sample_triton(
|
||||
probs=probs,
|
||||
seeds=sampling_tensors.sampling_seeds,
|
||||
max_best_of=max_best_of_in_batch,
|
||||
sample_indices=sampling_tensors.sample_indices,
|
||||
logprobs=logprobs,
|
||||
# don't save logprobs because we have logic for that below
|
||||
# TODO: use this instead of the CPU-based logic below
|
||||
save_logprobs=False,
|
||||
)
|
||||
|
||||
# GPU<->CPU sync happens in the loop below.
|
||||
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
(seq_group_id, seq_groups, sample_indices,
|
||||
sampled_token_indices) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(
|
||||
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(
|
||||
seq_groups, sampled_tokens[sampled_token_indices])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups,
|
||||
beam_search_logprobs)
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
|
||||
sample_results = [
|
||||
sample_results_dict.get(i, ([], []))
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
return sample_results
|
||||
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
@ -974,10 +887,6 @@ def _sample(
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
)
|
||||
|
||||
# TODO: Enable once Triton kernel & associated code is faster.
|
||||
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
||||
# sampling_tensors)
|
||||
|
||||
|
||||
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import random
|
||||
from array import array
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -8,15 +7,10 @@ import torch
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
||||
is_pin_memory_available, make_tensor_with_pad,
|
||||
maybe_expand_dim)
|
||||
is_pin_memory_available, make_tensor_with_pad)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_SEED_0_REPLACEMENT = 3403598558
|
||||
# Some triton sampler related code is guarded before it is ready.
|
||||
_USE_TRITON_SAMPLER = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
|
||||
generator=None,
|
||||
is_prompt=True,
|
||||
prompt_logprob_indices=[],
|
||||
sample_indices=[])
|
||||
sample_indices=[],
|
||||
)
|
||||
|
||||
|
||||
class SamplingMetadataCache:
|
||||
"""Used to cache SamplingMetadata objects between scheduler iterations
|
||||
"""
|
||||
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
||||
|
||||
def __init__(self):
|
||||
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
|
||||
@ -124,12 +118,12 @@ class SamplingMetadata:
|
||||
The first tuple is [1, 2] (sampled index within original logit),
|
||||
and the second tuple is [0, 1] (sampled index within pruned logit).
|
||||
num_prompts: Number of prompt sequence groups in seq_groups.
|
||||
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
||||
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
||||
serialization of token outputs.
|
||||
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
||||
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
||||
tensors that are part of the sampler forward pass. Currently,
|
||||
it is mainly used for multi-step decode.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -165,16 +159,19 @@ class SamplingMetadata:
|
||||
num_prompts,
|
||||
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
||||
device, generators, cache)
|
||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory)
|
||||
selected_token_indices = async_tensor_h2d(
|
||||
selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
categorized_sample_indices = {
|
||||
t: maybe_expand_dim(
|
||||
async_tensor_h2d(seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory), 2, 2)
|
||||
t: async_tensor_h2d(
|
||||
seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
@ -201,8 +198,8 @@ def _prepare_seq_groups(
|
||||
device: str,
|
||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||
cache: Optional[SamplingMetadataCache] = None,
|
||||
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
||||
SamplingType, List[Tuple[int, int]]], int]:
|
||||
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType,
|
||||
List[int]], int, ]:
|
||||
"""Prepare sequence groups and indices for sampling.
|
||||
|
||||
Args:
|
||||
@ -233,16 +230,13 @@ def _prepare_seq_groups(
|
||||
# Sampling type -> (
|
||||
# indices to sample/prompt logprob within pruned output logits,
|
||||
# indices to sample within pruned logits)
|
||||
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
|
||||
categorized_sample_indices: Dict[SamplingType, List[int]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
# Index of logits to compute logprob. Logits include both prompt logprob
|
||||
# and sample logprob indices.
|
||||
logit_idx = 0
|
||||
# Index to sample from a sample tensor. It is used by triton sample kernel.
|
||||
# See `_sample_with_triton_kernel` for more details.
|
||||
sample_idx = 0
|
||||
# Total number of prompts from given sequence groups.
|
||||
num_prompts = 0
|
||||
|
||||
@ -264,10 +258,10 @@ def _prepare_seq_groups(
|
||||
# If the current seq group is in decode stage, it is None.
|
||||
seq_len: Optional[int] = None
|
||||
query_len: Optional[int] = None
|
||||
prompt_logprob_indices: List[int] = \
|
||||
sample_obj.prompt_logprob_indices if cache is not None else []
|
||||
sample_indices: List[int] = \
|
||||
sample_obj.sample_indices if cache is not None else []
|
||||
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
|
||||
if cache is not None else [])
|
||||
sample_indices: List[int] = (sample_obj.sample_indices
|
||||
if cache is not None else [])
|
||||
do_sample = seq_group_metadata.do_sample
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
@ -333,11 +327,8 @@ def _prepare_seq_groups(
|
||||
if do_sample:
|
||||
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
||||
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||
list(
|
||||
zip(range(logit_idx, logit_idx + sample_len),
|
||||
range(sample_idx, sample_idx + sample_len))))
|
||||
list(range(logit_idx, logit_idx + sample_len)))
|
||||
logit_idx += sample_len
|
||||
sample_idx += sample_len
|
||||
|
||||
if cache is not None:
|
||||
sample_obj.sampling_params = sampling_params
|
||||
@ -356,7 +347,8 @@ def _prepare_seq_groups(
|
||||
generator=generator,
|
||||
is_prompt=is_prompt,
|
||||
prompt_logprob_indices=list(prompt_logprob_indices),
|
||||
sample_indices=list(sample_indices))
|
||||
sample_indices=list(sample_indices),
|
||||
)
|
||||
|
||||
seq_groups.append(sample_obj)
|
||||
|
||||
@ -378,9 +370,6 @@ class SamplingTensors:
|
||||
presence_penalties: torch.Tensor
|
||||
frequency_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
sampling_seeds: torch.Tensor
|
||||
sample_indices: torch.Tensor
|
||||
extra_seeds: Optional[torch.Tensor]
|
||||
prompt_tokens: torch.Tensor
|
||||
output_tokens: torch.Tensor
|
||||
|
||||
@ -391,15 +380,7 @@ class SamplingTensors:
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extra_seeds_to_generate: int = 0,
|
||||
extra_entropy: Optional[Tuple[int, ...]] = None
|
||||
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
||||
"""
|
||||
extra_seeds_to_generate: extra seeds to generate using the
|
||||
user-defined seed for each sequence.
|
||||
extra_entropy: extra entropy to use when generating seeds.
|
||||
"""
|
||||
prompt_tokens: List[array] = []
|
||||
output_tokens: List[array] = []
|
||||
top_ks: List[int] = []
|
||||
@ -409,19 +390,10 @@ class SamplingTensors:
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
repetition_penalties: List[float] = []
|
||||
sampling_seeds: List[int] = []
|
||||
sample_indices: List[int] = []
|
||||
do_penalties = False
|
||||
do_top_p_top_k = False
|
||||
do_min_p = False
|
||||
|
||||
if _USE_TRITON_SAMPLER:
|
||||
prompt_best_of: List[int] = []
|
||||
|
||||
# We need one base seed per Triton slice.
|
||||
seeds_to_generate = (extra_seeds_to_generate +
|
||||
get_num_triton_sampler_splits(vocab_size))
|
||||
|
||||
assert sampling_metadata.seq_groups is not None
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
@ -452,7 +424,7 @@ class SamplingTensors:
|
||||
do_penalties = True
|
||||
|
||||
is_prompt = seq_group.is_prompt
|
||||
if (is_prompt and sampling_params.prompt_logprobs is not None):
|
||||
if is_prompt and sampling_params.prompt_logprobs is not None:
|
||||
# For tokens in the prompt that we only need to get
|
||||
# their logprobs
|
||||
query_len = seq_group.query_len
|
||||
@ -477,28 +449,6 @@ class SamplingTensors:
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
|
||||
if _USE_TRITON_SAMPLER:
|
||||
if is_prompt:
|
||||
prompt_best_of.append(sampling_params.best_of)
|
||||
query_len = seq_group.query_len
|
||||
assert query_len is not None
|
||||
|
||||
seed = sampling_params.seed
|
||||
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
extra_entropy = extra_entropy or ()
|
||||
seq_seeds = cls._get_sequence_seeds(
|
||||
seed,
|
||||
seq_data.get_len(),
|
||||
*extra_entropy,
|
||||
seq_id,
|
||||
seeds_to_generate=seeds_to_generate,
|
||||
is_greedy=is_greedy)
|
||||
sampling_seeds.append(seq_seeds)
|
||||
sample_indices.extend(seq_group.sample_indices)
|
||||
|
||||
if do_penalties:
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
@ -518,23 +468,37 @@ class SamplingTensors:
|
||||
output_tokens.append(seq_data.output_token_ids_array)
|
||||
|
||||
sampling_tensors = SamplingTensors.from_lists(
|
||||
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||
frequency_penalties, repetition_penalties, sampling_seeds,
|
||||
sample_indices, prompt_tokens, output_tokens, vocab_size,
|
||||
extra_seeds_to_generate, device, dtype)
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
min_ps,
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
repetition_penalties,
|
||||
prompt_tokens,
|
||||
output_tokens,
|
||||
vocab_size,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
||||
|
||||
@classmethod
|
||||
def from_lists(cls, temperatures: List[float], top_ps: List[float],
|
||||
top_ks: List[int], min_ps: List[float],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
repetition_penalties: List[float],
|
||||
sampling_seeds: List[int], sample_indices: List[int],
|
||||
prompt_tokens: List[array], output_tokens: List[array],
|
||||
vocab_size: int, extra_seeds_to_generate: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype) -> "SamplingTensors":
|
||||
def from_lists(
|
||||
cls,
|
||||
temperatures: List[float],
|
||||
top_ps: List[float],
|
||||
top_ks: List[int],
|
||||
min_ps: List[float],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
repetition_penalties: List[float],
|
||||
prompt_tokens: List[array],
|
||||
output_tokens: List[array],
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> "SamplingTensors":
|
||||
# Note that the performance will be very bad without
|
||||
# pinned memory.
|
||||
pin_memory = is_pin_memory_available()
|
||||
@ -603,34 +567,9 @@ class SamplingTensors:
|
||||
dtype=torch.int,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
sample_indices_t = torch.tensor(
|
||||
sample_indices,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
# need to transpose and make contiguous to
|
||||
# copy the tensor correctly.
|
||||
# [batch_size, n_seeds] -> [n_seeds, batch_size]
|
||||
sampling_seeds_t = torch.tensor(
|
||||
sampling_seeds,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
).t().contiguous()
|
||||
|
||||
# Because the memory is pinned, we can do non-blocking
|
||||
# transfer to device.
|
||||
|
||||
# How many seeds the sample operation itself will need.
|
||||
num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
|
||||
sampling_seeds_gpu = sampling_seeds_t.to(device=device,
|
||||
non_blocking=True)
|
||||
extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
|
||||
if not extra_seeds_gpu.numel():
|
||||
extra_seeds_gpu = None
|
||||
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
||||
|
||||
return cls(
|
||||
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
||||
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
||||
@ -644,38 +583,4 @@ class SamplingTensors:
|
||||
non_blocking=True),
|
||||
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
||||
output_tokens=output_t.to(device=device, non_blocking=True),
|
||||
sampling_seeds=sampling_seeds_gpu,
|
||||
sample_indices=sample_indices_t.to(device=device,
|
||||
non_blocking=True),
|
||||
extra_seeds=extra_seeds_gpu,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_sequence_seeds(
|
||||
seed: int,
|
||||
*extra_entropy: int,
|
||||
seeds_to_generate: int,
|
||||
is_greedy: bool,
|
||||
):
|
||||
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
|
||||
if not is_greedy:
|
||||
if seed is None:
|
||||
randint_fn = random.randint
|
||||
else:
|
||||
generator = random.Random(str((seed, ) + extra_entropy))
|
||||
randint_fn = generator.randint
|
||||
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
|
||||
# If the user/random sets seed = 0 but request should
|
||||
# have sampling, we need to change it to something
|
||||
# else. We use a constant in that case.
|
||||
# This way we don't need to create and load a bool
|
||||
# matrix in the sampling kernel, which reduces CPU
|
||||
# overhead and latency.
|
||||
seq_seeds = [
|
||||
randint_fn(lo, hi) or _SEED_0_REPLACEMENT
|
||||
for _ in range(seeds_to_generate)
|
||||
]
|
||||
else:
|
||||
# For the kernel, seed == 0 means greedy decoding.
|
||||
seq_seeds = [0] * seeds_to_generate
|
||||
return seq_seeds
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
import math
|
||||
|
||||
# This is a hardcoded limit in Triton (max block size).
|
||||
MAX_TRITON_N_COLS = 131072
|
||||
|
||||
|
||||
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
||||
"""Get the number of splits to use for Triton sampling.
|
||||
|
||||
Triton has a limit on the number of columns it can handle, so we need to
|
||||
split the tensor and call the kernel multiple times if it's too large.
|
||||
"""
|
||||
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
||||
@ -270,7 +270,7 @@ class LRUCache(Generic[T]):
|
||||
|
||||
|
||||
class PyObjectCache:
|
||||
"""Used to cache python objects to avoid object allocations
|
||||
"""Used to cache python objects to avoid object allocations
|
||||
across scheduler iterations.
|
||||
"""
|
||||
|
||||
@ -289,7 +289,7 @@ class PyObjectCache:
|
||||
self._obj_cache.append(self._obj_builder())
|
||||
|
||||
def get_object(self):
|
||||
"""Returns a pre-allocated cached object. If there is not enough
|
||||
"""Returns a pre-allocated cached object. If there is not enough
|
||||
objects, then the cache size will double.
|
||||
"""
|
||||
if self._index >= len(self._obj_cache):
|
||||
@ -837,15 +837,6 @@ def async_tensor_h2d(
|
||||
return t.to(device=target_device, non_blocking=True)
|
||||
|
||||
|
||||
def maybe_expand_dim(tensor: torch.Tensor,
|
||||
target_dims: int,
|
||||
size: int = 1) -> torch.Tensor:
|
||||
"""Expand the tensor to the target_dims."""
|
||||
if tensor.ndim < target_dims:
|
||||
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
||||
return tensor
|
||||
|
||||
|
||||
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
"""Get the size of the data type in bytes."""
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
@ -1070,7 +1061,7 @@ def _cuda_device_count_stateless(
|
||||
def cuda_device_count_stateless() -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of
|
||||
CUDA_VISIBLE_DEVICES at the time of call.
|
||||
|
||||
|
||||
This should be used instead of torch.cuda.device_count()
|
||||
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||
value."""
|
||||
@ -1136,10 +1127,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
def _pull_args_from_config(args: List[str]) -> List[str]:
|
||||
"""Method to pull arguments specified in the config file
|
||||
into the command-line args variable.
|
||||
|
||||
The arguments in config file will be inserted between
|
||||
|
||||
The arguments in config file will be inserted between
|
||||
the argument list.
|
||||
|
||||
|
||||
example:
|
||||
```yaml
|
||||
port: 12323
|
||||
@ -1150,21 +1141,21 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
--config config.yaml -tp 2
|
||||
$: args = [
|
||||
"serve,chat,complete",
|
||||
"facebook/opt-12B",
|
||||
'--config', 'config.yaml',
|
||||
"facebook/opt-12B",
|
||||
'--config', 'config.yaml',
|
||||
'-tp', '2'
|
||||
]
|
||||
$: args = [
|
||||
"serve,chat,complete",
|
||||
"facebook/opt-12B",
|
||||
'--port', '12323',
|
||||
'--tensor-parallel-size', '4',
|
||||
"facebook/opt-12B",
|
||||
'--port', '12323',
|
||||
'--tensor-parallel-size', '4',
|
||||
'-tp', '2'
|
||||
]
|
||||
```
|
||||
|
||||
Please note how the config args are inserted after the sub command.
|
||||
this way the order of priorities is maintained when these are args
|
||||
this way the order of priorities is maintained when these are args
|
||||
parsed by super().
|
||||
"""
|
||||
assert args.count(
|
||||
@ -1190,7 +1181,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
|
||||
@staticmethod
|
||||
def _load_config_file(file_path: str) -> List[str]:
|
||||
"""Loads a yaml file and returns the key value pairs as a
|
||||
"""Loads a yaml file and returns the key value pairs as a
|
||||
flattened list with argparse like pattern
|
||||
```yaml
|
||||
port: 12323
|
||||
@ -1201,7 +1192,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
'--port': '12323',
|
||||
'--tensor-parallel-size': '4'
|
||||
]
|
||||
|
||||
|
||||
"""
|
||||
|
||||
extension: str = file_path.split('.')[-1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user