Add timer

This commit is contained in:
Woosuk Kwon 2024-04-25 05:06:11 +00:00
parent 81b8b813f1
commit 98eda57899

View File

@ -1,3 +1,4 @@
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
@ -213,9 +214,22 @@ class TPUModelRunner:
) -> Tuple[Optional[SamplerOutput], List[jax.Array]]:
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
start = time.time()
inputs = self.prepare_input_arrays(seq_group_metadata_list)
end = time.time()
print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
start = time.time()
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
next_token_ids.block_until_ready()
end = time.time()
print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
start = time.time()
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
end = time.time()
print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
next_token_ids = next_token_ids.tolist()
i = 0