vllm/vllm/worker/tpu_model_runner.py
Woosuk Kwon 2f77b6cfec
[TPU] Implement prefix caching for TPUs (#10307)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-11-20 13:54:15 -08:00

895 lines
39 KiB
Python

import enum
import time
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch
import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: AttentionMetadata
input_lens: torch.Tensor
t: torch.Tensor
p: torch.Tensor
num_samples: int
n: List[int]
seq_groups: List[List[int]]
is_first_multi_step: bool = True
is_last_step: bool = True
virtual_engine: int = 0
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"token_ids": self.token_ids,
"position_ids": self.position_ids,
"input_lens": self.input_lens,
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
"n": self.n,
"seq_groups": self.seq_groups,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
"virtual_engine": self.virtual_engine,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForTPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForTPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__(
self,
vllm_config: VllmConfig,
is_driver_worker: bool = False,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
self.block_size)
self.block_tables = np.zeros(
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
dtype=np.int32)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
False,
)
self.cached_step_outputs: List[torch.Tensor] = []
smem_size = 512 * 1024
block_table_size = 4 * self.block_tables.size
if block_table_size >= smem_size:
logger.warning(
"The max_model_len (%d) is too large. This may degrade the "
"performance due to the insufficient smem size. Consider "
"setting --max-model-len to a smaller value.",
self.model_config.max_model_len)
def load_model(self) -> None:
self.device = self.device_config.device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def _dummy_run(
self,
batch_size: int,
seq_len: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
exec_mode: ExecutionMode,
) -> None:
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros(
(batch_size, self.max_num_blocks_per_seq),
dtype=torch.int32,
device=self.device)
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables,
context_lens=context_lens,
)
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(input_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
def warmup_model(
self,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> None:
# Prefill
logger.info("Compiling the model with different input shapes...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefill done in %.2f s.", end - start)
# Prefix prefill
if self.cache_config.enable_prefix_caching:
logger.info("Compiling the model with different input shapes for "
"prefix prefill...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens >=
self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefix prefill done in %.2f s.",
end - start)
# Decode
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info("Compilation for decode done in %.2f s.", end - start)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
prompt_lens: List[int] = []
context_lens: List[int] = []
slot_mapping: List[int] = []
for batch_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
# Could include output tokens when a request is preempted.
prompt_tokens = seq_data.get_token_ids()
seq_len = len(prompt_tokens)
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
num_computed_tokens = num_computed_blocks * self.block_size
if num_computed_tokens > 0:
prompt_tokens = prompt_tokens[num_computed_tokens:]
context_lens.append(seq_len)
else:
context_lens.append(0)
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
input_positions.extend(range(num_computed_tokens, seq_len))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(num_computed_tokens, seq_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if num_computed_tokens > 0:
self.block_tables[batch_idx, :len(block_table)] = block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len = _get_padded_prefill_len(prompt_len)
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:num_prefills],
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
)
return input_tokens, input_positions, attn_metadata, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[batch_idx, :len(block_table)] = block_table
batch_idx += 1
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device="cpu")
input_lens = torch.tensor([1] * batch_size,
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables,
context_lens=context_lens,
)
return input_tokens, input_positions, attn_metadata, input_lens
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
padded_batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
t = []
p = []
n = []
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
t.append(sampling_params.temperature)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues.")
p.append(sampling_params.top_p)
if sampling_params.top_k != -1:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.n > _MAX_NUM_SAMPLES:
raise NotImplementedError(
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
n.append(sampling_params.n)
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")
if sampling_params.prompt_logprobs is not None:
raise NotImplementedError(
"prompt_logprobs is not currently supported by the TPU "
"backend.")
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1)
n += [n[-1]] * (num_seqs - 1)
num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings
p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, n
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForTPU:
del finished_requests_ids # Unused.
assert virtual_engine == 0
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0]
t, p, n = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
seq_groups = [
list(metadata.seq_data.keys())
for metadata in seq_group_metadata_list
]
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, n, seq_groups)
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=self.attn_backend)
return model_input
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: Optional[List[Any]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
return []
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=i == 0)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
return sampler_outputs
is_prompt = model_input.attn_metadata.num_prefills > 0
if is_prompt:
assert num_steps == 1
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
orig_slot_mapping = model_input.attn_metadata.slot_mapping
orig_block_tables = model_input.attn_metadata.block_tables
orig_context_lens = model_input.attn_metadata.context_lens
orig_effective_query_lens = \
model_input.attn_metadata.effective_query_lens
batch_size = model_input.input_lens.shape[0]
start_idx = 0
next_token_ids = []
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
self.device)
position_ids = model_input.position_ids[None,
start_idx:end_idx].to(
self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx].to(self.device)
if orig_context_lens[i].item() > 0:
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
self.device)
attn_metadata.block_tables = orig_block_tables[
i].unsqueeze(0).to(self.device)
attn_metadata.effective_query_lens = \
orig_effective_query_lens[i:i + 1].to(self.device)
else:
attn_metadata.context_lens = None
attn_metadata.block_tables = None
attn_metadata.effective_query_lens = None
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids = [
output_token_ids.cpu().tolist()
for output_token_ids in next_token_ids
]
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support advanced sampling parameters such as logprobs.
zero_logprob = Logprob(0.0)
sampler_outputs = []
for i, seq_group in enumerate(model_input.seq_groups):
seq_ids = seq_group
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_outputs = []
for j in range(model_input.n[i]):
next_token_id = next_token_ids[i][j]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return [SamplerOutput(sampler_outputs)]
else:
token_ids = model_input.token_ids.to(self.device)
position_ids = model_input.position_ids.to(self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
self.device)
attn_metadata.block_tables = attn_metadata.block_tables.to(
self.device)
attn_metadata.context_lens = attn_metadata.context_lens.to(
self.device)
t = model_input.t.to(self.device)
p = model_input.p.to(self.device)
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, p,
model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)
if i < num_steps - 1:
# Prepare the inputs for the next step.
token_ids = output_token_ids.unsqueeze(dim=1).int()
position_ids = position_ids + 1
attn_metadata.context_lens = attn_metadata.context_lens + 1
block_tables = attn_metadata.block_tables
block_number = block_tables.gather(
1,
position_ids.long() // self.block_size)
block_offset = position_ids % self.block_size
is_padding = slot_mapping == _PAD_SLOT_ID
slot_mapping = block_number * self.block_size + block_offset
slot_mapping = slot_mapping.long()
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
slot_mapping)
attn_metadata.slot_mapping = slot_mapping
if model_input.async_callback is not None:
model_input.async_callback()
if num_steps > 1:
return []
# Retrieve the outputs to CPU.
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
return [sampler_output]
class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size, seq_len = token_ids.shape
# Calculate the positions to sample from.
start_indicies = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = start_indicies + input_lens - 1
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
sampling_metadata = SamplingMetadata(
seq_groups=[],
selected_token_indices=logits_indices,
categorized_sample_indices={},
num_prompts=attn_metadata.num_prefills,
)
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indicies *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
if num_samples == 1:
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids
def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
return logits
def _make_decode_output(
next_token_ids: List[int],
seq_groups: List[List[int]],
) -> SamplerOutput:
zero_logprob = Logprob(0.0)
sampler_outputs = []
batch_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group
seq_outputs = []
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(CompletionSequenceGroupOutput(
seq_outputs, None))
return SamplerOutput(sampler_outputs)