mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 21:43:09 +08:00
[Hardware][TPU] Support parallel sampling & Swapping (#5855)
This commit is contained in:
parent
c54269d967
commit
cbc53b6b8d
@ -28,21 +28,35 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> None:
|
||||
raise NotImplementedError("swap_blocks is not implemented.")
|
||||
src_k_cache, src_v_cache = src_kv_cache
|
||||
dst_k_cache, dst_v_cache = dst_kv_cache
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
|
||||
|
||||
device = dst_k_cache.device
|
||||
src_indices, dst_indices = src_to_dst
|
||||
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
|
||||
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> None:
|
||||
# TODO(woosuk): Implement this.
|
||||
raise NotImplementedError("copy_blocks is not implemented.")
|
||||
src_indices, dst_indices = src_to_dists
|
||||
for k_cache, v_cache in kv_caches:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -22,6 +22,9 @@ logger = init_logger(__name__)
|
||||
_PAD_SLOT_ID = 0 # FIXME(woosuk)
|
||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||
_ENABLE_TOP_P = False
|
||||
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||
# This can significantly affect the performance if too large.
|
||||
_MAX_NUM_SAMPLES = 128
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
@ -143,8 +146,9 @@ class TPUModelRunner:
|
||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
|
||||
# Dummy run.
|
||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
||||
input_lens, t, p)
|
||||
input_lens, t, p, num_samples)
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
@ -268,14 +272,11 @@ class TPUModelRunner:
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
batch_size = _get_padded_batch_size(num_seq_groups)
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
batch_idx = 0
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
@ -288,14 +289,16 @@ class TPUModelRunner:
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
self.block_tables[i, :len(block_table)] = block_table
|
||||
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||
batch_idx += 1
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
num_paddings = batch_size - num_seq_groups
|
||||
batch_size = _get_padded_batch_size(batch_idx)
|
||||
num_paddings = batch_size - batch_idx
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
input_positions = input_positions + [[0]] * num_paddings
|
||||
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||
@ -333,14 +336,13 @@ class TPUModelRunner:
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
padded_batch_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
t = []
|
||||
p = []
|
||||
best_of = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.sampling_params is not None
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
|
||||
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
|
||||
# low temperature. This is not accurate.
|
||||
t.append(sampling_params.temperature
|
||||
@ -354,10 +356,11 @@ class TPUModelRunner:
|
||||
raise NotImplementedError(
|
||||
"Top-k sampling is currently disabled for the TPU backend "
|
||||
"due to performance issues.")
|
||||
if sampling_params.best_of > 1:
|
||||
if sampling_params.best_of > _MAX_NUM_SAMPLES:
|
||||
raise NotImplementedError(
|
||||
"best_of > 1 is not currently supported by the TPU "
|
||||
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
||||
"backend.")
|
||||
best_of.append(sampling_params.best_of)
|
||||
if sampling_params.use_beam_search:
|
||||
raise NotImplementedError(
|
||||
"Beam search is not supported by the TPU backend.")
|
||||
@ -369,13 +372,19 @@ class TPUModelRunner:
|
||||
"prompt_logprobs is not currently supported by the TPU "
|
||||
"backend.")
|
||||
|
||||
num_paddings = padded_batch_size - len(seq_group_metadata_list)
|
||||
# Repeat the sampling params if the seq group has multiple seqs.
|
||||
num_seqs = len(seq_group_metadata.seq_data)
|
||||
t += [t[-1]] * (num_seqs - 1)
|
||||
p += [p[-1]] * (num_seqs - 1)
|
||||
best_of += [best_of[-1]] * (num_seqs - 1)
|
||||
|
||||
num_paddings = padded_batch_size - len(t)
|
||||
t += [1.0] * num_paddings
|
||||
p += [1.0] * num_paddings
|
||||
|
||||
t = torch.tensor(t, dtype=torch.float32, device=self.device)
|
||||
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
||||
return t, p
|
||||
return t, p, best_of
|
||||
|
||||
def _execute_model(
|
||||
self,
|
||||
@ -392,28 +401,41 @@ class TPUModelRunner:
|
||||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
padded_batch_size = inputs[0].shape[0]
|
||||
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
|
||||
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
|
||||
padded_batch_size)
|
||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||
|
||||
# Execute the model.
|
||||
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
||||
*inputs[2:], t, p)
|
||||
*inputs[2:], t, p, num_samples)
|
||||
# Retrieve the outputs to CPU.
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||
# does not support the advanced sampling parameters such as logprobs.
|
||||
i = 0
|
||||
zero_logprob = Logprob(0.0)
|
||||
batch_idx = 0
|
||||
sampler_outputs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_outputs = []
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[i]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: Logprob(0.0)}))
|
||||
i += 1
|
||||
if is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
for i in range(best_of[batch_idx]):
|
||||
next_token_id = next_token_ids[batch_idx][i]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: zero_logprob}))
|
||||
batch_idx += 1
|
||||
else:
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[batch_idx][0]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: zero_logprob}))
|
||||
batch_idx += 1
|
||||
sampler_outputs.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||
return sampler_outputs
|
||||
@ -458,6 +480,7 @@ class ModelWrapper(nn.Module):
|
||||
input_lens: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
|
||||
@ -520,8 +543,9 @@ class ModelWrapper(nn.Module):
|
||||
if _ENABLE_TOP_P:
|
||||
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# FIXME(woosuk): best_of > 1 is not supported.
|
||||
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
|
||||
next_token_ids = torch.multinomial(probs,
|
||||
num_samples,
|
||||
replacement=True)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
@ -117,19 +117,26 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||
block_size = self.cache_config.block_size
|
||||
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
|
||||
head_size * num_kv_heads)
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
profiled = m["bytes_used"] # Weights + intermediate activations.
|
||||
kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||
block_size = self.cache_config.block_size
|
||||
num_tpu_blocks = (kv_cache_bytes //
|
||||
(dtype_btyes * block_size * num_layers * 2 *
|
||||
head_size * num_kv_heads))
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
|
||||
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||
return num_tpu_blocks, 0
|
||||
|
||||
# Calculate the CPU KV cache size based on the config.
|
||||
num_cpu_blocks = (self.cache_config.swap_space_bytes //
|
||||
block_size_bytes)
|
||||
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
|
||||
return num_tpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
@ -145,15 +152,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
|
||||
self.cpu_cache = []
|
||||
self.tpu_cache = []
|
||||
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.zeros(tpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
value_cache = torch.zeros_like(key_cache)
|
||||
self.tpu_cache.append((key_cache, value_cache))
|
||||
tpu_k_cache = torch.zeros(tpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
|
||||
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
|
||||
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
|
||||
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
@ -187,22 +198,48 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
if not self.is_driver_worker:
|
||||
self._execute_model_non_driver()
|
||||
return []
|
||||
|
||||
assert execute_model_req is not None
|
||||
# Currently, TPUWorker does not support swapping.
|
||||
# TODO(woosuk): Support block copying.
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0, (
|
||||
"Swapping is not supported for the TPU backend.")
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0, (
|
||||
"Swapping is not supported for the TPU backend.")
|
||||
assert len(execute_model_req.blocks_to_copy) == 0
|
||||
|
||||
# Issue cache operations.
|
||||
self.cache_swap(
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy,
|
||||
)
|
||||
# Run the model.
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.tpu_cache)
|
||||
return [output]
|
||||
|
||||
def cache_swap(
|
||||
self,
|
||||
blocks_to_swap_in: List[Tuple[int, int]],
|
||||
blocks_to_swap_out: List[Tuple[int, int]],
|
||||
blocks_to_copy: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
attn_backend = self.model_runner.attn_backend
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
|
||||
if blocks_to_swap_in:
|
||||
# Swap from CPU to TPU.
|
||||
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
|
||||
self.device)
|
||||
for i in range(num_layers):
|
||||
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
|
||||
src_to_dst)
|
||||
if blocks_to_swap_out:
|
||||
# Swap from TPU to CPU.
|
||||
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
|
||||
"cpu")
|
||||
for i in range(num_layers):
|
||||
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
|
||||
src_to_dst)
|
||||
if blocks_to_copy:
|
||||
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
|
||||
self.device)
|
||||
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
|
||||
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
while self._execute_model_non_driver():
|
||||
pass
|
||||
@ -210,3 +247,19 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
def _execute_model_non_driver(self) -> bool:
|
||||
self.model_runner.execute_model(None, self.tpu_cache)
|
||||
return True
|
||||
|
||||
|
||||
def _make_src_to_dst(
|
||||
mapping: List[Tuple[int, int]],
|
||||
src_device: Union[torch.device, str],
|
||||
dst_device: Union[torch.device, str],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
src_indices = [i for i, _ in mapping]
|
||||
dst_indices = [i for _, i in mapping]
|
||||
src_indices = torch.tensor(src_indices,
|
||||
device=src_device,
|
||||
dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_indices,
|
||||
device=dst_device,
|
||||
dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user