[refactor] remove triton based sampler (#8524)

This commit is contained in:
Simon Mo 2024-09-16 20:04:48 -07:00 committed by GitHub
parent cca61642e0
commit 546034b466
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 75 additions and 1095 deletions

View File

@ -1,52 +0,0 @@
import random
import pytest
import torch
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.model_executor.utils import set_random_seed
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_3d", [True, False])
def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
device = "cuda"
for seed in range(512):
set_random_seed(seed)
rows = random.randint(1, 512)
cols = random.randint(1, 64000)
if use_3d:
third_dim = random.randint(2, 10)
dims = [rows, third_dim, cols]
else:
dims = [rows, cols]
seeds = torch.randint(torch.iinfo(torch.long).min,
torch.iinfo(torch.long).max, (rows, ),
device=device)
# Test that the same seed produces the same output
out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out2)
# del to save memory
del out2
out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out3)
# del to save memory
del out3
# Initialize out tensor with garbage to ensure that it is overwritten
out_with_tensor = seeded_uniform(
*dims,
out=torch.full(
(*dims, ),
-1,
dtype=dtype,
device=device,
),
seeds=seeds,
dtype=dtype,
)
torch.testing.assert_close(out, out_with_tensor)

View File

@ -1,209 +0,0 @@
import gc
from unittest.mock import patch
import pytest
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.sample import (_sample_triton,
_uniform_to_exponential,
sample)
from vllm.model_executor.sampling_metadata import SamplingTensors
from vllm.model_executor.utils import set_random_seed
from vllm.triton_utils.libentry import LibEntry
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
get_num_triton_sampler_splits)
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
@pytest.fixture(autouse=True)
def _cleanup():
yield
gc.collect()
torch.cuda.empty_cache()
@triton.jit
def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = _uniform_to_exponential(x)
tl.store(output + idx, y)
def test_uniform_to_exponential():
"""Test that we can convert uniform to exponential without div by 0."""
input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
dtype=torch.float32,
device="cuda")
output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
_uniform_to_exponential_kernel[(1, )](input, output, 2)
assert torch.all(torch.isfinite(output))
assert torch.all(output > 0)
assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
@pytest.mark.parametrize("save_logprobs", [True, False])
def test_sample_decoding_only(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size,
save_logprobs):
set_random_seed(seed)
bs = 8
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = (torch.rand(
(1, bs), device="cuda") < 0.5).expand(n_splits, bs)
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, bs),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, bs),
dtype=torch.bool,
device="cuda")
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, bs),
device="cuda").mul_(random_sampling_mask)
#The current _sample_triton does not utilize the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
LibEntry(_sample_triton)):
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
assert sampled_tokens.shape == (bs, max_best_of)
for i in range(bs):
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
request_uses_random_sampling = random_sampling_mask[0, i]
if modify_greedy_probs and not request_uses_random_sampling:
# If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place
torch.testing.assert_close(
probs[i][sampled_tokens[i]],
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
assert torch.sum(probs[i]) == 1.0
torch.testing.assert_close(
sampled_modified_probs[i][0],
torch.full_like(sampled_modified_probs[i][0], 1.0))
elif request_uses_random_sampling:
# If the request is random, we want to make sure
# sampled_modified_probs tensor has noise added
# (and thus is different from probs tensor)
assert not torch.allclose(sampled_modified_probs[i][0],
probs[i][sampled_tokens[i]])
elif not request_uses_random_sampling:
# If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor.
torch.testing.assert_close(sampled_modified_probs[i],
probs[i][sampled_tokens[i]])
if save_logprobs:
assert sampled_logprobs.shape == (bs, max_best_of)
for i in range(bs):
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[i][
sampled_tokens[i, best_of]])
else:
assert sampled_logprobs is None
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
def test_sample_prompt_logprobs(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size):
set_random_seed(seed)
prompt_sizes = [16, 32, 64, 128] * 2
samples = 8
bs = samples + sum(prompt_sizes)
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.tensor(prompt_sizes,
dtype=torch.long,
device="cuda").cumsum_(0)
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = torch.rand(
(n_splits, samples), device="cuda") < 0.5
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, samples),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, samples),
dtype=torch.bool,
device="cuda")
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, samples),
device="cuda").mul_(random_sampling_mask)
#ditto
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
LibEntry(_sample_triton)):
sampled_tokens, sampled_logprobs, _ = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
assert sampled_tokens.shape == (samples, max_best_of)
assert sampled_logprobs.shape == (samples, max_best_of)
for i, t in enumerate(sample_indices):
assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
[sampled_tokens[i, best_of]])
@pytest.mark.parametrize("seed", list(range(16)))
def test_get_sequence_seeds(seed):
"""Ensure that we get a different child seed from base
seed + extra entropy"""
starting_seed = seed
seq_seed = None
extra_entropy = 1
for i in range(512):
new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
i,
seeds_to_generate=1,
is_greedy=False)[0]
new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
starting_seed,
i,
extra_entropy,
seeds_to_generate=1,
is_greedy=False)[0]
assert new_seq_seed_extra_entropy != new_seq_seed
assert seq_seed != new_seq_seed
seq_seed = new_seq_seed

View File

@ -1,157 +0,0 @@
from typing import Optional, Union
import torch
import triton
import triton.language as tl
def seeded_uniform(
*size,
seeds: torch.Tensor,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
pin_memory: Optional[bool] = False,
) -> torch.Tensor:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims = len(size)
if n_dims > 3:
raise ValueError("seeded_uniform only supports up to 3D tensors")
if out is None:
out = torch.empty(*size,
dtype=dtype,
device=device,
pin_memory=pin_memory)
elif out.shape != size:
raise ValueError("shape of out and size must be the same")
if n_dims == 3:
n_rows, n_3d, n_cols = out.shape
stride_row = out.stride(0)
stride_3d = out.stride(1)
elif n_dims == 2:
n_rows, n_cols = out.shape
n_3d = 1
stride_row = out.stride(0)
stride_3d = 1
else:
n_cols = out.shape[0]
n_rows = 1
n_3d = 1
stride_row = 1
stride_3d = 1
if seeds.ndim != 1:
raise ValueError("seeds must be a 1D tensor")
if seeds.numel() != n_rows:
raise ValueError(
"seeds must have the same number of elements as out has rows")
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size = triton.next_power_of_2(n_cols)
philox_block_size = max(full_block_size // 4, 1)
n_slices = full_block_size // philox_block_size
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
num_warps = 32
elif philox_block_size >= 4096:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
_seeded_uniform_triton[(n_rows, n_3d)](
out,
seeds,
stride_row,
stride_3d,
seeds.stride(0),
n_rows,
n_3d,
n_cols,
n_slices=n_slices,
num_warps=num_warps,
block_size=philox_block_size,
)
return out
@triton.jit
def _seeded_uniform_triton(
out_ptr: torch.Tensor,
seed_ptr: torch.Tensor,
out_row_stride: int,
out_3d_stride: int,
seed_row_stride: int,
n_rows: int,
n_3d: int,
n_cols: int,
n_slices: tl.constexpr,
block_size: tl.constexpr,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
# Get the row index.
row_idx = tl.program_id(axis=0)
three_d_idx = tl.program_id(axis=1)
philox_offsets = tl.arange(0, block_size)
# Get the seed for the current element.
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
if three_d_idx > 0:
seed ^= three_d_idx
# Generate random numbers in [0, 1).
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
three_d_idx * out_3d_stride)
out1_offsets = philox_offsets
tl.store(output_row_start_ptr + out1_offsets,
out1,
mask=out1_offsets < n_cols)
if n_slices > 1:
out2_offsets = tl.arange(block_size, block_size * 2)
tl.store(output_row_start_ptr + out2_offsets,
out2,
mask=out2_offsets < n_cols)
if n_slices > 2:
out3_offsets = tl.arange(block_size * 2, block_size * 3)
tl.store(output_row_start_ptr + out3_offsets,
out3,
mask=out3_offsets < n_cols)
if n_slices > 3:
out4_offsets = tl.arange(block_size * 3, block_size * 4)
tl.store(output_row_start_ptr + out4_offsets,
out4,
mask=out4_offsets < n_cols)

View File

@ -1,394 +0,0 @@
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.triton_utils.sample import get_num_triton_sampler_splits
_EPS: tl.constexpr = 1e-6
def _multi_split_sample(
probs: torch.Tensor,
seeds: torch.Tensor,
n_splits: int,
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
logprobs: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
"""Sample tokens where vocab size is split into multiple parts
(too large for Triton otherwise)."""
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
split_probs = probs.tensor_split(n_splits, 1)
split_logprobs = logprobs.tensor_split(n_splits, 1)
sampled_tokens_tmp = [
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
for _ in range(n_splits)
]
sampled_logprobs_tmp = [
torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
# We are purposefuly using sampled_tokens_size as we need to always
# save modified probs in this case.
sampled_modified_probs_tmp = [
torch.empty(sampled_tokens_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
for i in range(n_splits):
n_samples = sample_indices.shape[0]
n_cols = split_probs[i].shape[1]
n_best = sampled_tokens_tmp[i].shape[1]
uniform_noise = seeded_uniform(n_samples,
n_best,
n_cols,
seeds=seeds[i].flatten(),
device=split_probs[i].device,
dtype=split_probs[i].dtype)
# TODO(yard1): See if we can remove the contiguous() calls.
# Will need kernel support.
_sample(
split_probs[i].contiguous(),
split_logprobs[i].contiguous(),
sample_indices,
sampled_tokens_tmp[i],
sampled_logprobs_tmp[i],
sampled_modified_probs_tmp[i],
seeds[i],
uniform_noise,
modify_greedy_probs=False,
save_logprobs=save_logprobs,
save_modified_probs=True,
)
if i > 0:
# Add offset to sampled tokens
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
sampled_tokens = torch.stack(sampled_tokens_tmp)
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
# Reduce the results from the splits.
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
dim=0,
keepdim=True)
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
if save_logprobs:
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
else:
sampled_logprobs = None
sampled_modified_probs = sampled_modified_probs.squeeze(0)
if modify_greedy_probs:
# We need to modify the greedy probs for the sampled tokens.
# We can't do this in the kernel as we need to know the
# sampled tokens.
probs.fill_(0.0)
probs.scatter_(1, sampled_tokens, 1.0)
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
def sample(
probs: torch.Tensor,
seeds: torch.Tensor,
*,
max_best_of: int = 1,
sample_indices: Optional[torch.Tensor] = None,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
_save_modified_probs: bool = False, # pylint: disable=invalid-name
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Sample tokens from probs. with per-sequence seeds.
Can sample from a subset of sequences through sample_indices.
Args:
probs: Probabilities to sample from.
shape = [batch_size, vocab_size]
seeds: Per-sequence seed values.
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
max_best_of: Number of samples to generate per sequence.
Sequence seed will be incremented by 1 each time.
sample_indices: Indices of sequences to sample from.
If not provided, will sample from all sequences.
shape = [n]
logprobs: Log-probabilities of the sampled tokens.
Only used for saving the logprobs if save_logprobs is True.
shape = [batch_size, vocab_size]
modify_greedy_probs: Whether to modify the greedy probabilities
for speculative sampling (sampled token = 1.0,
everything else = 0.0).
save_logprobs: Whether to save the log-probabilities of the
sampled tokens to a tensor.
_save_modified_probs: Whether to save the modified probabilities
(including gumbel noise) of the sampled tokens to a tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
This is exposed only for testing.
Returns:
sampled_tokens: shape = [n, max_best_of]
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
sampled_modified_probs: shape = [n, max_best_of]
if save_modified_probs else None
"""
if sample_indices is None:
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
sampled_tokens_size = (sample_indices.size(0), max_best_of)
if save_logprobs:
if logprobs is None:
raise ValueError(
"logprobs tensor must be provided if save_logprobs is True")
sampled_logprobs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_logprobs_size = (0, 0)
logprobs = probs
assert logprobs is not None
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_modified_probs_size = (0, 0)
# If the number of columns in probs is too large for Triton to handle,
# we split the tensor and sample from each split separately, and then
# do an argmax+gather to combine the results.
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if n_splits > 1:
(sampled_tokens, sampled_logprobs,
sampled_modified_probs) = _multi_split_sample(
probs,
seeds,
n_splits,
sampled_tokens_size,
sampled_logprobs_size,
sample_indices,
logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs)
else:
sampled_tokens = torch.empty(sampled_tokens_size,
dtype=torch.long,
device=probs.device)
sampled_logprobs = torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device)
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
dtype=probs.dtype,
device=probs.device)
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
uniform_noise = seeded_uniform(n_samples,
max_best_of,
n_cols,
seeds=seeds.flatten(),
device=probs.device,
dtype=probs.dtype)
_sample(
probs,
logprobs,
sample_indices,
sampled_tokens,
sampled_logprobs,
sampled_modified_probs,
seeds,
uniform_noise,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=_save_modified_probs,
)
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
sampled_modified_probs if _save_modified_probs else None)
def _sample(probs: torch.Tensor,
logprobs: torch.Tensor,
sample_indices: torch.Tensor,
output_samples: torch.Tensor,
output_logprobs: torch.Tensor,
output_modified_probs: torch.Tensor,
seeds: torch.Tensor,
uniform_noise: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = True,
save_modified_probs: bool = False) -> torch.Tensor:
"""Sample tokens from probs.
Args:
probs [batch_size, vocab_size]: probs to sample from.
logprobs [batch_size, vocab_size]: logprobs (used when
save_logprobsis True).
sample_indices [n]: Indices of the samples to use for each row of probs.
output_samples [n, n_best]: Output tensor to store samples in.
output_logprobs [n, n_best]: Output tensor to store logprobs in.
output_modified_probs [n, n_best]: Output tensor to store
probs of chosen tokens in (modified with noise).
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
greedy sampling. Note this is ONLY used for determining
whether to use random sampling or not. The actual random
noise should be passed as uniform_noise.
uniform_noise [batch_size, n_best, vocab_size]: Uniform
noise to use for random sampling (will be converted
to exponential gumbel noise by the kernel).
modify_greedy_probs: If True, we modify the probs tensor in-place
to encode the sampling method used for each row. This is used
in speculative decoding. Only applies in greedy decoding.
save_logprobs: If True, we save the logprobs of the sampled tokens
in the output_logprobs tensor.
save_modified_probs: If True, we save the modified probs (with noise)
of the sampled tokens in the output_modified_probs tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
"""
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
# The block size is the smallest power of two greater than the number of
# columns in probs
block_size = triton.next_power_of_2(n_cols)
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if block_size >= 8192:
num_warps = 32
elif block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
# instance per row of the probs matrix
_sample_triton[(n_samples, n_best)](
sample_indices,
output_samples,
output_logprobs,
output_modified_probs,
probs,
logprobs,
seeds,
uniform_noise,
output_samples.stride(0),
probs.stride(0),
uniform_noise.stride(0),
uniform_noise.stride(1) if n_best > 1 else 1,
n_samples,
n_cols,
n_best,
num_warps=num_warps,
block_size=block_size,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=save_modified_probs,
)
return output_samples, output_logprobs, output_modified_probs
@triton.jit
def _uniform_to_exponential(uniform_noise):
"""Convert uniform samples to exponential samples."""
# tl.rand returns values in [0, 1), so we clamp lower bound
# to _EPS to avoid log(0) and thus division by 0 later
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
uniform_noise = tl.maximum(uniform_noise, lb)
# Use the inversion method to turn uniform samples
# into exponential samples
exponential_noise = -tl.log(uniform_noise)
return exponential_noise
@triton.jit
def _sample_triton(
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
output_logprobs_ptr: torch.Tensor,
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
probs_row_stride: int, uniform_noise_row_stride: int,
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
n_best: int, block_size: tl.constexpr,
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
save_modified_probs: tl.constexpr):
# The rows are independent, so we parallelize across those
sample_idx = tl.program_id(0)
best_idx = tl.program_id(1)
# Load the row index from DRAM
row_idx = tl.load(sample_indices_ptr + sample_idx)
seed = tl.load(seeds_ptr + sample_idx)
uses_random_sampling = seed != 0
# The stride represents how much we need to increase the
# pointer to advance 1 row
row_start_ptr = probs_ptr + row_idx * probs_row_stride
# The block size is the next power of two greater than n_cols,
# so we can fit each row in a single block
col_offsets = tl.arange(0, block_size)
# Load the row into SRAM, using a mask since block_size may be > than n_cols
row = tl.load(row_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=float("-inf"))
if uses_random_sampling:
uniform_noise_start_ptr = (uniform_noise_ptr +
sample_idx * uniform_noise_row_stride +
best_idx * uniform_noise_best_stride)
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=0.5)
exponential_noise = _uniform_to_exponential(uniform_noise)
row /= exponential_noise
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
# clamp sampled token to n_cols - 1
# this should not be necessary, but we do it
# just in case
if sampled_token >= n_cols:
sampled_token = n_cols - 1
# Write back output to DRAM
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
best_idx)
tl.store(output_row_start_ptr, sampled_token)
if modify_greedy_probs: # noqa
if not uses_random_sampling:
# Set the probability of the sampled token to 1, all other
# tokens to zero. This is used in speculative decoding where
# the sampling method must be encoded within the sampled
# probability distributions.
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
tl.store(row_start_ptr + col_offsets,
row,
mask=col_offsets < n_cols)
if save_modified_probs:
output_row_start_ptr = (output_modified_probs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_value)
if save_logprobs:
# Load the row into SRAM, using a mask since block_size
# may be > than n_cols
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
sampled_token)
# Write back output to DRAM
output_row_start_ptr = (output_logprobs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_logprob)

View File

@ -10,12 +10,6 @@ import msgspec
import torch
import torch.nn as nn
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton
import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
@ -23,6 +17,7 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
from vllm.sampling_params import SamplingType
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
@ -740,7 +735,7 @@ def _sample_with_torch(
) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
@ -777,7 +772,7 @@ def _sample_with_torch(
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
@ -863,88 +858,6 @@ def _sample_with_torch(
)
def _sample_with_triton_kernel(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> SampleResultType:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata: Dict[SamplingType,
Tuple[List[int], List[SequenceGroupToSample],
torch.Tensor, torch.Tensor]] = {}
max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sampled_tokens, _, _ = sample_triton(
probs=probs,
seeds=sampling_tensors.sampling_seeds,
max_best_of=max_best_of_in_batch,
sample_indices=sampling_tensors.sample_indices,
logprobs=logprobs,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs=False,
)
# GPU<->CPU sync happens in the loop below.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
@ -974,10 +887,6 @@ def _sample(
modify_greedy_probs=modify_greedy_probs,
)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""

View File

@ -1,4 +1,3 @@
import random
from array import array
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -8,15 +7,10 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
is_pin_memory_available, make_tensor_with_pad)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER = False
@dataclass
@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int):
generator=None,
is_prompt=True,
prompt_logprob_indices=[],
sample_indices=[])
sample_indices=[],
)
class SamplingMetadataCache:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""
"""Used to cache SamplingMetadata objects between scheduler iterations"""
def __init__(self):
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
@ -124,12 +118,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
"""
def __init__(
@ -165,16 +159,19 @@ class SamplingMetadata:
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device, generators, cache)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
selected_token_indices = async_tensor_h2d(
selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory,
)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory), 2, 2)
t: async_tensor_h2d(
seq_ids,
dtype=torch.int,
target_device=device,
pin_memory=pin_memory,
)
for t, seq_ids in categorized_sample_indices.items()
}
@ -201,8 +198,8 @@ def _prepare_seq_groups(
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType,
List[int]], int, ]:
"""Prepare sequence groups and indices for sampling.
Args:
@ -233,16 +230,13 @@ def _prepare_seq_groups(
# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
categorized_sample_indices: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
# Index of logits to compute logprob. Logits include both prompt logprob
# and sample logprob indices.
logit_idx = 0
# Index to sample from a sample tensor. It is used by triton sample kernel.
# See `_sample_with_triton_kernel` for more details.
sample_idx = 0
# Total number of prompts from given sequence groups.
num_prompts = 0
@ -264,10 +258,10 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = \
sample_obj.prompt_logprob_indices if cache is not None else []
sample_indices: List[int] = \
sample_obj.sample_indices if cache is not None else []
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
if cache is not None else [])
sample_indices: List[int] = (sample_obj.sample_indices
if cache is not None else [])
do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt:
@ -333,11 +327,8 @@ def _prepare_seq_groups(
if do_sample:
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
categorized_sample_indices[sampling_params.sampling_type].extend(
list(
zip(range(logit_idx, logit_idx + sample_len),
range(sample_idx, sample_idx + sample_len))))
list(range(logit_idx, logit_idx + sample_len)))
logit_idx += sample_len
sample_idx += sample_len
if cache is not None:
sample_obj.sampling_params = sampling_params
@ -356,7 +347,8 @@ def _prepare_seq_groups(
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices))
sample_indices=list(sample_indices),
)
seq_groups.append(sample_obj)
@ -378,9 +370,6 @@ class SamplingTensors:
presence_penalties: torch.Tensor
frequency_penalties: torch.Tensor
repetition_penalties: torch.Tensor
sampling_seeds: torch.Tensor
sample_indices: torch.Tensor
extra_seeds: Optional[torch.Tensor]
prompt_tokens: torch.Tensor
output_tokens: torch.Tensor
@ -391,15 +380,7 @@ class SamplingTensors:
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
*,
extra_seeds_to_generate: int = 0,
extra_entropy: Optional[Tuple[int, ...]] = None
) -> Tuple["SamplingTensors", bool, bool, bool]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens: List[array] = []
output_tokens: List[array] = []
top_ks: List[int] = []
@ -409,19 +390,10 @@ class SamplingTensors:
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
sampling_seeds: List[int] = []
sample_indices: List[int] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False
if _USE_TRITON_SAMPLER:
prompt_best_of: List[int] = []
# We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
assert sampling_metadata.seq_groups is not None
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
@ -452,7 +424,7 @@ class SamplingTensors:
do_penalties = True
is_prompt = seq_group.is_prompt
if (is_prompt and sampling_params.prompt_logprobs is not None):
if is_prompt and sampling_params.prompt_logprobs is not None:
# For tokens in the prompt that we only need to get
# their logprobs
query_len = seq_group.query_len
@ -477,28 +449,6 @@ class SamplingTensors:
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
if _USE_TRITON_SAMPLER:
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
query_len = seq_group.query_len
assert query_len is not None
seed = sampling_params.seed
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
seq_data.get_len(),
*extra_entropy,
seq_id,
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.extend(seq_group.sample_indices)
if do_penalties:
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
@ -518,23 +468,37 @@ class SamplingTensors:
output_tokens.append(seq_data.output_token_ids_array)
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, sampling_seeds,
sample_indices, prompt_tokens, output_tokens, vocab_size,
extra_seeds_to_generate, device, dtype)
temperatures,
top_ps,
top_ks,
min_ps,
presence_penalties,
frequency_penalties,
repetition_penalties,
prompt_tokens,
output_tokens,
vocab_size,
device,
dtype,
)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
@classmethod
def from_lists(cls, temperatures: List[float], top_ps: List[float],
top_ks: List[int], min_ps: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int],
prompt_tokens: List[array], output_tokens: List[array],
vocab_size: int, extra_seeds_to_generate: int,
device: torch.device,
dtype: torch.dtype) -> "SamplingTensors":
def from_lists(
cls,
temperatures: List[float],
top_ps: List[float],
top_ks: List[int],
min_ps: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
prompt_tokens: List[array],
output_tokens: List[array],
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
pin_memory = is_pin_memory_available()
@ -603,34 +567,9 @@ class SamplingTensors:
dtype=torch.int,
pin_memory=pin_memory,
)
sample_indices_t = torch.tensor(
sample_indices,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t = torch.tensor(
sampling_seeds,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).t().contiguous()
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
sampling_seeds_gpu = sampling_seeds_t.to(device=device,
non_blocking=True)
extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
if not extra_seeds_gpu.numel():
extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
@ -644,38 +583,4 @@ class SamplingTensors:
non_blocking=True),
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device,
non_blocking=True),
extra_seeds=extra_seeds_gpu,
)
@staticmethod
def _get_sequence_seeds(
seed: int,
*extra_entropy: int,
seeds_to_generate: int,
is_greedy: bool,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if not is_greedy:
if seed is None:
randint_fn = random.randint
else:
generator = random.Random(str((seed, ) + extra_entropy))
randint_fn = generator.randint
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds = [
randint_fn(lo, hi) or _SEED_0_REPLACEMENT
for _ in range(seeds_to_generate)
]
else:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds = [0] * seeds_to_generate
return seq_seeds

View File

@ -1,13 +0,0 @@
import math
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)

View File

@ -270,7 +270,7 @@ class LRUCache(Generic[T]):
class PyObjectCache:
"""Used to cache python objects to avoid object allocations
"""Used to cache python objects to avoid object allocations
across scheduler iterations.
"""
@ -289,7 +289,7 @@ class PyObjectCache:
self._obj_cache.append(self._obj_builder())
def get_object(self):
"""Returns a pre-allocated cached object. If there is not enough
"""Returns a pre-allocated cached object. If there is not enough
objects, then the cache size will double.
"""
if self._index >= len(self._obj_cache):
@ -837,15 +837,6 @@ def async_tensor_h2d(
return t.to(device=target_device, non_blocking=True)
def maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
"""Expand the tensor to the target_dims."""
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size()
@ -1070,7 +1061,7 @@ def _cuda_device_count_stateless(
def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
@ -1136,10 +1127,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
def _pull_args_from_config(args: List[str]) -> List[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
The arguments in config file will be inserted between
The arguments in config file will be inserted between
the argument list.
example:
```yaml
port: 12323
@ -1150,21 +1141,21 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
--config config.yaml -tp 2
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--config', 'config.yaml',
"facebook/opt-12B",
'--config', 'config.yaml',
'-tp', '2'
]
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
'-tp', '2'
]
```
Please note how the config args are inserted after the sub command.
this way the order of priorities is maintained when these are args
this way the order of priorities is maintained when these are args
parsed by super().
"""
assert args.count(
@ -1190,7 +1181,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
@staticmethod
def _load_config_file(file_path: str) -> List[str]:
"""Loads a yaml file and returns the key value pairs as a
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
port: 12323
@ -1201,7 +1192,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split('.')[-1]