Pad to avoid recompilation

This commit is contained in:
Woosuk Kwon 2024-04-25 04:43:33 +00:00
parent e2c7dedb3a
commit 81b8b813f1

View File

@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
@ -11,12 +12,13 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import pad_to_max_length from vllm.utils import pad_to_max_length
# DELETE # DELETE
# from jax_smi import initialise_tracking from jax_smi import initialise_tracking
# initialise_tracking() initialise_tracking()
logger = init_logger(__name__) logger = init_logger(__name__)
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
_MAX_NUM_SEQS = 256
class TPUModelRunner: class TPUModelRunner:
@ -41,6 +43,8 @@ class TPUModelRunner:
self.model = None self.model = None
self.block_size = None self.block_size = None
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,)) self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
# FIXME(woosuk)
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
def load_model(self) -> None: def load_model(self) -> None:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -91,10 +95,7 @@ class TPUModelRunner:
max_prompt_len = max(prompt_lens) max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0 assert max_prompt_len > 0
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence max_prompt_len = _get_padded_prefill_len(max_prompt_len)
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
max_prompt_len = (max_prompt_len + 15) // 16 * 16
input_tokens = _make_array_with_pad(input_tokens, input_tokens = _make_array_with_pad(input_tokens,
max_prompt_len, max_prompt_len,
@ -119,10 +120,11 @@ class TPUModelRunner:
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
block_tables: List[List[int]] = []
context_lens: List[int] = [] context_lens: List[int] = []
num_seq_groups = len(seq_group_metadata_list)
batch_size = _get_padded_batch_size(num_seq_groups)
for seq_group_metadata in seq_group_metadata_list: for i, seq_group_metadata in enumerate(seq_group_metadata_list):
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
@ -139,19 +141,26 @@ class TPUModelRunner:
assert seq_group_metadata.block_tables is not None assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
block_tables.append(block_table) self.block_tables[i, :len(block_table)] = block_table
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
block_offset = position % self.block_size block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append([slot])
num_paddings = batch_size - num_seq_groups
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings
input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32) input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32)
input_positions = jnp.asarray(input_positions, dtype=jnp.int32) input_positions = jnp.asarray(input_positions, dtype=jnp.int32)
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32) slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
block_tables = _make_array_with_pad(block_tables, max_len=32, pad=0, dtype=jnp.int32)
context_lens = jnp.asarray(context_lens, dtype=jnp.int32) context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
input_lens = jnp.asarray([1] * len(input_tokens), dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32)
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
return (input_tokens, input_positions, slot_mapping, block_tables, return (input_tokens, input_positions, slot_mapping, block_tables,
context_lens, input_lens) context_lens, input_lens)
@ -234,6 +243,24 @@ def _make_array_with_pad(
return jnp.asarray(padded_x, dtype) return jnp.asarray(padded_x, dtype)
def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + 7) // 8) * 8
import functools import functools
from typing import Any, Mapping from typing import Any, Mapping