mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-25 14:37:11 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
915 lines
40 KiB
Python
915 lines
40 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import enum
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
|
Type, Union)
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch_xla.core.xla_model as xm
|
|
import torch_xla.runtime as xr
|
|
|
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
|
from vllm.config import VllmConfig
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
|
from vllm.worker.model_runner_base import (
|
|
ModelRunnerBase, ModelRunnerInputBase,
|
|
_add_attn_metadata_broadcastable_dict,
|
|
_init_attn_metadata_from_tensor_dict)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Here we utilize the behavior that out-of-bound index is ignored.
|
|
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
|
_PAD_SLOT_ID = 1_000_000_000
|
|
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
|
_ENABLE_TOP_P = False
|
|
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
|
# This can significantly affect the performance if too large.
|
|
_MAX_NUM_SAMPLES = 128
|
|
|
|
|
|
class ExecutionMode(enum.Enum):
|
|
PREFILL = enum.auto()
|
|
DECODE = enum.auto()
|
|
PREFIX_PREFILL = enum.auto()
|
|
|
|
def is_prefill(self) -> bool:
|
|
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ModelInputForTPU(ModelRunnerInputBase):
|
|
token_ids: torch.Tensor
|
|
position_ids: torch.Tensor
|
|
attn_metadata: AttentionMetadata
|
|
input_lens: torch.Tensor
|
|
t: torch.Tensor
|
|
p: torch.Tensor
|
|
num_samples: int
|
|
n: List[int]
|
|
seq_groups: List[List[int]]
|
|
is_first_multi_step: bool = True
|
|
is_last_step: bool = True
|
|
virtual_engine: int = 0
|
|
async_callback: Optional[Callable] = None
|
|
|
|
def as_broadcastable_tensor_dict(
|
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
|
tensor_dict = {
|
|
"token_ids": self.token_ids,
|
|
"position_ids": self.position_ids,
|
|
"input_lens": self.input_lens,
|
|
"t": self.t,
|
|
"p": self.p,
|
|
"num_samples": self.num_samples,
|
|
"n": self.n,
|
|
"seq_groups": self.seq_groups,
|
|
"is_first_multi_step": self.is_first_multi_step,
|
|
"is_last_step": self.is_last_step,
|
|
"virtual_engine": self.virtual_engine,
|
|
}
|
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
|
return tensor_dict
|
|
|
|
@classmethod
|
|
def from_broadcasted_tensor_dict(
|
|
cls: Type["ModelInputForTPU"],
|
|
tensor_dict: Dict[str, Any],
|
|
attn_backend: Optional["AttentionBackend"] = None,
|
|
) -> "ModelInputForTPU":
|
|
if attn_backend is not None:
|
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
|
attn_backend, tensor_dict)
|
|
return cls(**tensor_dict)
|
|
|
|
|
|
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
is_driver_worker: bool = False,
|
|
):
|
|
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
|
self.is_driver_worker = is_driver_worker
|
|
|
|
self.block_size = self.cache_config.block_size
|
|
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
|
|
self.block_size)
|
|
self.block_tables = np.zeros(
|
|
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
|
|
dtype=np.int32)
|
|
self.attn_backend = get_attn_backend(
|
|
self.model_config.get_head_size(),
|
|
self.model_config.dtype,
|
|
self.cache_config.cache_dtype,
|
|
self.block_size,
|
|
self.model_config.is_attention_free,
|
|
False,
|
|
)
|
|
self.cached_step_outputs: List[torch.Tensor] = []
|
|
|
|
smem_size = 512 * 1024
|
|
block_table_size = 4 * self.block_tables.size
|
|
if block_table_size >= smem_size:
|
|
logger.warning(
|
|
"The max_model_len (%d) is too large. This may degrade the "
|
|
"performance due to the insufficient smem size. Consider "
|
|
"setting --max-model-len to a smaller value, like %d.",
|
|
self.model_config.max_model_len,
|
|
self.model_config.max_model_len /
|
|
(block_table_size / smem_size))
|
|
|
|
def load_model(self) -> None:
|
|
self.device = self.device_config.device
|
|
|
|
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
|
# process, the ranks can be different from the ranks internally assigned
|
|
# by the xm runtime. Therefore, there is a mismatch in the rank
|
|
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
|
# This is not a problem in linear layers because all-reduce is
|
|
# rank-agnostic. However, it matters for all-gather as the ranks
|
|
# determine the order of concatenating the output tensors.
|
|
# As a workaround, we use the xm's rank assignment only when loading
|
|
# the embedding weights.
|
|
xm_tp_rank = xr.global_ordinal()
|
|
with patch(
|
|
"vllm.model_executor.layers.vocab_parallel_embedding."
|
|
"get_tensor_model_parallel_rank",
|
|
return_value=xm_tp_rank):
|
|
model = get_model(vllm_config=self.vllm_config)
|
|
model = model.eval()
|
|
xm.wait_device_ops()
|
|
model = ModelWrapper(model)
|
|
self.model = torch.compile(model,
|
|
backend="openxla",
|
|
fullgraph=True,
|
|
dynamic=False)
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model.model
|
|
|
|
def _dummy_run(
|
|
self,
|
|
batch_size: int,
|
|
seq_len: int,
|
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
exec_mode: ExecutionMode,
|
|
) -> None:
|
|
exec_mode = ExecutionMode(exec_mode)
|
|
if exec_mode.is_prefill():
|
|
seq_len = (seq_len + 15) // 16 * 16
|
|
token_ids = torch.zeros((batch_size, seq_len),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
position_ids = torch.zeros((batch_size, seq_len),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
slot_mapping = torch.zeros((batch_size, seq_len),
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
input_lens = torch.ones((batch_size, ),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
if exec_mode == ExecutionMode.PREFILL:
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
num_prefills=batch_size,
|
|
num_prefill_tokens=batch_size * seq_len,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
block_tables=None,
|
|
context_lens=None,
|
|
effective_query_lens=None,
|
|
)
|
|
else:
|
|
context_lens = torch.ones((batch_size, ),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
block_tables = torch.tensor(self.block_tables[:batch_size],
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
effective_query_lens = torch.ones_like(context_lens)
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
num_prefills=batch_size,
|
|
num_prefill_tokens=batch_size * seq_len,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
block_tables=block_tables,
|
|
context_lens=context_lens,
|
|
effective_query_lens=effective_query_lens,
|
|
)
|
|
else:
|
|
assert seq_len == 1
|
|
token_ids = torch.zeros((batch_size, seq_len),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
position_ids = torch.zeros((batch_size, seq_len),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
slot_mapping = torch.zeros((batch_size, seq_len),
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
block_tables = torch.zeros(
|
|
(batch_size, self.max_num_blocks_per_seq),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
context_lens = torch.ones((batch_size, ),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
input_lens = torch.ones((batch_size, ),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=batch_size * seq_len,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
block_tables=block_tables,
|
|
context_lens=context_lens,
|
|
)
|
|
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
|
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
|
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
|
|
|
|
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
|
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
|
# overhead by reusing the FX graph for different shapes.
|
|
# However, the XLA graph will still require static shapes and needs to
|
|
# be re-compiled for every different shapes. This overhead is inevitable
|
|
# in the first run, but can be skipped afterwards as we cache the XLA
|
|
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
|
if exec_mode.is_prefill():
|
|
# Prefll
|
|
torch._dynamo.mark_dynamic(token_ids, 1)
|
|
torch._dynamo.mark_dynamic(position_ids, 1)
|
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
|
else:
|
|
# Decode
|
|
torch._dynamo.mark_dynamic(token_ids, 0)
|
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
|
torch._dynamo.mark_dynamic(input_lens, 0)
|
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
|
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
|
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
|
torch._dynamo.mark_dynamic(t, 0)
|
|
torch._dynamo.mark_dynamic(p, 0)
|
|
# Dummy run.
|
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
|
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
|
|
p, num_samples, kv_caches)
|
|
|
|
def warmup_model(
|
|
self,
|
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
) -> None:
|
|
# Prefill
|
|
logger.info("Compiling the model with different input shapes...")
|
|
start = time.time()
|
|
for batch_size in [1]:
|
|
seq_len = 16
|
|
while seq_len <= self.model_config.max_model_len:
|
|
self._dummy_run(batch_size,
|
|
seq_len,
|
|
kv_caches,
|
|
exec_mode=ExecutionMode.PREFILL)
|
|
xm.wait_device_ops()
|
|
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
|
num_tokens = batch_size * seq_len
|
|
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
|
break
|
|
seq_len = seq_len * 2
|
|
|
|
end = time.time()
|
|
logger.info("Compilation for prefill done in %.2f s.", end - start)
|
|
|
|
# Prefix prefill
|
|
if self.cache_config.enable_prefix_caching:
|
|
logger.info("Compiling the model with different input shapes for "
|
|
"prefix prefill...")
|
|
start = time.time()
|
|
for batch_size in [1]:
|
|
seq_len = 16
|
|
while seq_len <= self.model_config.max_model_len:
|
|
self._dummy_run(batch_size,
|
|
seq_len,
|
|
kv_caches,
|
|
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
|
xm.wait_device_ops()
|
|
logger.info("batch_size: %d, seq_len: %d", batch_size,
|
|
seq_len)
|
|
num_tokens = batch_size * seq_len
|
|
if (num_tokens
|
|
>= self.scheduler_config.max_num_batched_tokens):
|
|
break
|
|
seq_len = seq_len * 2
|
|
end = time.time()
|
|
logger.info("Compilation for prefix prefill done in %.2f s.",
|
|
end - start)
|
|
|
|
# Decode
|
|
start = time.time()
|
|
seq_len = 1
|
|
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
|
while True:
|
|
self._dummy_run(batch_size,
|
|
seq_len,
|
|
kv_caches,
|
|
exec_mode=ExecutionMode.DECODE)
|
|
xm.wait_device_ops()
|
|
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
|
|
|
if batch_size >= self.scheduler_config.max_num_seqs:
|
|
break
|
|
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
|
|
|
end = time.time()
|
|
logger.info("Compilation for decode done in %.2f s.", end - start)
|
|
|
|
def _prepare_prompt(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
|
assert len(seq_group_metadata_list) > 0
|
|
input_tokens: List[int] = []
|
|
input_positions: List[int] = []
|
|
prompt_lens: List[int] = []
|
|
context_lens: List[int] = []
|
|
slot_mapping: List[int] = []
|
|
|
|
for batch_idx, seq_group_metadata in enumerate(
|
|
seq_group_metadata_list):
|
|
assert seq_group_metadata.is_prompt
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
assert len(seq_ids) == 1
|
|
seq_id = seq_ids[0]
|
|
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
# Could include output tokens when a request is preempted.
|
|
prompt_tokens = seq_data.get_token_ids()
|
|
seq_len = len(prompt_tokens)
|
|
|
|
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
|
|
num_computed_tokens = num_computed_blocks * self.block_size
|
|
if num_computed_tokens > 0:
|
|
prompt_tokens = prompt_tokens[num_computed_tokens:]
|
|
context_lens.append(seq_len)
|
|
else:
|
|
context_lens.append(0)
|
|
|
|
prompt_len = len(prompt_tokens)
|
|
prompt_lens.append(prompt_len)
|
|
|
|
input_tokens.extend(prompt_tokens)
|
|
input_positions.extend(range(num_computed_tokens, seq_len))
|
|
|
|
assert seq_group_metadata.block_tables is not None
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
|
for i in range(num_computed_tokens, seq_len):
|
|
block_number = block_table[i // self.block_size]
|
|
block_offset = i % self.block_size
|
|
slot = block_number * self.block_size + block_offset
|
|
slot_mapping.append(slot)
|
|
if num_computed_tokens > 0:
|
|
self.block_tables[batch_idx, :len(block_table)] = block_table
|
|
|
|
# Add paddings to EACH prompt to the smallest power of 2 that is
|
|
# greater than or equal to the prompt length.
|
|
# We pad the seq_len to reduce the compilation overhead.
|
|
# We execute each prompt individually (i.e., with batch_size 1)
|
|
# because the FlashAttention kernel does not support ragged inputs.
|
|
# TODO(woosuk): Use SplashAttention to support ragged inputs.
|
|
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
|
num_paddings = padded_prompt_len - prompt_len
|
|
input_tokens += [0] * num_paddings
|
|
input_positions += [0] * num_paddings
|
|
slot_mapping += [_PAD_SLOT_ID] * num_paddings
|
|
|
|
assert len(prompt_lens) > 0
|
|
num_prefills = len(prompt_lens)
|
|
input_tokens = torch.tensor(input_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
input_positions = torch.tensor(input_positions,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
slot_mapping = torch.tensor(slot_mapping,
|
|
dtype=torch.int64,
|
|
device="cpu")
|
|
prompt_lens = torch.tensor(prompt_lens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
context_lens = torch.tensor(context_lens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
block_tables = torch.tensor(self.block_tables[:num_prefills],
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
num_prefills=num_prefills,
|
|
num_prefill_tokens=0, # NOTE: This is not used.
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
block_tables=block_tables,
|
|
context_lens=context_lens,
|
|
effective_query_lens=prompt_lens,
|
|
)
|
|
return input_tokens, input_positions, attn_metadata, prompt_lens
|
|
|
|
def _prepare_decode(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
|
assert len(seq_group_metadata_list) > 0
|
|
input_tokens: List[List[int]] = []
|
|
input_positions: List[List[int]] = []
|
|
slot_mapping: List[List[int]] = []
|
|
context_lens: List[int] = []
|
|
|
|
batch_idx = 0
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
assert not seq_group_metadata.is_prompt
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
for seq_id in seq_ids:
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
generation_token = seq_data.get_last_token_id()
|
|
input_tokens.append([generation_token])
|
|
|
|
seq_len = seq_data.get_len()
|
|
position = seq_len - 1
|
|
input_positions.append([position])
|
|
context_lens.append(seq_len)
|
|
|
|
assert seq_group_metadata.block_tables is not None
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
|
self.block_tables[batch_idx, :len(block_table)] = block_table
|
|
batch_idx += 1
|
|
|
|
block_number = block_table[position // self.block_size]
|
|
block_offset = position % self.block_size
|
|
slot = block_number * self.block_size + block_offset
|
|
slot_mapping.append([slot])
|
|
|
|
batch_size = _get_padded_batch_size(batch_idx)
|
|
num_paddings = batch_size - batch_idx
|
|
input_tokens = input_tokens + [[0]] * num_paddings
|
|
input_positions = input_positions + [[0]] * num_paddings
|
|
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
|
context_lens = context_lens + [0] * num_paddings
|
|
|
|
input_tokens = torch.tensor(input_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
input_positions = torch.tensor(input_positions,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
slot_mapping = torch.tensor(slot_mapping,
|
|
dtype=torch.int64,
|
|
device="cpu")
|
|
context_lens = torch.tensor(context_lens,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
block_tables = torch.tensor(self.block_tables[:batch_size],
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
input_lens = torch.tensor([1] * batch_size,
|
|
dtype=torch.int32,
|
|
device="cpu")
|
|
attn_metadata = self.attn_backend.make_metadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=batch_size,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
block_tables=block_tables,
|
|
context_lens=context_lens,
|
|
)
|
|
return input_tokens, input_positions, attn_metadata, input_lens
|
|
|
|
def _prepare_sample(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
padded_batch_size: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
|
assert len(seq_group_metadata_list) > 0
|
|
t = []
|
|
p = []
|
|
n = []
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
sampling_params = seq_group_metadata.sampling_params
|
|
t.append(sampling_params.temperature)
|
|
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
|
raise NotImplementedError(
|
|
"Top-p sampling is currently disabled for the TPU backend "
|
|
"due to performance issues.")
|
|
p.append(sampling_params.top_p)
|
|
if sampling_params.top_k != -1:
|
|
raise NotImplementedError(
|
|
"Top-k sampling is currently disabled for the TPU backend "
|
|
"due to performance issues.")
|
|
if sampling_params.n > _MAX_NUM_SAMPLES:
|
|
raise NotImplementedError(
|
|
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
|
"backend.")
|
|
n.append(sampling_params.n)
|
|
if sampling_params.logprobs is not None:
|
|
raise NotImplementedError(
|
|
"logprobs is not currently supported by the TPU backend.")
|
|
if sampling_params.prompt_logprobs is not None:
|
|
raise NotImplementedError(
|
|
"prompt_logprobs is not currently supported by the TPU "
|
|
"backend.")
|
|
|
|
# Repeat the sampling params if the seq group has multiple seqs.
|
|
num_seqs = len(seq_group_metadata.seq_data)
|
|
t += [t[-1]] * (num_seqs - 1)
|
|
p += [p[-1]] * (num_seqs - 1)
|
|
n += [n[-1]] * (num_seqs - 1)
|
|
|
|
num_paddings = padded_batch_size - len(t)
|
|
t += [1.0] * num_paddings
|
|
p += [1.0] * num_paddings
|
|
|
|
t = torch.tensor(t, dtype=torch.float32, device="cpu")
|
|
p = torch.tensor(p, dtype=torch.float32, device="cpu")
|
|
return t, p, n
|
|
|
|
def prepare_model_input(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
virtual_engine: int = 0,
|
|
finished_requests_ids: Optional[List[str]] = None,
|
|
) -> ModelInputForTPU:
|
|
del finished_requests_ids # Unused.
|
|
assert virtual_engine == 0
|
|
assert len(seq_group_metadata_list) > 0
|
|
# NOTE: We assume that all sequences in the group are all prompts or
|
|
# all decodes.
|
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
|
if is_prompt:
|
|
inputs = self._prepare_prompt(seq_group_metadata_list)
|
|
else:
|
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
|
input_tokens, input_positions, attn_metadata, input_lens = inputs
|
|
padded_batch_size = input_tokens.shape[0]
|
|
t, p, n = self._prepare_sample(seq_group_metadata_list,
|
|
padded_batch_size)
|
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
|
|
|
seq_groups = [
|
|
list(metadata.seq_data.keys())
|
|
for metadata in seq_group_metadata_list
|
|
]
|
|
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
|
|
input_lens, t, p, num_samples, n, seq_groups)
|
|
|
|
def make_model_input_from_broadcasted_tensor_dict(
|
|
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
|
|
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
|
|
tensor_dict, attn_backend=self.attn_backend)
|
|
return model_input
|
|
|
|
@torch.no_grad()
|
|
def execute_model(
|
|
self,
|
|
model_input: ModelInputForTPU,
|
|
kv_caches: Optional[List[Any]],
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
num_steps: int = 1,
|
|
) -> List[SamplerOutput]:
|
|
assert intermediate_tensors is None
|
|
if not model_input.is_first_multi_step:
|
|
if not model_input.is_last_step:
|
|
return []
|
|
|
|
use_async_out_proc = model_input.async_callback is not None
|
|
sampler_outputs = []
|
|
num_outputs = len(self.cached_step_outputs)
|
|
for i in range(num_outputs):
|
|
next_token_ids = self.cached_step_outputs.pop(0)
|
|
next_token_ids = next_token_ids.cpu().tolist()
|
|
sampler_output = _make_decode_output(next_token_ids,
|
|
model_input.seq_groups)
|
|
sampler_outputs.append(sampler_output)
|
|
|
|
if i < num_outputs - 1 and use_async_out_proc:
|
|
assert model_input.async_callback is not None
|
|
ctx = model_input.async_callback.keywords[ # type: ignore
|
|
"ctx"]
|
|
ctx.append_output(
|
|
outputs=[sampler_output],
|
|
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
|
scheduler_outputs=ctx.scheduler_outputs,
|
|
is_async=False,
|
|
is_last_step=False,
|
|
is_first_step_output=i == 0)
|
|
model_input.async_callback()
|
|
if use_async_out_proc:
|
|
return [sampler_outputs[-1]]
|
|
else:
|
|
return sampler_outputs
|
|
|
|
is_prompt = model_input.attn_metadata.num_prefills > 0
|
|
if is_prompt:
|
|
assert num_steps == 1
|
|
# NOTE(woosuk): Since the FlashAttention kernel does not support
|
|
# ragged inputs, we split the prompts into different batches and
|
|
# process them separately. This is a temporary hack that should be
|
|
# optimized by using SplashAttention.
|
|
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
|
orig_block_tables = model_input.attn_metadata.block_tables
|
|
orig_context_lens = model_input.attn_metadata.context_lens
|
|
orig_effective_query_lens = \
|
|
model_input.attn_metadata.effective_query_lens
|
|
batch_size = model_input.input_lens.shape[0]
|
|
start_idx = 0
|
|
next_token_ids = []
|
|
for i in range(batch_size):
|
|
# Get the actual prefill_len.
|
|
prefill_len = model_input.input_lens[i:i + 1].item()
|
|
prefill_len = _get_padded_prefill_len(prefill_len)
|
|
end_idx = start_idx + prefill_len
|
|
|
|
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
|
|
self.device)
|
|
position_ids = model_input.position_ids[None,
|
|
start_idx:end_idx].to(
|
|
self.device)
|
|
attn_metadata = model_input.attn_metadata
|
|
attn_metadata.num_prefills = 1
|
|
attn_metadata.slot_mapping = orig_slot_mapping[
|
|
None, start_idx:end_idx].to(self.device)
|
|
if orig_context_lens[i].item() > 0:
|
|
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
|
|
self.device)
|
|
attn_metadata.block_tables = orig_block_tables[
|
|
i].unsqueeze(0).to(self.device)
|
|
attn_metadata.effective_query_lens = \
|
|
orig_effective_query_lens[i:i + 1].to(self.device)
|
|
else:
|
|
attn_metadata.context_lens = None
|
|
attn_metadata.block_tables = None
|
|
attn_metadata.effective_query_lens = None
|
|
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
|
t = model_input.t[i:i + 1].to(self.device)
|
|
p = model_input.p[i:i + 1].to(self.device)
|
|
with set_forward_context(model_input.attn_metadata,
|
|
self.vllm_config,
|
|
model_input.virtual_engine):
|
|
output_token_ids = self.model(token_ids, position_ids,
|
|
attn_metadata, input_lens, t,
|
|
p, model_input.num_samples,
|
|
kv_caches)
|
|
next_token_ids.append(output_token_ids[0])
|
|
start_idx = end_idx
|
|
|
|
if model_input.async_callback is not None:
|
|
model_input.async_callback()
|
|
# Retrieve the outputs to CPU.
|
|
next_token_ids = [
|
|
output_token_ids.cpu().tolist()
|
|
for output_token_ids in next_token_ids
|
|
]
|
|
|
|
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
|
# The TPU backend does not reuse the sampler, since the TPU backend
|
|
# does not support advanced sampling parameters such as logprobs.
|
|
zero_logprob = Logprob(0.0)
|
|
sampler_outputs = []
|
|
for i, seq_group in enumerate(model_input.seq_groups):
|
|
seq_ids = seq_group
|
|
assert len(seq_ids) == 1
|
|
seq_id = seq_ids[0]
|
|
seq_outputs = []
|
|
for j in range(model_input.n[i]):
|
|
next_token_id = next_token_ids[i][j]
|
|
seq_outputs.append(
|
|
SequenceOutput(seq_id, next_token_id,
|
|
{next_token_id: zero_logprob}))
|
|
sampler_outputs.append(
|
|
CompletionSequenceGroupOutput(seq_outputs, None))
|
|
return [SamplerOutput(sampler_outputs)]
|
|
else:
|
|
token_ids = model_input.token_ids.to(self.device)
|
|
position_ids = model_input.position_ids.to(self.device)
|
|
attn_metadata = model_input.attn_metadata
|
|
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
|
|
self.device)
|
|
attn_metadata.block_tables = attn_metadata.block_tables.to(
|
|
self.device)
|
|
attn_metadata.context_lens = attn_metadata.context_lens.to(
|
|
self.device)
|
|
t = model_input.t.to(self.device)
|
|
p = model_input.p.to(self.device)
|
|
input_lens = model_input.input_lens.to(self.device)
|
|
for i in range(num_steps):
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
with set_forward_context(model_input.attn_metadata,
|
|
self.vllm_config,
|
|
model_input.virtual_engine):
|
|
output_token_ids = self.model(token_ids, position_ids,
|
|
attn_metadata, input_lens, t,
|
|
p, model_input.num_samples,
|
|
kv_caches)
|
|
self.cached_step_outputs.append(output_token_ids)
|
|
|
|
if i < num_steps - 1:
|
|
# Prepare the inputs for the next step.
|
|
token_ids = output_token_ids.unsqueeze(dim=1).int()
|
|
position_ids = position_ids + 1
|
|
attn_metadata.context_lens = attn_metadata.context_lens + 1
|
|
|
|
block_tables = attn_metadata.block_tables
|
|
block_number = block_tables.gather(
|
|
1,
|
|
position_ids.long() // self.block_size)
|
|
block_offset = position_ids % self.block_size
|
|
|
|
is_padding = slot_mapping == _PAD_SLOT_ID
|
|
slot_mapping = block_number * self.block_size + block_offset
|
|
slot_mapping = slot_mapping.long()
|
|
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
|
|
slot_mapping)
|
|
attn_metadata.slot_mapping = slot_mapping
|
|
|
|
if model_input.async_callback is not None:
|
|
model_input.async_callback()
|
|
|
|
if num_steps > 1:
|
|
return []
|
|
# Retrieve the outputs to CPU.
|
|
next_token_ids = self.cached_step_outputs.pop(0)
|
|
next_token_ids = next_token_ids.cpu().tolist()
|
|
sampler_output = _make_decode_output(next_token_ids,
|
|
model_input.seq_groups)
|
|
return [sampler_output]
|
|
|
|
|
|
class ModelWrapper(nn.Module):
|
|
|
|
def __init__(self, model: nn.Module):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def forward(
|
|
self,
|
|
token_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
input_lens: torch.Tensor,
|
|
t: torch.Tensor,
|
|
p: torch.Tensor,
|
|
num_samples: int,
|
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
) -> torch.Tensor:
|
|
"""Executes the forward pass of the model and samples the next token.
|
|
|
|
Args:
|
|
token_ids: The input token IDs of shape [batch_size, seq_len].
|
|
position_ids: The input position IDs of shape [batch_size, seq_len].
|
|
attn_metadata: The Pallas attention metadata.
|
|
input_lens: The actual input lengths of shape [batch_size].
|
|
t: The sampling temperature of shape [batch_size].
|
|
p: The top-p probability of shape [batch_size].
|
|
num_samples: Number of samples to draw from each logits vector.
|
|
kv_caches: The key and value caches. They can be None during the
|
|
memory profiling at initialization.
|
|
"""
|
|
batch_size, seq_len = token_ids.shape
|
|
# Calculate the positions to sample from.
|
|
start_indicies = torch.arange(
|
|
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
|
|
logits_indices = start_indicies + input_lens - 1
|
|
|
|
# FIXME(woosuk): This is a temporary hack to avoid using the existing
|
|
# sampler and sampling metadata.
|
|
sampling_metadata = SamplingMetadata(
|
|
seq_groups=[],
|
|
selected_token_indices=logits_indices,
|
|
categorized_sample_indices={},
|
|
num_prompts=attn_metadata.num_prefills,
|
|
)
|
|
|
|
# Skip this in memory profiling at initialization.
|
|
if kv_caches[0][0].numel() > 0:
|
|
# index_copy_(slot_mapping) only works when the inserted dimension
|
|
# is 0. However, the KV cache in the Pallas backend has the shape
|
|
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
|
# work, we need to flatten the first three dimensions and modify
|
|
# the slot_mapping accordingly.
|
|
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
slot_mapping = slot_mapping.flatten()
|
|
head_indicies = torch.arange(0,
|
|
num_kv_heads,
|
|
device=slot_mapping.device,
|
|
dtype=slot_mapping.dtype)
|
|
head_indicies *= block_size * num_blocks
|
|
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
|
-1, num_kv_heads)
|
|
slot_mapping = slot_mapping + head_indicies.view(1, -1)
|
|
slot_mapping = slot_mapping.flatten()
|
|
attn_metadata.slot_mapping = slot_mapping
|
|
|
|
hidden_states = self.model(
|
|
token_ids,
|
|
position_ids,
|
|
kv_caches,
|
|
attn_metadata,
|
|
)
|
|
hidden_states = hidden_states.flatten(0, 1)
|
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
|
|
|
# Argmax sampling.
|
|
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
|
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
|
|
|
|
# Zero temperature means greedy decoding. Avoid division by zero.
|
|
nonzero_t = torch.where(t != 0, t, 1.0)
|
|
logits = logits / nonzero_t.unsqueeze(dim=1)
|
|
if _ENABLE_TOP_P:
|
|
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
|
|
|
# Random sampling.
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
|
sampled_token_ids = torch.multinomial(probs,
|
|
num_samples,
|
|
replacement=True)
|
|
if num_samples == 1:
|
|
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
|
|
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
|
|
next_token_ids = torch.where(t != 0, sampled_token_ids,
|
|
argmax_token_ids)
|
|
return next_token_ids
|
|
|
|
|
|
def _get_padded_prefill_len(x: int) -> int:
|
|
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
|
# length to be a multiple of 16. We pad the prompt length to the nearest
|
|
# multiple of 16. This is also good for performance.
|
|
if x <= 16:
|
|
return 16
|
|
return 1 << (x - 1).bit_length()
|
|
|
|
|
|
def _get_padded_batch_size(batch_size: int) -> int:
|
|
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
|
|
# To meet this requirement in the simplest way, we set the minimal batch
|
|
# size to 8.
|
|
if batch_size <= 8:
|
|
return 8
|
|
else:
|
|
return ((batch_size + 15) // 16) * 16
|
|
|
|
|
|
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
|
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
|
|
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
|
|
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
|
|
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
|
|
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
|
|
return logits
|
|
|
|
|
|
def _make_decode_output(
|
|
next_token_ids: List[int],
|
|
seq_groups: List[List[int]],
|
|
) -> SamplerOutput:
|
|
zero_logprob = Logprob(0.0)
|
|
sampler_outputs = []
|
|
batch_idx = 0
|
|
for seq_group in seq_groups:
|
|
seq_ids = seq_group
|
|
seq_outputs = []
|
|
for seq_id in seq_ids:
|
|
next_token_id = next_token_ids[batch_idx]
|
|
seq_outputs.append(
|
|
SequenceOutput(seq_id, next_token_id,
|
|
{next_token_id: zero_logprob}))
|
|
batch_idx += 1
|
|
sampler_outputs.append(CompletionSequenceGroupOutput(
|
|
seq_outputs, None))
|
|
return SamplerOutput(sampler_outputs)
|