diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index f98d8e758ccce..77f4885deef64 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,3 +1,4 @@ +import time from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -213,9 +214,22 @@ class TPUModelRunner: ) -> Tuple[Optional[SamplerOutput], List[jax.Array]]: from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob + start = time.time() inputs = self.prepare_input_arrays(seq_group_metadata_list) + end = time.time() + print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms") + + start = time.time() next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches) + next_token_ids.block_until_ready() + end = time.time() + print(f"compiled_fn: {(end - start) * 1000:.2f} ms") + + start = time.time() next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) + end = time.time() + print(f"jax.device_put: {(end - start) * 1000:.2f} ms") + next_token_ids = next_token_ids.tolist() i = 0