mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 21:53:47 +08:00
Add model runner
This commit is contained in:
parent
5cb213c85e
commit
e4377dd698
280
vllm/worker/tpu_model_runner.py
Normal file
280
vllm/worker/tpu_model_runner.py
Normal file
@ -0,0 +1,280 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import pad_to_max_length
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
if model_config is not None and model_config.get_sliding_window():
|
||||
logger.warning("Sliding window is not supported on TPU. "
|
||||
"The model will run without sliding window.")
|
||||
self.model = None
|
||||
self.block_size = None
|
||||
# FIXME
|
||||
# self.compiled_fn = jax.jit(self._execute_step)
|
||||
self.compiled_fn = self._execute_step
|
||||
|
||||
def load_model(self) -> None:
|
||||
from vllm.model_executor.models.jax.gemma import Transformer
|
||||
|
||||
self.model = Transformer(self.model_config.hf_config)
|
||||
self.params = load_and_format_params(
|
||||
"/home/woosukk/.cache/huggingface/hub/models--google--gemma-7b-flax/snapshots/255139998d76ac69e797fd4b4e8c4b562dc3c75f/7b")
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
prompt_lens: List[int] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
input_positions.append(list(range(prompt_len)))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
slot_mapping.append([])
|
||||
for i in range(prompt_len):
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
max_prompt_len = max(prompt_lens)
|
||||
assert max_prompt_len > 0
|
||||
|
||||
input_tokens = _make_array_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=jnp.int32)
|
||||
input_positions = _make_array_with_pad(input_positions,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=jnp.int32)
|
||||
slot_mapping = _make_array_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=0, # FIXME
|
||||
dtype=jnp.int32)
|
||||
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
|
||||
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
block_tables: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
context_lens.append(seq_len)
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_tables.append(block_table)
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32)
|
||||
input_positions = jnp.asarray(input_positions, dtype=jnp.int32)
|
||||
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
|
||||
block_tables = _make_array_with_pad(block_tables, max_len=32, pad=0, dtype=jnp.int32)
|
||||
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
|
||||
input_lens = jnp.asarray([1] * len(input_tokens), dtype=jnp.int32)
|
||||
return (input_tokens, input_positions, slot_mapping, block_tables,
|
||||
context_lens, input_lens)
|
||||
|
||||
def prepare_input_arrays(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
):
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
return self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
return self._prepare_decode(seq_group_metadata_list)
|
||||
|
||||
def _execute_step(
|
||||
self,
|
||||
token_ids: jax.Array,
|
||||
position_ids: jax.Array,
|
||||
slot_mapping: jax.Array,
|
||||
block_tables: Optional[jax.Array],
|
||||
context_lens: Optional[jax.Array],
|
||||
input_lens: jax.Array,
|
||||
kv_caches: List[jax.Array],
|
||||
) -> tuple[jax.Array, list[jax.Array]]:
|
||||
batch_size, seq_len = token_ids.shape
|
||||
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
|
||||
logits_indices = base_indicies + input_lens - 1
|
||||
|
||||
logits, kv_caches = self.model.apply(
|
||||
{"params": self.params["transformer"]},
|
||||
token_ids,
|
||||
position_ids,
|
||||
slot_mapping,
|
||||
block_tables,
|
||||
context_lens,
|
||||
kv_caches,
|
||||
logits_indices,
|
||||
)
|
||||
return logits, kv_caches
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[jax.Array],
|
||||
) -> Optional[SamplerOutput]:
|
||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||
|
||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||
logits, _ = self.compiled_fn(*inputs, kv_caches)
|
||||
next_token_ids = jnp.argmax(logits, axis=-1)
|
||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
i = 0
|
||||
|
||||
sampler_outputs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_outputs = []
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[i]
|
||||
seq_outputs.append(SequenceOutput(seq_id, next_token_id, {next_token_id: Logprob(0.0)}))
|
||||
i += 1
|
||||
|
||||
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
|
||||
return SamplerOutput(sampler_outputs)
|
||||
|
||||
|
||||
def _make_array_with_pad(
|
||||
x: List[List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: jnp.dtype,
|
||||
) -> jax.Array:
|
||||
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
|
||||
return jnp.asarray(padded_x, dtype)
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Mapping
|
||||
|
||||
import orbax.checkpoint
|
||||
|
||||
Params = Mapping[str, Any]
|
||||
|
||||
|
||||
def load_and_format_params(path: str) -> Params:
|
||||
"""Loads parameters and formats them for compatibility."""
|
||||
params = load_params(path)
|
||||
param_state = jax.tree_util.tree_map(jnp.array, params)
|
||||
remapped_params = param_remapper(param_state)
|
||||
nested_params = nest_params(remapped_params)
|
||||
return nested_params
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_params(path: str) -> Params:
|
||||
"""Loads parameters from a checkpoint path."""
|
||||
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||
params = checkpointer.restore(path)
|
||||
return params
|
||||
|
||||
|
||||
def param_remapper(orig_params: Params) -> Params:
|
||||
"""Remaps params to new module layout.
|
||||
|
||||
This is needed here because the model definition does not have a separate
|
||||
`mlp` module.
|
||||
|
||||
Args:
|
||||
orig_params: original dict of parameters in Gemma format.
|
||||
|
||||
Returns:
|
||||
dict of params with different names.
|
||||
"""
|
||||
new_params = {}
|
||||
for k, v in orig_params.items():
|
||||
if 'mlp/' in k:
|
||||
layer_name, param = k.rsplit('/', maxsplit=1)
|
||||
if layer_name not in new_params:
|
||||
new_params[layer_name] = {}
|
||||
if 'w' in v:
|
||||
new_params[layer_name][param] = v['w']
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def nest_params(params: Params) -> Params:
|
||||
"""Nests params as a dict of dicts rather than a flat dict."""
|
||||
nested_params = {}
|
||||
for path, param in params.items():
|
||||
*path, leaf = path.split('/')
|
||||
subdict = nested_params
|
||||
for key in path:
|
||||
subdict = subdict.setdefault(key, {})
|
||||
subdict[leaf] = param
|
||||
return nested_params
|
||||
Loading…
x
Reference in New Issue
Block a user