From e4377dd6982be72220d6b6124aab9cfe681dc1a8 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 17 Apr 2024 18:04:54 +0000 Subject: [PATCH] Add model runner --- vllm/worker/tpu_model_runner.py | 280 ++++++++++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 vllm/worker/tpu_model_runner.py diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py new file mode 100644 index 0000000000000..5683003178058 --- /dev/null +++ b/vllm/worker/tpu_model_runner.py @@ -0,0 +1,280 @@ +from typing import Dict, List, Optional, Tuple + +import jax +import jax.numpy as jnp + +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import pad_to_max_length + +logger = init_logger(__name__) + + +class TPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + vision_language_config: Optional[VisionLanguageConfig], + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + + if model_config is not None and model_config.get_sliding_window(): + logger.warning("Sliding window is not supported on TPU. " + "The model will run without sliding window.") + self.model = None + self.block_size = None + # FIXME + # self.compiled_fn = jax.jit(self._execute_step) + self.compiled_fn = self._execute_step + + def load_model(self) -> None: + from vllm.model_executor.models.jax.gemma import Transformer + + self.model = Transformer(self.model_config.hf_config) + self.params = load_and_format_params( + "/home/woosukk/.cache/huggingface/hub/models--google--gemma-7b-flax/snapshots/255139998d76ac69e797fd4b4e8c4b562dc3c75f/7b") + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ): + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + prompt_lens: List[int] = [] + slot_mapping: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.append(prompt_tokens) + input_positions.append(list(range(prompt_len))) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + slot_mapping.append([]) + for i in range(prompt_len): + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_prompt_len = max(prompt_lens) + assert max_prompt_len > 0 + + input_tokens = _make_array_with_pad(input_tokens, + max_prompt_len, + pad=0, + dtype=jnp.int32) + input_positions = _make_array_with_pad(input_positions, + max_prompt_len, + pad=0, + dtype=jnp.int32) + slot_mapping = _make_array_with_pad(slot_mapping, + max_prompt_len, + pad=0, # FIXME + dtype=jnp.int32) + prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32) + return input_tokens, input_positions, slot_mapping, None, None, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ): + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + block_tables: List[List[int]] = [] + context_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + block_tables.append(block_table) + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + input_tokens = jnp.asarray(input_tokens, dtype=jnp.int32) + input_positions = jnp.asarray(input_positions, dtype=jnp.int32) + slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32) + block_tables = _make_array_with_pad(block_tables, max_len=32, pad=0, dtype=jnp.int32) + context_lens = jnp.asarray(context_lens, dtype=jnp.int32) + input_lens = jnp.asarray([1] * len(input_tokens), dtype=jnp.int32) + return (input_tokens, input_positions, slot_mapping, block_tables, + context_lens, input_lens) + + def prepare_input_arrays( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ): + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + return self._prepare_prompt(seq_group_metadata_list) + else: + return self._prepare_decode(seq_group_metadata_list) + + def _execute_step( + self, + token_ids: jax.Array, + position_ids: jax.Array, + slot_mapping: jax.Array, + block_tables: Optional[jax.Array], + context_lens: Optional[jax.Array], + input_lens: jax.Array, + kv_caches: List[jax.Array], + ) -> tuple[jax.Array, list[jax.Array]]: + batch_size, seq_len = token_ids.shape + base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len + logits_indices = base_indicies + input_lens - 1 + + logits, kv_caches = self.model.apply( + {"params": self.params["transformer"]}, + token_ids, + position_ids, + slot_mapping, + block_tables, + context_lens, + kv_caches, + logits_indices, + ) + return logits, kv_caches + + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[jax.Array], + ) -> Optional[SamplerOutput]: + from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob + + inputs = self.prepare_input_arrays(seq_group_metadata_list) + logits, _ = self.compiled_fn(*inputs, kv_caches) + next_token_ids = jnp.argmax(logits, axis=-1) + next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) + next_token_ids = next_token_ids.tolist() + i = 0 + + sampler_outputs = [] + for seq_group_metadata in seq_group_metadata_list: + seq_outputs = [] + + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + next_token_id = next_token_ids[i] + seq_outputs.append(SequenceOutput(seq_id, next_token_id, {next_token_id: Logprob(0.0)})) + i += 1 + + sampler_outputs.append(SequenceGroupOutput(seq_outputs, None)) + return SamplerOutput(sampler_outputs) + + +def _make_array_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: jnp.dtype, +) -> jax.Array: + padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x] + return jnp.asarray(padded_x, dtype) + + +import functools +from typing import Any, Mapping + +import orbax.checkpoint + +Params = Mapping[str, Any] + + +def load_and_format_params(path: str) -> Params: + """Loads parameters and formats them for compatibility.""" + params = load_params(path) + param_state = jax.tree_util.tree_map(jnp.array, params) + remapped_params = param_remapper(param_state) + nested_params = nest_params(remapped_params) + return nested_params + + +@functools.cache +def load_params(path: str) -> Params: + """Loads parameters from a checkpoint path.""" + checkpointer = orbax.checkpoint.PyTreeCheckpointer() + params = checkpointer.restore(path) + return params + + +def param_remapper(orig_params: Params) -> Params: + """Remaps params to new module layout. + + This is needed here because the model definition does not have a separate + `mlp` module. + + Args: + orig_params: original dict of parameters in Gemma format. + + Returns: + dict of params with different names. + """ + new_params = {} + for k, v in orig_params.items(): + if 'mlp/' in k: + layer_name, param = k.rsplit('/', maxsplit=1) + if layer_name not in new_params: + new_params[layer_name] = {} + if 'w' in v: + new_params[layer_name][param] = v['w'] + else: + new_params[k] = v + return new_params + + +def nest_params(params: Params) -> Params: + """Nests params as a dict of dicts rather than a flat dict.""" + nested_params = {} + for path, param in params.items(): + *path, leaf = path.split('/') + subdict = nested_params + for key in path: + subdict = subdict.setdefault(key, {}) + subdict[leaf] = param + return nested_params