[TPU] Refactor TPU worker & model runner (#6506)

This commit is contained in:
Woosuk Kwon 2024-07-18 01:34:16 -07:00 committed by GitHub
parent c8a7d51c49
commit 4634c8728b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 270 additions and 164 deletions

View File

@ -1,5 +1,6 @@
import time
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
@ -12,10 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SamplerOutput, SequenceGroupMetadata,
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput)
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
@ -27,7 +34,44 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES = 128
class TPUModelRunner:
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: AttentionMetadata
input_lens: torch.Tensor
t: torch.Tensor
p: torch.Tensor
num_samples: int
best_of: List[int]
seq_groups: List[List[int]]
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"token_ids": self.token_ids,
"position_ids": self.position_ids,
"input_lens": self.input_lens,
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForTPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForTPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__(
self,
@ -79,6 +123,7 @@ class TPUModelRunner:
multimodal_config=self.multimodal_config,
lora_config=None,
)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
@ -147,8 +192,8 @@ class TPUModelRunner:
# Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata,
input_lens, t, p, num_samples)
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
num_samples, kv_caches)
def warmup_model(
self,
@ -177,7 +222,7 @@ class TPUModelRunner:
# Decode
start = time.time()
seq_len = 1
batch_size = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
xm.wait_device_ops()
@ -195,10 +240,10 @@ class TPUModelRunner:
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
input_tokens: List[int] = []
input_positions: List[int] = []
prompt_lens: List[int] = []
slot_mapping: List[List[int]] = []
slot_mapping: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
@ -212,50 +257,46 @@ class TPUModelRunner:
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
input_tokens.extend(prompt_tokens)
input_positions.extend(list(range(prompt_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
slot_mapping.append([])
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
slot_mapping.append(slot)
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len = _get_padded_prefill_len(prompt_len)
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
num_prefill_tokens = sum(prompt_lens)
# Add paddings to make the shape [batch_size, max_prompt_len] where
# max_prompt_len is smallest power of 2 that is greater than or equal
# to the maximum prompt length.
# We need the 2D input shape because the Pallas FlashAttention kernel
# does not support packed 1D inputs.
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
max_prompt_len = _get_padded_prefill_len(max(prompt_lens))
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.int32,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.int32,
device=self.device)
slot_mapping = make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.int64,
device=self.device)
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32,
device=self.device)
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used.
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
block_tables=None,
@ -306,22 +347,22 @@ class TPUModelRunner:
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device=self.device)
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device=self.device)
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device=self.device)
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device=self.device)
device="cpu")
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device=self.device)
device="cpu")
input_lens = torch.tensor([1] * batch_size,
dtype=torch.int32,
device=self.device)
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
@ -382,16 +423,18 @@ class TPUModelRunner:
t += [1.0] * num_paddings
p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device=self.device)
p = torch.tensor(p, dtype=torch.float32, device=self.device)
t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, best_of
def _execute_model(
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> List[CompletionSequenceGroupOutput]:
# Prepare inputs.
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForTPU:
del finished_requests_ids # Unused.
assert virtual_engine == 0
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
@ -400,16 +443,104 @@ class TPUModelRunner:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0]
input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0]
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
# Execute the model.
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:], t, p, num_samples)
# Retrieve the outputs to CPU.
next_token_ids = next_token_ids.cpu().tolist()
seq_groups = [
list(metadata.seq_data.keys())
for metadata in seq_group_metadata_list
]
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, best_of,
seq_groups)
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=self.attn_backend)
return model_input
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if num_steps > 1:
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
def _execute_model(*args, clone: bool = False) -> torch.Tensor:
"""Move input args from CPU to device and execute the model."""
def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
if clone:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x = x.clone()
return x.to(self.device)
new_args = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg = _copy_to_device(arg)
elif isinstance(arg, AttentionMetadata):
arg.slot_mapping = _copy_to_device(arg.slot_mapping)
if getattr(arg, "block_tables", None) is not None:
arg.block_tables = _copy_to_device(arg.block_tables)
if getattr(arg, "context_lens", None) is not None:
arg.context_lens = _copy_to_device(arg.context_lens)
new_args.append(arg)
return self.model(*new_args)
num_prefills = model_input.attn_metadata.num_prefills
is_prompt = num_prefills > 0
if is_prompt:
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
next_token_ids = []
orig_slot_mapping = model_input.attn_metadata.slot_mapping
batch_size = model_input.input_lens.shape[0]
start_idx = 0
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len
model_input.attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx]
model_input.attn_metadata.num_prefills = 1
output_token_ids = _execute_model(
model_input.token_ids[None, start_idx:end_idx],
model_input.position_ids[None, start_idx:end_idx],
model_input.attn_metadata,
model_input.input_lens[i:i + 1],
model_input.t[i:i + 1],
model_input.p[i:i + 1],
model_input.num_samples,
kv_caches,
clone=True)
# Retrieve the outputs to CPU.
next_token_ids += output_token_ids.cpu().tolist()
start_idx = end_idx
else:
# Execute the model.
output_token_ids = _execute_model(
model_input.token_ids, model_input.position_ids,
model_input.attn_metadata, model_input.input_lens,
model_input.t, model_input.p, model_input.num_samples,
kv_caches)
# Retrieve the outputs to CPU.
next_token_ids = output_token_ids.cpu().tolist()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
@ -417,13 +548,13 @@ class TPUModelRunner:
zero_logprob = Logprob(0.0)
batch_idx = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
for seq_group in model_input.seq_groups:
seq_ids = seq_group
seq_outputs = []
seq_ids = list(seq_group_metadata.seq_data.keys())
if is_prompt:
assert len(seq_ids) == 1
seq_id = seq_ids[0]
for i in range(best_of[batch_idx]):
for i in range(model_input.best_of[batch_idx]):
next_token_id = next_token_ids[batch_idx][i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
@ -438,35 +569,6 @@ class TPUModelRunner:
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return sampler_outputs
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
num_steps: int = 1,
) -> List[SamplerOutput]:
if num_steps > 1:
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
assert seq_group_metadata_list is not None
assert len(seq_group_metadata_list) > 0
if seq_group_metadata_list[0].is_prompt:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# aware of this limitation, we need to handle batch size > 1
# internally by calling the model multiple times and concatenating
# the outputs.
# FIXME(woosuk): This is a temporary hack to not change the existing
# scheduler. We need to fix this in the future.
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
sampler_outputs += self._execute_model([seq_group_metadata],
kv_caches)
else:
sampler_outputs = self._execute_model(seq_group_metadata_list,
kv_caches)
return [SamplerOutput(sampler_outputs)]
@ -474,36 +576,37 @@ class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model.eval()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size, seq_len = token_ids.shape
# Calculate the positions to sample from.
base_indicies = torch.arange(
start_indicies = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = base_indicies + input_lens - 1
logits_indices = start_indicies + input_lens - 1
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.

View File

@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.worker.tpu_model_runner import TPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
logger = init_logger(__name__)
class TPUWorker(LoraNotSupportedWorkerBase):
class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.model_runner = TPUModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
multimodal_config,
is_driver_worker=is_driver_worker)
self.model_runner: TPUModelRunner = TPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
multimodal_config,
is_driver_worker=is_driver_worker)
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
@ -196,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if not self.is_driver_worker:
self._execute_model_non_driver()
return []
assert execute_model_req is not None
# Issue cache operations.
self.cache_swap(
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy,
)
# Run the model.
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
assert len(seq_group_metadata_list) > 0
output = self.model_runner.execute_model(seq_group_metadata_list,
self.tpu_cache)
return output
@property
def do_metadata_broadcast(self) -> bool:
# TODO(woosuk): Support TP.
return False
def cache_swap(
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
def prepare_worker_input(
self,
blocks_to_swap_in: List[Tuple[int, int]],
blocks_to_swap_out: List[Tuple[int, int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
execute_model_req: ExecuteModelRequest,
) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
blocks_to_swap_in = _make_src_to_dst(
execute_model_req.blocks_to_swap_in, "cpu", self.device)
blocks_to_swap_out = _make_src_to_dst(
execute_model_req.blocks_to_swap_out, self.device, "cpu")
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
self.device, self.device)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
assert virtual_engine == 0
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)
if blocks_to_swap_in:
# Swap from CPU to TPU.
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_in, "cpu", self.device)
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
# Issue cache operations.
if worker_input.blocks_to_swap_in is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_in
if src_indices.numel() > 0:
# Swap from CPU to TPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if blocks_to_swap_out:
# Swap from TPU to CPU.
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_out, self.device, "cpu")
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()
if worker_input.blocks_to_swap_out is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_out
if src_indices.numel() > 0:
# Swap from TPU to CPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
self.device)
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
def start_worker_execution_loop(self) -> None:
while self._execute_model_non_driver():
pass
def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True
if worker_input.blocks_to_copy is not None:
src_indices, dst_indices = worker_input.blocks_to_copy
if src_indices.numel() > 0:
attn_backend.copy_blocks(self.tpu_cache,
(src_indices, dst_indices))
def _make_src_to_dst(