From 5ae2f81c2b370ec4f262a1db69ac88daad5b2de4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 05:28:09 +0000 Subject: [PATCH] Add warmup + formatting --- vllm/worker/tpu_model_runner.py | 138 ++++++++++++++++++++++---------- vllm/worker/tpu_worker.py | 7 +- 2 files changed, 97 insertions(+), 48 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 77f4885deef64..89390538a4ab3 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -14,6 +14,7 @@ from vllm.utils import pad_to_max_length # DELETE from jax_smi import initialise_tracking + initialise_tracking() logger = init_logger(__name__) @@ -43,11 +44,11 @@ class TPUModelRunner: "The model will run without sliding window.") self.model = None self.block_size = None - self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,)) + self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, )) # FIXME(woosuk) self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32) - def load_model(self) -> None: + def load_model(self) -> None: from huggingface_hub import snapshot_download from vllm.model_executor.models.jax.gemma import Transformer @@ -60,6 +61,51 @@ class TPUModelRunner: params = load_and_format_params(model_dir + "/7b/")["transformer"] self.params = {"params": params} + def warmup_model(self, tpu_caches: List[Tuple[jax.Array, + jax.Array]]) -> None: + # Prefill + logger.info("Warming up the model...") + start = time.time() + for batch_size in [1]: + for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]: + if batch_size * seq_len > 8192: + continue + token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) + position_ids = jnp.zeros((batch_size, seq_len), + dtype=jnp.int32) + slot_mapping = jnp.zeros((batch_size, seq_len), + dtype=jnp.int32) + block_tables = None + context_lens = None + prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32) + + # Dummy run. + _, tpu_caches = self.compiled_fn(self.params, token_ids, + position_ids, slot_mapping, + block_tables, context_lens, + prompt_lens, tpu_caches) + end = time.time() + logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.") + + # Decode + start = time.time() + for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]: + seq_len = 1 + token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) + position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) + slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) + block_tables = jnp.asarray(self.block_tables[:batch_size], + dtype=jnp.int32) + context_lens = jnp.ones((batch_size, ), dtype=jnp.int32) + prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32) + + _, tpu_caches = self.compiled_fn(self.params, token_ids, + position_ids, slot_mapping, + block_tables, context_lens, + prompt_lens, tpu_caches) + end = time.time() + logger.info(f"Decode warmup done in {(end - start):.2f} seconds.") + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -107,9 +153,9 @@ class TPUModelRunner: pad=0, dtype=jnp.int32) slot_mapping = _make_array_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=jnp.int32) + max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=jnp.int32) prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32) return input_tokens, input_positions, slot_mapping, None, None, prompt_lens @@ -160,7 +206,8 @@ class TPUModelRunner: slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32) context_lens = jnp.asarray(context_lens, dtype=jnp.int32) - block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32) + block_tables = jnp.asarray(self.block_tables[:batch_size], + dtype=jnp.int32) input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32) return (input_tokens, input_positions, slot_mapping, block_tables, context_lens, input_lens) @@ -210,25 +257,26 @@ class TPUModelRunner: def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - kv_caches: List[jax.Array], - ) -> Tuple[Optional[SamplerOutput], List[jax.Array]]: + kv_caches: List[Tuple[jax.Array, jax.Array]], + ) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]: from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob start = time.time() inputs = self.prepare_input_arrays(seq_group_metadata_list) end = time.time() - print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms") + # print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms") start = time.time() - next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches) + next_token_ids, new_kv_caches = self.compiled_fn( + self.params, *inputs, kv_caches) next_token_ids.block_until_ready() end = time.time() - print(f"compiled_fn: {(end - start) * 1000:.2f} ms") + # print(f"compiled_fn: {(end - start) * 1000:.2f} ms") start = time.time() next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) end = time.time() - print(f"jax.device_put: {(end - start) * 1000:.2f} ms") + # print(f"jax.device_put: {(end - start) * 1000:.2f} ms") next_token_ids = next_token_ids.tolist() i = 0 @@ -240,7 +288,9 @@ class TPUModelRunner: seq_ids = list(seq_group_metadata.seq_data.keys()) for seq_id in seq_ids: next_token_id = next_token_ids[i] - seq_outputs.append(SequenceOutput(seq_id, next_token_id, {next_token_id: Logprob(0.0)})) + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: Logprob(0.0)})) i += 1 sampler_outputs.append(SequenceGroupOutput(seq_outputs, None)) @@ -284,24 +334,24 @@ Params = Mapping[str, Any] def load_and_format_params(path: str) -> Params: - """Loads parameters and formats them for compatibility.""" - params = load_params(path) - param_state = jax.tree_util.tree_map(jnp.array, params) - remapped_params = param_remapper(param_state) - nested_params = nest_params(remapped_params) - return nested_params + """Loads parameters and formats them for compatibility.""" + params = load_params(path) + param_state = jax.tree_util.tree_map(jnp.array, params) + remapped_params = param_remapper(param_state) + nested_params = nest_params(remapped_params) + return nested_params @functools.cache def load_params(path: str) -> Params: - """Loads parameters from a checkpoint path.""" - checkpointer = orbax.checkpoint.PyTreeCheckpointer() - params = checkpointer.restore(path) - return params + """Loads parameters from a checkpoint path.""" + checkpointer = orbax.checkpoint.PyTreeCheckpointer() + params = checkpointer.restore(path) + return params def param_remapper(orig_params: Params) -> Params: - """Remaps params to new module layout. + """Remaps params to new module layout. This is needed here because the model definition does not have a separate `mlp` module. @@ -312,26 +362,26 @@ def param_remapper(orig_params: Params) -> Params: Returns: dict of params with different names. """ - new_params = {} - for k, v in orig_params.items(): - if 'mlp/' in k: - layer_name, param = k.rsplit('/', maxsplit=1) - if layer_name not in new_params: - new_params[layer_name] = {} - if 'w' in v: - new_params[layer_name][param] = v['w'] - else: - new_params[k] = v - return new_params + new_params = {} + for k, v in orig_params.items(): + if 'mlp/' in k: + layer_name, param = k.rsplit('/', maxsplit=1) + if layer_name not in new_params: + new_params[layer_name] = {} + if 'w' in v: + new_params[layer_name][param] = v['w'] + else: + new_params[k] = v + return new_params def nest_params(params: Params) -> Params: - """Nests params as a dict of dicts rather than a flat dict.""" - nested_params = {} - for path, param in params.items(): - *path, leaf = path.split('/') - subdict = nested_params - for key in path: - subdict = subdict.setdefault(key, {}) - subdict[leaf] = param - return nested_params + """Nests params as a dict of dicts rather than a flat dict.""" + nested_params = {} + for path, param in params.items(): + *path, leaf = path.split('/') + subdict = nested_params + for key in path: + subdict = subdict.setdefault(key, {}) + subdict[leaf] = param + return nested_params diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 0fe632aa8923e..7f5d7efe57880 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -92,8 +92,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): self._warmup_model() def _warmup_model(self) -> None: - # self.model_runner.warmup_model(self.tpu_cache) - pass + self.model_runner.warmup_model(self.tpu_cache) def get_cache_block_size_bytes(self) -> int: head_size = self.model_config.get_head_size() @@ -129,8 +128,8 @@ class TPUWorker(LoraNotSupportedWorkerBase): if num_seq_groups == 0: return {} - output, kv_caches = self.model_runner.execute_model(seq_group_metadata_list, - self.tpu_cache) + output, kv_caches = self.model_runner.execute_model( + seq_group_metadata_list, self.tpu_cache) self.tpu_cache = kv_caches return output