[Misc] Consolidate and optimize logic for building padded tensors (#6541)

This commit is contained in:
Cyrus Leung 2024-07-20 12:17:24 +08:00 committed by GitHub
parent 3f8d42c81f
commit 9042d68362
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 77 additions and 75 deletions

View File

@ -21,7 +21,8 @@ from vllm.distributed import (destroy_distributed_environment,
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -124,12 +125,6 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS return IMAGE_ASSETS
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
@ -151,8 +146,7 @@ class HfRunner:
is_vision_model: bool = False, is_vision_model: bool = False,
is_sparseml_model: bool = False, is_sparseml_model: bool = False,
) -> None: ) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name self.model_name = model_name

View File

@ -306,11 +306,8 @@ class FlashAttentionMetadataBuilder(
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device) block_tables = torch.tensor(input_block_tables, device=device)
else: else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=device, device=device,

View File

@ -344,11 +344,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cuda_graph_pad_size) cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
else: else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=device, device=device,

View File

@ -182,11 +182,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device) block_tables = torch.tensor(input_block_tables, device=device)
else: else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=device, device=device,

View File

@ -2,14 +2,13 @@ import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np
import torch import torch
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available, from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
maybe_expand_dim) make_tensor_with_pad, maybe_expand_dim)
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558 _SEED_0_REPLACEMENT = 3403598558
@ -466,22 +465,24 @@ class SamplingTensors:
do_penalties = prompt_tokens or output_tokens do_penalties = prompt_tokens or output_tokens
if do_penalties: if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens], prompt_t = make_tensor_with_pad(
default=0) prompt_tokens,
prompt_padded_tokens = np.full(
(len(prompt_tokens), prompt_max_len),
vocab_size, vocab_size,
dtype=np.int64) device="cpu",
for i, tokens in enumerate(prompt_tokens): dtype=torch.int64,
prompt_padded_tokens[i, :len(tokens)] = tokens pin_memory=pin_memory,
output_max_len = max([len(tokens) for tokens in output_tokens], )
default=0) output_t = make_tensor_with_pad(
output_padded_tokens = np.full( output_tokens,
(len(output_tokens), output_max_len),
vocab_size, vocab_size,
dtype=np.int64) device="cpu",
for i, tokens in enumerate(output_tokens): dtype=torch.int64,
output_padded_tokens[i, :len(tokens)] = tokens pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
output_t = empty_tensor
temperatures_t = torch.tensor( temperatures_t = torch.tensor(
temperatures, temperatures,
@ -531,15 +532,6 @@ class SamplingTensors:
dtype=torch.long, dtype=torch.long,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
if do_penalties:
prompt_tensor = torch.from_numpy(prompt_padded_tokens)
output_tensor = torch.from_numpy(output_padded_tokens)
if pin_memory:
prompt_tensor = prompt_tensor.pin_memory()
output_tensor = output_tensor.pin_memory()
else:
prompt_tensor = None
output_tensor = None
# need to transpose and make contiguous to # need to transpose and make contiguous to
# copy the tensor correctly. # copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size] # [batch_size, n_seeds] -> [n_seeds, batch_size]
@ -562,16 +554,6 @@ class SamplingTensors:
extra_seeds_gpu = None extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
if do_penalties:
prompt_tokens_gpu = prompt_tensor.to(device=device,
non_blocking=True)
output_tokens_gpu = output_tensor.to(device=device,
non_blocking=True)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_tokens_gpu = empty_tensor
output_tokens_gpu = empty_tensor
return cls( return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True), temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True),
@ -583,8 +565,8 @@ class SamplingTensors:
non_blocking=True), non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device, repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True), non_blocking=True),
prompt_tokens=prompt_tokens_gpu, prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_tokens_gpu, output_tokens=output_t.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu, sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device, sample_indices=sample_indices_t.to(device=device,
non_blocking=True), non_blocking=True),

View File

@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Union) Union)
import numpy as np import numpy as np
import numpy.typing as npt
import psutil import psutil
import torch import torch
import torch.types import torch.types
@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
} }
TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int32: np.int32,
torch.int64: np.int64,
}
P = ParamSpec('P') P = ParamSpec('P')
K = TypeVar("K") K = TypeVar("K")
T = TypeVar("T") T = TypeVar("T")
@ -617,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f"(e.g., 1, 2, 3). Given input: {s}") from e f"(e.g., 1, 2, 3). Given input: {s}") from e
def make_tensor_with_pad( def make_ndarray_with_pad(
x: List[List[int]], x: List[List[T]],
max_len: int, pad: T,
pad: int, dtype: npt.DTypeLike,
dtype: torch.dtype, *,
device: Optional[Union[str, torch.device]], max_len: Optional[int] = None,
) -> torch.Tensor: ) -> npt.NDArray:
"""Make a padded tensor of a 2D inputs. """
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches The padding is applied to the end of each inner list until it reaches
`max_len`. `max_len`.
""" """
padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x): for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len assert len(blocktb) <= max_len
padded_x[ind, :len(blocktb)] = blocktb padded_x[ind, :len(blocktb)] = blocktb
return torch.tensor(padded_x, dtype=dtype, device=device)
return padded_x
def make_tensor_with_pad(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def async_tensor_h2d( def async_tensor_h2d(

View File

@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
block_tables, block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=self.device, device=self.device,

View File

@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
assert max_seq_len > 0 assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_seq_len,
pad=0, pad=0,
max_len=max_seq_len,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_seq_len,
pad=0, pad=0,
max_len=max_seq_len,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_block_ids = torch.tensor(input_block_ids, input_block_ids = torch.tensor(input_block_ids,
@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids.append(block_table[0]) input_block_ids.append(block_table[0])
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_len=1,
pad=0, pad=0,
max_len=1,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_len=1,
pad=0, pad=0,
max_len=1,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,

View File

@ -335,11 +335,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
block_tables, block_tables,
max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=self.device, device=self.device,