mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 17:47:04 +08:00
Add precompilation step
This commit is contained in:
parent
d1591f0f1f
commit
b15db234ba
@ -61,10 +61,12 @@ class TPUModelRunner:
|
||||
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||
self.params = {"params": params}
|
||||
|
||||
def warmup_model(self, tpu_caches: List[Tuple[jax.Array,
|
||||
jax.Array]]) -> None:
|
||||
def warmup_model(
|
||||
self,
|
||||
tpu_caches: List[Tuple[jax.Array, jax.Array]],
|
||||
) -> List[Tuple[jax.Array, jax.Array]]:
|
||||
# Prefill
|
||||
logger.info("Warming up the model...")
|
||||
logger.info("Compiling the model with different input shapes...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
@ -85,7 +87,7 @@ class TPUModelRunner:
|
||||
block_tables, context_lens,
|
||||
prompt_lens, tpu_caches)
|
||||
end = time.time()
|
||||
logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.")
|
||||
logger.info(f"Compilation for prefill done in {(end - start):.2f} s.")
|
||||
|
||||
# Decode
|
||||
start = time.time()
|
||||
@ -104,7 +106,8 @@ class TPUModelRunner:
|
||||
block_tables, context_lens,
|
||||
prompt_lens, tpu_caches)
|
||||
end = time.time()
|
||||
logger.info(f"Decode warmup done in {(end - start):.2f} seconds.")
|
||||
logger.info(f"Compilation for decode done in {(end - start):.2f} s.")
|
||||
return tpu_caches
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
|
||||
@ -92,7 +92,9 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
self.model_runner.warmup_model(self.tpu_cache)
|
||||
# NOTE(woosuk): Because of buffer donation, the reference to the cache
|
||||
# should be updated after the warmup.
|
||||
self.tpu_cache = self.model_runner.warmup_model(self.tpu_cache)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
head_size = self.model_config.get_head_size()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user