Add precompilation step

This commit is contained in:
Woosuk Kwon 2024-04-26 05:43:08 +00:00
parent d1591f0f1f
commit b15db234ba
2 changed files with 11 additions and 6 deletions

View File

@ -61,10 +61,12 @@ class TPUModelRunner:
params = load_and_format_params(model_dir + "/7b/")["transformer"]
self.params = {"params": params}
def warmup_model(self, tpu_caches: List[Tuple[jax.Array,
jax.Array]]) -> None:
def warmup_model(
self,
tpu_caches: List[Tuple[jax.Array, jax.Array]],
) -> List[Tuple[jax.Array, jax.Array]]:
# Prefill
logger.info("Warming up the model...")
logger.info("Compiling the model with different input shapes...")
start = time.time()
for batch_size in [1]:
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
@ -85,7 +87,7 @@ class TPUModelRunner:
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.")
logger.info(f"Compilation for prefill done in {(end - start):.2f} s.")
# Decode
start = time.time()
@ -104,7 +106,8 @@ class TPUModelRunner:
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Decode warmup done in {(end - start):.2f} seconds.")
logger.info(f"Compilation for decode done in {(end - start):.2f} s.")
return tpu_caches
def _prepare_prompt(
self,

View File

@ -92,7 +92,9 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self._warmup_model()
def _warmup_model(self) -> None:
self.model_runner.warmup_model(self.tpu_cache)
# NOTE(woosuk): Because of buffer donation, the reference to the cache
# should be updated after the warmup.
self.tpu_cache = self.model_runner.warmup_model(self.tpu_cache)
def get_cache_block_size_bytes(self) -> int:
head_size = self.model_config.get_head_size()