diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index bbf0db31ee383..8a8b412db6731 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ import time -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -12,10 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceGroupMetadata, +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import make_tensor_with_pad +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -27,7 +34,44 @@ _ENABLE_TOP_P = False _MAX_NUM_SAMPLES = 128 -class TPUModelRunner: +@dataclass(frozen=True) +class ModelInputForTPU(ModelRunnerInputBase): + token_ids: torch.Tensor + position_ids: torch.Tensor + attn_metadata: AttentionMetadata + input_lens: torch.Tensor + t: torch.Tensor + p: torch.Tensor + num_samples: int + best_of: List[int] + seq_groups: List[List[int]] + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "token_ids": self.token_ids, + "position_ids": self.position_ids, + "input_lens": self.input_lens, + "t": self.t, + "p": self.p, + "num_samples": self.num_samples, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForTPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForTPU": + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): def __init__( self, @@ -79,6 +123,7 @@ class TPUModelRunner: multimodal_config=self.multimodal_config, lora_config=None, ) + model = model.eval() xm.wait_device_ops() model = ModelWrapper(model) @@ -147,8 +192,8 @@ class TPUModelRunner: # Dummy run. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 - self.model(token_ids, position_ids, kv_caches, attn_metadata, - input_lens, t, p, num_samples) + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + num_samples, kv_caches) def warmup_model( self, @@ -177,7 +222,7 @@ class TPUModelRunner: # Decode start = time.time() seq_len = 1 - batch_size = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) xm.wait_device_ops() @@ -195,10 +240,10 @@ class TPUModelRunner: seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] prompt_lens: List[int] = [] - slot_mapping: List[List[int]] = [] + slot_mapping: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -212,50 +257,46 @@ class TPUModelRunner: prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) - input_tokens.append(prompt_tokens) - input_positions.append(list(range(prompt_len))) + input_tokens.extend(prompt_tokens) + input_positions.extend(list(range(prompt_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] - slot_mapping.append([]) for i in range(prompt_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) + slot_mapping.append(slot) + + # Add paddings to EACH prompt to the smallest power of 2 that is + # greater than or equal to the prompt length. + # We pad the seq_len to reduce the compilation overhead. + # We execute each prompt individually (i.e., with batch_size 1) + # because the FlashAttention kernel does not support ragged inputs. + # TODO(woosuk): Use SplashAttention to support ragged inputs. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + num_paddings = padded_prompt_len - prompt_len + input_tokens += [0] * num_paddings + input_positions += [0] * num_paddings + slot_mapping += [_PAD_SLOT_ID] * num_paddings assert len(prompt_lens) > 0 num_prefills = len(prompt_lens) - num_prefill_tokens = sum(prompt_lens) - - # Add paddings to make the shape [batch_size, max_prompt_len] where - # max_prompt_len is smallest power of 2 that is greater than or equal - # to the maximum prompt length. - # We need the 2D input shape because the Pallas FlashAttention kernel - # does not support packed 1D inputs. - # We pad the seq_len to powers of 2 to reduce the compilation overhead. - max_prompt_len = _get_padded_prefill_len(max(prompt_lens)) - input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.int32, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.int32, - device=self.device) - slot_mapping = make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.int64, - device=self.device) + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") prompt_lens = torch.tensor(prompt_lens, dtype=torch.int32, - device=self.device) + device="cpu") attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used. + num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, block_tables=None, @@ -306,22 +347,22 @@ class TPUModelRunner: input_tokens = torch.tensor(input_tokens, dtype=torch.int32, - device=self.device) + device="cpu") input_positions = torch.tensor(input_positions, dtype=torch.int32, - device=self.device) + device="cpu") slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, - device=self.device) + device="cpu") context_lens = torch.tensor(context_lens, dtype=torch.int32, - device=self.device) + device="cpu") block_tables = torch.tensor(self.block_tables[:batch_size], dtype=torch.int32, - device=self.device) + device="cpu") input_lens = torch.tensor([1] * batch_size, dtype=torch.int32, - device=self.device) + device="cpu") attn_metadata = self.attn_backend.make_metadata( num_prefills=0, num_prefill_tokens=0, @@ -382,16 +423,18 @@ class TPUModelRunner: t += [1.0] * num_paddings p += [1.0] * num_paddings - t = torch.tensor(t, dtype=torch.float32, device=self.device) - p = torch.tensor(p, dtype=torch.float32, device=self.device) + t = torch.tensor(t, dtype=torch.float32, device="cpu") + p = torch.tensor(p, dtype=torch.float32, device="cpu") return t, p, best_of - def _execute_model( + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> List[CompletionSequenceGroupOutput]: - # Prepare inputs. + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForTPU: + del finished_requests_ids # Unused. + assert virtual_engine == 0 assert len(seq_group_metadata_list) > 0 # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -400,16 +443,104 @@ class TPUModelRunner: inputs = self._prepare_prompt(seq_group_metadata_list) else: inputs = self._prepare_decode(seq_group_metadata_list) - padded_batch_size = inputs[0].shape[0] + input_tokens, input_positions, attn_metadata, input_lens = inputs + padded_batch_size = input_tokens.shape[0] t, p, best_of = self._prepare_sample(seq_group_metadata_list, padded_batch_size) num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 - # Execute the model. - next_token_ids = self.model(inputs[0], inputs[1], kv_caches, - *inputs[2:], t, p, num_samples) - # Retrieve the outputs to CPU. - next_token_ids = next_token_ids.cpu().tolist() + seq_groups = [ + list(metadata.seq_data.keys()) + for metadata in seq_group_metadata_list + ] + return ModelInputForTPU(input_tokens, input_positions, attn_metadata, + input_lens, t, p, num_samples, best_of, + seq_groups) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: + model_input = ModelInputForTPU.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=self.attn_backend) + return model_input + + def execute_model( + self, + model_input: ModelInputForTPU, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> List[SamplerOutput]: + assert intermediate_tensors is None + if num_steps > 1: + raise ValueError( + "TPUModelRunner does not support multi-step execution.") + + def _execute_model(*args, clone: bool = False) -> torch.Tensor: + """Move input args from CPU to device and execute the model.""" + + def _copy_to_device(x: torch.Tensor) -> torch.Tensor: + if clone: + # When x is a slice of a CPU tensor, XLA may copy the whole + # original tensor to TPU instead of only copying x. + # To avoid this, we copy x after cloning. + x = x.clone() + return x.to(self.device) + + new_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + arg = _copy_to_device(arg) + elif isinstance(arg, AttentionMetadata): + arg.slot_mapping = _copy_to_device(arg.slot_mapping) + if getattr(arg, "block_tables", None) is not None: + arg.block_tables = _copy_to_device(arg.block_tables) + if getattr(arg, "context_lens", None) is not None: + arg.context_lens = _copy_to_device(arg.context_lens) + new_args.append(arg) + return self.model(*new_args) + + num_prefills = model_input.attn_metadata.num_prefills + is_prompt = num_prefills > 0 + if is_prompt: + # NOTE(woosuk): Since the FlashAttention kernel does not support + # ragged inputs, we split the prompts into different batches and + # process them separately. This is a temporary hack that should be + # optimized by using SplashAttention. + next_token_ids = [] + orig_slot_mapping = model_input.attn_metadata.slot_mapping + batch_size = model_input.input_lens.shape[0] + start_idx = 0 + for i in range(batch_size): + # Get the actual prefill_len. + prefill_len = model_input.input_lens[i:i + 1].item() + prefill_len = _get_padded_prefill_len(prefill_len) + end_idx = start_idx + prefill_len + + model_input.attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx] + model_input.attn_metadata.num_prefills = 1 + output_token_ids = _execute_model( + model_input.token_ids[None, start_idx:end_idx], + model_input.position_ids[None, start_idx:end_idx], + model_input.attn_metadata, + model_input.input_lens[i:i + 1], + model_input.t[i:i + 1], + model_input.p[i:i + 1], + model_input.num_samples, + kv_caches, + clone=True) + # Retrieve the outputs to CPU. + next_token_ids += output_token_ids.cpu().tolist() + start_idx = end_idx + else: + # Execute the model. + output_token_ids = _execute_model( + model_input.token_ids, model_input.position_ids, + model_input.attn_metadata, model_input.input_lens, + model_input.t, model_input.p, model_input.num_samples, + kv_caches) + # Retrieve the outputs to CPU. + next_token_ids = output_token_ids.cpu().tolist() # NOTE(woosuk): Minimal code to construct the sampler outputs. # The TPU backend does not reuse the sampler, since the TPU backend @@ -417,13 +548,13 @@ class TPUModelRunner: zero_logprob = Logprob(0.0) batch_idx = 0 sampler_outputs = [] - for seq_group_metadata in seq_group_metadata_list: + for seq_group in model_input.seq_groups: + seq_ids = seq_group seq_outputs = [] - seq_ids = list(seq_group_metadata.seq_data.keys()) if is_prompt: assert len(seq_ids) == 1 seq_id = seq_ids[0] - for i in range(best_of[batch_idx]): + for i in range(model_input.best_of[batch_idx]): next_token_id = next_token_ids[batch_idx][i] seq_outputs.append( SequenceOutput(seq_id, next_token_id, @@ -438,35 +569,6 @@ class TPUModelRunner: batch_idx += 1 sampler_outputs.append( CompletionSequenceGroupOutput(seq_outputs, None)) - return sampler_outputs - - def execute_model( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - num_steps: int = 1, - ) -> List[SamplerOutput]: - if num_steps > 1: - raise ValueError( - "TPUModelRunner does not support multi-step execution.") - - assert seq_group_metadata_list is not None - assert len(seq_group_metadata_list) > 0 - if seq_group_metadata_list[0].is_prompt: - # NOTE(woosuk): To reduce the compilation time, we only compile the - # prefill inputs with batch size 1. Because the scheduler is not - # aware of this limitation, we need to handle batch size > 1 - # internally by calling the model multiple times and concatenating - # the outputs. - # FIXME(woosuk): This is a temporary hack to not change the existing - # scheduler. We need to fix this in the future. - sampler_outputs = [] - for seq_group_metadata in seq_group_metadata_list: - sampler_outputs += self._execute_model([seq_group_metadata], - kv_caches) - else: - sampler_outputs = self._execute_model(seq_group_metadata_list, - kv_caches) return [SamplerOutput(sampler_outputs)] @@ -474,36 +576,37 @@ class ModelWrapper(nn.Module): def __init__(self, model: nn.Module): super().__init__() - self.model = model.eval() + self.model = model def forward( self, token_ids: torch.Tensor, position_ids: torch.Tensor, - kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], attn_metadata: AttentionMetadata, input_lens: torch.Tensor, t: torch.Tensor, p: torch.Tensor, num_samples: int, + kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. Args: token_ids: The input token IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len]. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. """ batch_size, seq_len = token_ids.shape # Calculate the positions to sample from. - base_indicies = torch.arange( + start_indicies = torch.arange( batch_size, dtype=torch.int32, device=input_lens.device) * seq_len - logits_indices = base_indicies + input_lens - 1 + logits_indices = start_indicies + input_lens - 1 # FIXME(woosuk): This is a temporary hack to avoid using the existing # sampler and sampling metadata. diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 9bf764f0ff23a..03011e03058d8 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.worker.tpu_model_runner import TPUModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase, WorkerInput) logger = init_logger(__name__) -class TPUWorker(LoraNotSupportedWorkerBase): +class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, @@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase): self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_config.cache_dtype] - self.model_runner = TPUModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - multimodal_config, - is_driver_worker=is_driver_worker) + self.model_runner: TPUModelRunner = TPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + multimodal_config, + is_driver_worker=is_driver_worker) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" @@ -196,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase): dtype_size = get_dtype_size(self.cache_dtype) return dtype_size * total - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: - if not self.is_driver_worker: - self._execute_model_non_driver() - return [] - assert execute_model_req is not None - # Issue cache operations. - self.cache_swap( - execute_model_req.blocks_to_swap_in, - execute_model_req.blocks_to_swap_out, - execute_model_req.blocks_to_copy, - ) - # Run the model. - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - assert len(seq_group_metadata_list) > 0 - output = self.model_runner.execute_model(seq_group_metadata_list, - self.tpu_cache) - return output + @property + def do_metadata_broadcast(self) -> bool: + # TODO(woosuk): Support TP. + return False - def cache_swap( + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline + # parallelism. + return [self.tpu_cache] + + def prepare_worker_input( self, - blocks_to_swap_in: List[Tuple[int, int]], - blocks_to_swap_out: List[Tuple[int, int]], - blocks_to_copy: List[Tuple[int, int]], - ) -> None: + execute_model_req: ExecuteModelRequest, + ) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + blocks_to_swap_in = _make_src_to_dst( + execute_model_req.blocks_to_swap_in, "cpu", self.device) + blocks_to_swap_out = _make_src_to_dst( + execute_model_req.blocks_to_swap_out, self.device, "cpu") + blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy, + self.device, self.device) + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + assert virtual_engine == 0 attn_backend = self.model_runner.attn_backend num_layers = self.model_config.get_num_layers(self.parallel_config) - if blocks_to_swap_in: - # Swap from CPU to TPU. - src_indices, dst_indices = _make_src_to_dst( - blocks_to_swap_in, "cpu", self.device) - for i in range(num_layers): - tpu_k_cache, tpu_v_cache = self.tpu_cache[i] - cpu_k_cache, cpu_v_cache = self.cpu_cache[i] - k = cpu_k_cache[:, src_indices].to(self.device) - v = cpu_v_cache[:, src_indices].to(self.device) - _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) + # Issue cache operations. + if worker_input.blocks_to_swap_in is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_in + if src_indices.numel() > 0: + # Swap from CPU to TPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + k = cpu_k_cache[:, src_indices].to(self.device) + v = cpu_v_cache[:, src_indices].to(self.device) + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) - if blocks_to_swap_out: - # Swap from TPU to CPU. - src_indices, dst_indices = _make_src_to_dst( - blocks_to_swap_out, self.device, "cpu") - for i in range(num_layers): - tpu_k_cache, tpu_v_cache = self.tpu_cache[i] - cpu_k_cache, cpu_v_cache = self.cpu_cache[i] - cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu() - cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu() + if worker_input.blocks_to_swap_out is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_out + if src_indices.numel() > 0: + # Swap from TPU to CPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices] + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices] - if blocks_to_copy: - src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, - self.device) - attn_backend.copy_blocks(self.tpu_cache, src_to_dst) - - def start_worker_execution_loop(self) -> None: - while self._execute_model_non_driver(): - pass - - def _execute_model_non_driver(self) -> bool: - self.model_runner.execute_model(None, self.tpu_cache) - return True + if worker_input.blocks_to_copy is not None: + src_indices, dst_indices = worker_input.blocks_to_copy + if src_indices.numel() > 0: + attn_backend.copy_blocks(self.tpu_cache, + (src_indices, dst_indices)) def _make_src_to_dst(