diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index fc2aab0261576..f98d8e758ccce 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple +import numpy as np import jax import jax.numpy as jnp @@ -11,12 +12,13 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import pad_to_max_length # DELETE -# from jax_smi import initialise_tracking -# initialise_tracking() +from jax_smi import initialise_tracking +initialise_tracking() logger = init_logger(__name__) _PAD_SLOT_ID = -1 +_MAX_NUM_SEQS = 256 class TPUModelRunner: @@ -41,6 +43,8 @@ class TPUModelRunner: self.model = None self.block_size = None self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,)) + # FIXME(woosuk) + self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32) def load_model(self) -> None: from huggingface_hub import snapshot_download @@ -91,10 +95,7 @@ class TPUModelRunner: max_prompt_len = max(prompt_lens) assert max_prompt_len > 0 - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - max_prompt_len = (max_prompt_len + 15) // 16 * 16 + max_prompt_len = _get_padded_prefill_len(max_prompt_len) input_tokens = _make_array_with_pad(input_tokens, max_prompt_len, @@ -119,10 +120,11 @@ class TPUModelRunner: input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] - block_tables: List[List[int]] = [] context_lens: List[int] = [] + num_seq_groups = len(seq_group_metadata_list) + batch_size = _get_padded_batch_size(num_seq_groups) - for seq_group_metadata in seq_group_metadata_list: + for i, seq_group_metadata in enumerate(seq_group_metadata_list): assert not seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -139,19 +141,26 @@ class TPUModelRunner: assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] - block_tables.append(block_table) + self.block_tables[i, :len(block_table)] = block_table block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) + num_paddings = batch_size - num_seq_groups + input_tokens = input_tokens + [[0]] * num_paddings + input_positions = input_positions + [[0]] * num_paddings + slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings + context_lens = context_lens + [0] * num_paddings + input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32) input_positions = jnp.asarray(input_positions, dtype=jnp.int32) slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32) - block_tables = _make_array_with_pad(block_tables, max_len=32, pad=0, dtype=jnp.int32) context_lens = jnp.asarray(context_lens, dtype=jnp.int32) - input_lens = jnp.asarray([1] * len(input_tokens), dtype=jnp.int32) + + block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32) + input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32) return (input_tokens, input_positions, slot_mapping, block_tables, context_lens, input_lens) @@ -234,6 +243,24 @@ def _make_array_with_pad( return jnp.asarray(padded_x, dtype) +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + 7) // 8) * 8 + + import functools from typing import Any, Mapping