mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 12:45:01 +08:00
[Misc] Consolidate and optimize logic for building padded tensors (#6541)
This commit is contained in:
parent
3f8d42c81f
commit
9042d68362
@ -21,7 +21,8 @@ from vllm.distributed import (destroy_distributed_environment,
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import cuda_device_count_stateless, is_cpu
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
is_cpu)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -124,12 +125,6 @@ def image_assets() -> _ImageAssets:
|
||||
return IMAGE_ASSETS
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
}
|
||||
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
|
||||
|
||||
|
||||
@ -151,8 +146,7 @@ class HfRunner:
|
||||
is_vision_model: bool = False,
|
||||
is_sparseml_model: bool = False,
|
||||
) -> None:
|
||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
@ -306,11 +306,8 @@ class FlashAttentionMetadataBuilder(
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device=device)
|
||||
else:
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in self.block_tables)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
|
||||
@ -344,11 +344,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
cuda_graph_pad_size)
|
||||
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
|
||||
else:
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in self.block_tables)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
|
||||
@ -182,11 +182,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device=device)
|
||||
else:
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in self.block_tables)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
|
||||
@ -2,14 +2,13 @@ import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||
maybe_expand_dim)
|
||||
make_tensor_with_pad, maybe_expand_dim)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_SEED_0_REPLACEMENT = 3403598558
|
||||
@ -466,22 +465,24 @@ class SamplingTensors:
|
||||
do_penalties = prompt_tokens or output_tokens
|
||||
|
||||
if do_penalties:
|
||||
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
|
||||
default=0)
|
||||
prompt_padded_tokens = np.full(
|
||||
(len(prompt_tokens), prompt_max_len),
|
||||
prompt_t = make_tensor_with_pad(
|
||||
prompt_tokens,
|
||||
vocab_size,
|
||||
dtype=np.int64)
|
||||
for i, tokens in enumerate(prompt_tokens):
|
||||
prompt_padded_tokens[i, :len(tokens)] = tokens
|
||||
output_max_len = max([len(tokens) for tokens in output_tokens],
|
||||
default=0)
|
||||
output_padded_tokens = np.full(
|
||||
(len(output_tokens), output_max_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
output_t = make_tensor_with_pad(
|
||||
output_tokens,
|
||||
vocab_size,
|
||||
dtype=np.int64)
|
||||
for i, tokens in enumerate(output_tokens):
|
||||
output_padded_tokens[i, :len(tokens)] = tokens
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
else:
|
||||
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
||||
prompt_t = empty_tensor
|
||||
output_t = empty_tensor
|
||||
|
||||
temperatures_t = torch.tensor(
|
||||
temperatures,
|
||||
@ -531,15 +532,6 @@ class SamplingTensors:
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
if do_penalties:
|
||||
prompt_tensor = torch.from_numpy(prompt_padded_tokens)
|
||||
output_tensor = torch.from_numpy(output_padded_tokens)
|
||||
if pin_memory:
|
||||
prompt_tensor = prompt_tensor.pin_memory()
|
||||
output_tensor = output_tensor.pin_memory()
|
||||
else:
|
||||
prompt_tensor = None
|
||||
output_tensor = None
|
||||
# need to transpose and make contiguous to
|
||||
# copy the tensor correctly.
|
||||
# [batch_size, n_seeds] -> [n_seeds, batch_size]
|
||||
@ -562,16 +554,6 @@ class SamplingTensors:
|
||||
extra_seeds_gpu = None
|
||||
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
||||
|
||||
if do_penalties:
|
||||
prompt_tokens_gpu = prompt_tensor.to(device=device,
|
||||
non_blocking=True)
|
||||
output_tokens_gpu = output_tensor.to(device=device,
|
||||
non_blocking=True)
|
||||
else:
|
||||
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
||||
prompt_tokens_gpu = empty_tensor
|
||||
output_tokens_gpu = empty_tensor
|
||||
|
||||
return cls(
|
||||
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
||||
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
||||
@ -583,8 +565,8 @@ class SamplingTensors:
|
||||
non_blocking=True),
|
||||
repetition_penalties=repetition_penalties_t.to(device=device,
|
||||
non_blocking=True),
|
||||
prompt_tokens=prompt_tokens_gpu,
|
||||
output_tokens=output_tokens_gpu,
|
||||
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
||||
output_tokens=output_t.to(device=device, non_blocking=True),
|
||||
sampling_seeds=sampling_seeds_gpu,
|
||||
sample_indices=sample_indices_t.to(device=device,
|
||||
non_blocking=True),
|
||||
|
||||
@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"fp8_e5m2": torch.uint8,
|
||||
}
|
||||
|
||||
TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
torch.float64: np.float64,
|
||||
torch.uint8: np.uint8,
|
||||
torch.int32: np.int32,
|
||||
torch.int64: np.int64,
|
||||
}
|
||||
|
||||
P = ParamSpec('P')
|
||||
K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
|
||||
f"(e.g., 1, 2, 3). Given input: {s}") from e
|
||||
|
||||
|
||||
def make_tensor_with_pad(
|
||||
x: List[List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Optional[Union[str, torch.device]],
|
||||
) -> torch.Tensor:
|
||||
"""Make a padded tensor of a 2D inputs.
|
||||
def make_ndarray_with_pad(
|
||||
x: List[List[T]],
|
||||
pad: T,
|
||||
dtype: npt.DTypeLike,
|
||||
*,
|
||||
max_len: Optional[int] = None,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Make a padded array from 2D inputs.
|
||||
|
||||
The padding is applied to the end of each inner list until it reaches
|
||||
`max_len`.
|
||||
"""
|
||||
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
|
||||
if max_len is None:
|
||||
# Unlike for most functions, map is faster than a genexpr over `len`
|
||||
max_len = max(map(len, x), default=0)
|
||||
|
||||
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
|
||||
for ind, blocktb in enumerate(x):
|
||||
assert len(blocktb) <= max_len
|
||||
padded_x[ind, :len(blocktb)] = blocktb
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
||||
|
||||
return padded_x
|
||||
|
||||
|
||||
def make_tensor_with_pad(
|
||||
x: List[List[T]],
|
||||
pad: T,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
max_len: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Make a padded tensor from 2D inputs.
|
||||
|
||||
The padding is applied to the end of each inner list until it reaches
|
||||
`max_len`.
|
||||
"""
|
||||
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
|
||||
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
|
||||
|
||||
tensor = torch.from_numpy(padded_x).to(device)
|
||||
if pin_memory:
|
||||
tensor = tensor.pin_memory()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def async_tensor_h2d(
|
||||
|
||||
@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in block_tables)
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
|
||||
@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_seq_len > 0
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
max_seq_len,
|
||||
pad=0,
|
||||
max_len=max_seq_len,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
max_seq_len,
|
||||
pad=0,
|
||||
max_len=max_seq_len,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_block_ids = torch.tensor(input_block_ids,
|
||||
@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
max_len=1,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
max_len=1,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
|
||||
@ -335,11 +335,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in block_tables)
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user