Fix for binary cache

This commit is contained in:
Woosuk Kwon 2024-04-26 08:56:12 +00:00
parent 408ff4950c
commit 3f6288cc89

View File

@ -16,6 +16,7 @@ logger = init_logger(__name__)
_PAD_SLOT_ID = -1
_MAX_NUM_SEQS = 256
_MAX_NUM_BLOCKS_PER_SEQ = 8192 // 16
class TPUModelRunner:
@ -41,7 +42,8 @@ class TPUModelRunner:
self.block_size = None
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
# FIXME(woosuk)
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
self.block_tables = np.zeros((_MAX_NUM_SEQS, _MAX_NUM_BLOCKS_PER_SEQ),
dtype=np.int32)
def load_model(self) -> None:
from huggingface_hub import snapshot_download
@ -55,6 +57,7 @@ class TPUModelRunner:
model_dir = snapshot_download(model_name)
params = load_and_format_params(model_dir + "/7b/")["transformer"]
self.params = {"params": params}
self.cpu_device = jax.devices("cpu")[0]
def warmup_model(
self,
@ -91,8 +94,8 @@ class TPUModelRunner:
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size],
dtype=jnp.int32)
block_tables = jnp.zeros((batch_size, _MAX_NUM_BLOCKS_PER_SEQ),
dtype=jnp.int32)
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
@ -263,6 +266,7 @@ class TPUModelRunner:
start = time.time()
inputs = self.prepare_input_arrays(seq_group_metadata_list)
end = time.time()
# print(inputs[0].shape)
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
start = time.time()
@ -273,7 +277,7 @@ class TPUModelRunner:
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
start = time.time()
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
next_token_ids = jax.device_put(next_token_ids, self.cpu_device)
end = time.time()
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")