From cbc53b6b8d87b29949ce13d504750f63714df532 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Jun 2024 11:07:49 -0700 Subject: [PATCH] [Hardware][TPU] Support parallel sampling & Swapping (#5855) --- vllm/attention/backends/pallas.py | 30 +++++++--- vllm/worker/tpu_model_runner.py | 76 +++++++++++++++--------- vllm/worker/tpu_worker.py | 97 ++++++++++++++++++++++++------- 3 files changed, 147 insertions(+), 56 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 62b4a144fc443..121ca9ec45205 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -28,21 +28,35 @@ class PallasAttentionBackend(AttentionBackend): ) -> Tuple[int, ...]: return (num_kv_heads, num_blocks, block_size, head_size) + @torch.compile(backend="openxla") @staticmethod def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_kv_cache: Tuple[torch.Tensor, torch.Tensor], + dst_kv_cache: Tuple[torch.Tensor, torch.Tensor], + src_to_dst: Tuple[torch.Tensor, torch.Tensor], ) -> None: - raise NotImplementedError("swap_blocks is not implemented.") + src_k_cache, src_v_cache = src_kv_cache + dst_k_cache, dst_v_cache = dst_kv_cache + torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True) + device = dst_k_cache.device + src_indices, dst_indices = src_to_dst + dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device) + dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device) + + @torch.compile(backend="openxla") @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], ) -> None: - # TODO(woosuk): Implement this. - raise NotImplementedError("copy_blocks is not implemented.") + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] @dataclass diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2c70c1f917a0d..c3ccbd025ed1b 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -22,6 +22,9 @@ logger = init_logger(__name__) _PAD_SLOT_ID = 0 # FIXME(woosuk) # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. _ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 class TPUModelRunner: @@ -143,8 +146,9 @@ class TPUModelRunner: p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) # Dummy run. + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 self.model(token_ids, position_ids, kv_caches, attn_metadata, - input_lens, t, p) + input_lens, t, p, num_samples) def warmup_model( self, @@ -268,14 +272,11 @@ class TPUModelRunner: input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] - num_seq_groups = len(seq_group_metadata_list) - batch_size = _get_padded_batch_size(num_seq_groups) - for i, seq_group_metadata in enumerate(seq_group_metadata_list): + batch_idx = 0 + for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() @@ -288,14 +289,16 @@ class TPUModelRunner: assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] - self.block_tables[i, :len(block_table)] = block_table + self.block_tables[batch_idx, :len(block_table)] = block_table + batch_idx += 1 block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) - num_paddings = batch_size - num_seq_groups + batch_size = _get_padded_batch_size(batch_idx) + num_paddings = batch_size - batch_idx input_tokens = input_tokens + [[0]] * num_paddings input_positions = input_positions + [[0]] * num_paddings slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings @@ -333,14 +336,13 @@ class TPUModelRunner: self, seq_group_metadata_list: List[SequenceGroupMetadata], padded_batch_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: assert len(seq_group_metadata_list) > 0 t = [] p = [] + best_of = [] for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.sampling_params is not None sampling_params = seq_group_metadata.sampling_params - # NOTE(woosuk): Here we mimic argmax sampling by applying a very # low temperature. This is not accurate. t.append(sampling_params.temperature @@ -354,10 +356,11 @@ class TPUModelRunner: raise NotImplementedError( "Top-k sampling is currently disabled for the TPU backend " "due to performance issues.") - if sampling_params.best_of > 1: + if sampling_params.best_of > _MAX_NUM_SAMPLES: raise NotImplementedError( - "best_of > 1 is not currently supported by the TPU " + f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " "backend.") + best_of.append(sampling_params.best_of) if sampling_params.use_beam_search: raise NotImplementedError( "Beam search is not supported by the TPU backend.") @@ -369,13 +372,19 @@ class TPUModelRunner: "prompt_logprobs is not currently supported by the TPU " "backend.") - num_paddings = padded_batch_size - len(seq_group_metadata_list) + # Repeat the sampling params if the seq group has multiple seqs. + num_seqs = len(seq_group_metadata.seq_data) + t += [t[-1]] * (num_seqs - 1) + p += [p[-1]] * (num_seqs - 1) + best_of += [best_of[-1]] * (num_seqs - 1) + + num_paddings = padded_batch_size - len(t) t += [1.0] * num_paddings p += [1.0] * num_paddings t = torch.tensor(t, dtype=torch.float32, device=self.device) p = torch.tensor(p, dtype=torch.float32, device=self.device) - return t, p + return t, p, best_of def _execute_model( self, @@ -392,28 +401,41 @@ class TPUModelRunner: else: inputs = self._prepare_decode(seq_group_metadata_list) padded_batch_size = inputs[0].shape[0] - t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size) + t, p, best_of = self._prepare_sample(seq_group_metadata_list, + padded_batch_size) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 # Execute the model. next_token_ids = self.model(inputs[0], inputs[1], kv_caches, - *inputs[2:], t, p) + *inputs[2:], t, p, num_samples) # Retrieve the outputs to CPU. next_token_ids = next_token_ids.cpu().tolist() # NOTE(woosuk): Minimal code to construct the sampler outputs. # The TPU backend does not reuse the sampler, since the TPU backend # does not support the advanced sampling parameters such as logprobs. - i = 0 + zero_logprob = Logprob(0.0) + batch_idx = 0 sampler_outputs = [] for seq_group_metadata in seq_group_metadata_list: seq_outputs = [] seq_ids = list(seq_group_metadata.seq_data.keys()) - for seq_id in seq_ids: - next_token_id = next_token_ids[i] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: Logprob(0.0)})) - i += 1 + if is_prompt: + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + for i in range(best_of[batch_idx]): + next_token_id = next_token_ids[batch_idx][i] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + else: + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx][0] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 sampler_outputs.append( CompletionSequenceGroupOutput(seq_outputs, None)) return sampler_outputs @@ -458,6 +480,7 @@ class ModelWrapper(nn.Module): input_lens: torch.Tensor, t: torch.Tensor, p: torch.Tensor, + num_samples: int, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -520,8 +543,9 @@ class ModelWrapper(nn.Module): if _ENABLE_TOP_P: logits = _apply_top_p(logits, p.unsqueeze(dim=1)) probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - # FIXME(woosuk): best_of > 1 is not supported. - next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1) + next_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) return next_token_ids diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index cd72c71199090..c85bf6892fb28 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch_xla.core.xla_model as xm @@ -117,19 +117,26 @@ class TPUWorker(LoraNotSupportedWorkerBase): # Synchronize before measuring the memory usage. xm.wait_device_ops() + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size = self.cache_config.block_size + block_size_bytes = (dtype_btyes * block_size * num_layers * 2 * + head_size * num_kv_heads) + + # Calculate the TPU KV cache size based on profiling. m = xm.get_memory_info(self.device) total_memory_size = m["bytes_limit"] usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) profiled = m["bytes_used"] # Weights + intermediate activations. - kv_cache_bytes = max(usable_memory_size - profiled, 0) - dtype_btyes = get_dtype_size(self.cache_dtype) - block_size = self.cache_config.block_size - num_tpu_blocks = (kv_cache_bytes // - (dtype_btyes * block_size * num_layers * 2 * - head_size * num_kv_heads)) + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. - return num_tpu_blocks, 0 + + # Calculate the CPU KV cache size based on the config. + num_cpu_blocks = (self.cache_config.swap_space_bytes // + block_size_bytes) + num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. + return num_tpu_blocks, num_cpu_blocks def initialize_cache( self, @@ -145,15 +152,19 @@ class TPUWorker(LoraNotSupportedWorkerBase): num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() + self.cpu_cache = [] self.tpu_cache = [] tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( num_gpu_blocks, self.block_size, num_kv_heads, head_size) for _ in range(num_layers): - key_cache = torch.zeros(tpu_cache_shape, - dtype=dtype, - device=self.device) - value_cache = torch.zeros_like(key_cache) - self.tpu_cache.append((key_cache, value_cache)) + tpu_k_cache = torch.zeros(tpu_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) + cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu") + cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu") + self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) self._warmup_model() def _warmup_model(self) -> None: @@ -187,22 +198,48 @@ class TPUWorker(LoraNotSupportedWorkerBase): if not self.is_driver_worker: self._execute_model_non_driver() return [] - assert execute_model_req is not None - # Currently, TPUWorker does not support swapping. - # TODO(woosuk): Support block copying. - assert len(execute_model_req.blocks_to_swap_in) == 0, ( - "Swapping is not supported for the TPU backend.") - assert len(execute_model_req.blocks_to_swap_out) == 0, ( - "Swapping is not supported for the TPU backend.") - assert len(execute_model_req.blocks_to_copy) == 0 - + # Issue cache operations. + self.cache_swap( + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy, + ) + # Run the model. seq_group_metadata_list = execute_model_req.seq_group_metadata_list assert len(seq_group_metadata_list) > 0 output = self.model_runner.execute_model(seq_group_metadata_list, self.tpu_cache) return [output] + def cache_swap( + self, + blocks_to_swap_in: List[Tuple[int, int]], + blocks_to_swap_out: List[Tuple[int, int]], + blocks_to_copy: List[Tuple[int, int]], + ) -> None: + attn_backend = self.model_runner.attn_backend + num_layers = self.model_config.get_num_layers(self.parallel_config) + + if blocks_to_swap_in: + # Swap from CPU to TPU. + src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu", + self.device) + for i in range(num_layers): + attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i], + src_to_dst) + if blocks_to_swap_out: + # Swap from TPU to CPU. + src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device, + "cpu") + for i in range(num_layers): + attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i], + src_to_dst) + if blocks_to_copy: + src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, + self.device) + attn_backend.copy_blocks(self.tpu_cache, src_to_dst) + def start_worker_execution_loop(self) -> None: while self._execute_model_non_driver(): pass @@ -210,3 +247,19 @@ class TPUWorker(LoraNotSupportedWorkerBase): def _execute_model_non_driver(self) -> bool: self.model_runner.execute_model(None, self.tpu_cache) return True + + +def _make_src_to_dst( + mapping: List[Tuple[int, int]], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> Tuple[torch.Tensor, torch.Tensor]: + src_indices = [i for i, _ in mapping] + dst_indices = [i for _, i in mapping] + src_indices = torch.tensor(src_indices, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_indices, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices