mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 06:17:07 +08:00
Add timer
This commit is contained in:
parent
81b8b813f1
commit
98eda57899
@ -1,3 +1,4 @@
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -213,9 +214,22 @@ class TPUModelRunner:
|
||||
) -> Tuple[Optional[SamplerOutput], List[jax.Array]]:
|
||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||
|
||||
start = time.time()
|
||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||
end = time.time()
|
||||
print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
|
||||
next_token_ids.block_until_ready()
|
||||
end = time.time()
|
||||
print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||
end = time.time()
|
||||
print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
i = 0
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user