Pad to avoid recompilation

This commit is contained in:
Woosuk Kwon 2024-04-25 04:43:33 +00:00
parent e2c7dedb3a
commit 81b8b813f1

View File

@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import jax
import jax.numpy as jnp
@ -11,12 +12,13 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import pad_to_max_length
# DELETE
# from jax_smi import initialise_tracking
# initialise_tracking()
from jax_smi import initialise_tracking
initialise_tracking()
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
_MAX_NUM_SEQS = 256
class TPUModelRunner:
@ -41,6 +43,8 @@ class TPUModelRunner:
self.model = None
self.block_size = None
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
# FIXME(woosuk)
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
def load_model(self) -> None:
from huggingface_hub import snapshot_download
@ -91,10 +95,7 @@ class TPUModelRunner:
max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
max_prompt_len = (max_prompt_len + 15) // 16 * 16
max_prompt_len = _get_padded_prefill_len(max_prompt_len)
input_tokens = _make_array_with_pad(input_tokens,
max_prompt_len,
@ -119,10 +120,11 @@ class TPUModelRunner:
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
block_tables: List[List[int]] = []
context_lens: List[int] = []
num_seq_groups = len(seq_group_metadata_list)
batch_size = _get_padded_batch_size(num_seq_groups)
for seq_group_metadata in seq_group_metadata_list:
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
@ -139,19 +141,26 @@ class TPUModelRunner:
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
block_tables.append(block_table)
self.block_tables[i, :len(block_table)] = block_table
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
num_paddings = batch_size - num_seq_groups
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings
input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32)
input_positions = jnp.asarray(input_positions, dtype=jnp.int32)
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
block_tables = _make_array_with_pad(block_tables, max_len=32, pad=0, dtype=jnp.int32)
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
input_lens = jnp.asarray([1] * len(input_tokens), dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32)
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
return (input_tokens, input_positions, slot_mapping, block_tables,
context_lens, input_lens)
@ -234,6 +243,24 @@ def _make_array_with_pad(
return jnp.asarray(padded_x, dtype)
def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + 7) // 8) * 8
import functools
from typing import Any, Mapping