[Hardware][TPU] Support parallel sampling & Swapping (#5855)

This commit is contained in:
Woosuk Kwon 2024-06-26 11:07:49 -07:00 committed by GitHub
parent c54269d967
commit cbc53b6b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 147 additions and 56 deletions

View File

@ -28,21 +28,35 @@ class PallasAttentionBackend(AttentionBackend):
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@torch.compile(backend="openxla")
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
) -> None:
raise NotImplementedError("swap_blocks is not implemented.")
src_k_cache, src_v_cache = src_kv_cache
dst_k_cache, dst_v_cache = dst_kv_cache
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
device = dst_k_cache.device
src_indices, dst_indices = src_to_dst
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
# TODO(woosuk): Implement this.
raise NotImplementedError("copy_blocks is not implemented.")
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@dataclass

View File

@ -22,6 +22,9 @@ logger = init_logger(__name__)
_PAD_SLOT_ID = 0 # FIXME(woosuk)
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
class TPUModelRunner:
@ -143,8 +146,9 @@ class TPUModelRunner:
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
# Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata,
input_lens, t, p)
input_lens, t, p, num_samples)
def warmup_model(
self,
@ -268,14 +272,11 @@ class TPUModelRunner:
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
num_seq_groups = len(seq_group_metadata_list)
batch_size = _get_padded_batch_size(num_seq_groups)
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
@ -288,14 +289,16 @@ class TPUModelRunner:
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[i, :len(block_table)] = block_table
self.block_tables[batch_idx, :len(block_table)] = block_table
batch_idx += 1
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
num_paddings = batch_size - num_seq_groups
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
@ -333,14 +336,13 @@ class TPUModelRunner:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
padded_batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
t = []
p = []
best_of = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.sampling_params is not None
sampling_params = seq_group_metadata.sampling_params
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
@ -354,10 +356,11 @@ class TPUModelRunner:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.best_of > 1:
if sampling_params.best_of > _MAX_NUM_SAMPLES:
raise NotImplementedError(
"best_of > 1 is not currently supported by the TPU "
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
best_of.append(sampling_params.best_of)
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
@ -369,13 +372,19 @@ class TPUModelRunner:
"prompt_logprobs is not currently supported by the TPU "
"backend.")
num_paddings = padded_batch_size - len(seq_group_metadata_list)
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1)
best_of += [best_of[-1]] * (num_seqs - 1)
num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings
p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device=self.device)
p = torch.tensor(p, dtype=torch.float32, device=self.device)
return t, p
return t, p, best_of
def _execute_model(
self,
@ -392,28 +401,41 @@ class TPUModelRunner:
else:
inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0]
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
# Execute the model.
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:], t, p)
*inputs[2:], t, p, num_samples)
# Retrieve the outputs to CPU.
next_token_ids = next_token_ids.cpu().tolist()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
i = 0
zero_logprob = Logprob(0.0)
batch_idx = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
seq_outputs = []
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
next_token_id = next_token_ids[i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: Logprob(0.0)}))
i += 1
if is_prompt:
assert len(seq_ids) == 1
seq_id = seq_ids[0]
for i in range(best_of[batch_idx]):
next_token_id = next_token_ids[batch_idx][i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
else:
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return sampler_outputs
@ -458,6 +480,7 @@ class ModelWrapper(nn.Module):
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
@ -520,8 +543,9 @@ class ModelWrapper(nn.Module):
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
next_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
return next_token_ids

View File

@ -1,5 +1,5 @@
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
@ -117,19 +117,26 @@ class TPUWorker(LoraNotSupportedWorkerBase):
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
dtype_btyes = get_dtype_size(self.cache_dtype)
block_size = self.cache_config.block_size
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
head_size * num_kv_heads)
# Calculate the TPU KV cache size based on profiling.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
profiled = m["bytes_used"] # Weights + intermediate activations.
kv_cache_bytes = max(usable_memory_size - profiled, 0)
dtype_btyes = get_dtype_size(self.cache_dtype)
block_size = self.cache_config.block_size
num_tpu_blocks = (kv_cache_bytes //
(dtype_btyes * block_size * num_layers * 2 *
head_size * num_kv_heads))
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, 0
# Calculate the CPU KV cache size based on the config.
num_cpu_blocks = (self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, num_cpu_blocks
def initialize_cache(
self,
@ -145,15 +152,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
self.cpu_cache = []
self.tpu_cache = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
key_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
value_cache = torch.zeros_like(key_cache)
self.tpu_cache.append((key_cache, value_cache))
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
self._warmup_model()
def _warmup_model(self) -> None:
@ -187,22 +198,48 @@ class TPUWorker(LoraNotSupportedWorkerBase):
if not self.is_driver_worker:
self._execute_model_non_driver()
return []
assert execute_model_req is not None
# Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying.
assert len(execute_model_req.blocks_to_swap_in) == 0, (
"Swapping is not supported for the TPU backend.")
assert len(execute_model_req.blocks_to_swap_out) == 0, (
"Swapping is not supported for the TPU backend.")
assert len(execute_model_req.blocks_to_copy) == 0
# Issue cache operations.
self.cache_swap(
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy,
)
# Run the model.
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
assert len(seq_group_metadata_list) > 0
output = self.model_runner.execute_model(seq_group_metadata_list,
self.tpu_cache)
return [output]
def cache_swap(
self,
blocks_to_swap_in: List[Tuple[int, int]],
blocks_to_swap_out: List[Tuple[int, int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)
if blocks_to_swap_in:
# Swap from CPU to TPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
self.device)
for i in range(num_layers):
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
src_to_dst)
if blocks_to_swap_out:
# Swap from TPU to CPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
"cpu")
for i in range(num_layers):
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
src_to_dst)
if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
self.device)
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
def start_worker_execution_loop(self) -> None:
while self._execute_model_non_driver():
pass
@ -210,3 +247,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True
def _make_src_to_dst(
mapping: List[Tuple[int, int]],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
src_indices = [i for i, _ in mapping]
dst_indices = [i for _, i in mapping]
src_indices = torch.tensor(src_indices,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_indices,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices