mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 18:17:17 +08:00
Pad to avoid recompilation
This commit is contained in:
parent
e2c7dedb3a
commit
81b8b813f1
@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -11,12 +12,13 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import pad_to_max_length
|
||||
|
||||
# DELETE
|
||||
# from jax_smi import initialise_tracking
|
||||
# initialise_tracking()
|
||||
from jax_smi import initialise_tracking
|
||||
initialise_tracking()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
_MAX_NUM_SEQS = 256
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
@ -41,6 +43,8 @@ class TPUModelRunner:
|
||||
self.model = None
|
||||
self.block_size = None
|
||||
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
|
||||
# FIXME(woosuk)
|
||||
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
|
||||
|
||||
def load_model(self) -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
@ -91,10 +95,7 @@ class TPUModelRunner:
|
||||
|
||||
max_prompt_len = max(prompt_lens)
|
||||
assert max_prompt_len > 0
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
max_prompt_len = (max_prompt_len + 15) // 16 * 16
|
||||
max_prompt_len = _get_padded_prefill_len(max_prompt_len)
|
||||
|
||||
input_tokens = _make_array_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
@ -119,10 +120,11 @@ class TPUModelRunner:
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
block_tables: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
batch_size = _get_padded_batch_size(num_seq_groups)
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
@ -139,19 +141,26 @@ class TPUModelRunner:
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_tables.append(block_table)
|
||||
self.block_tables[i, :len(block_table)] = block_table
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
num_paddings = batch_size - num_seq_groups
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
input_positions = input_positions + [[0]] * num_paddings
|
||||
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||
context_lens = context_lens + [0] * num_paddings
|
||||
|
||||
input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32)
|
||||
input_positions = jnp.asarray(input_positions, dtype=jnp.int32)
|
||||
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
|
||||
block_tables = _make_array_with_pad(block_tables, max_len=32, pad=0, dtype=jnp.int32)
|
||||
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
|
||||
input_lens = jnp.asarray([1] * len(input_tokens), dtype=jnp.int32)
|
||||
|
||||
block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32)
|
||||
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
|
||||
return (input_tokens, input_positions, slot_mapping, block_tables,
|
||||
context_lens, input_lens)
|
||||
|
||||
@ -234,6 +243,24 @@ def _make_array_with_pad(
|
||||
return jnp.asarray(padded_x, dtype)
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_batch_size(batch_size: int) -> int:
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return ((batch_size + 7) // 8) * 8
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user