mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:27:19 +08:00
Fix for binary cache
This commit is contained in:
parent
408ff4950c
commit
3f6288cc89
@ -16,6 +16,7 @@ logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
_MAX_NUM_SEQS = 256
|
||||
_MAX_NUM_BLOCKS_PER_SEQ = 8192 // 16
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
@ -41,7 +42,8 @@ class TPUModelRunner:
|
||||
self.block_size = None
|
||||
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
|
||||
# FIXME(woosuk)
|
||||
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
|
||||
self.block_tables = np.zeros((_MAX_NUM_SEQS, _MAX_NUM_BLOCKS_PER_SEQ),
|
||||
dtype=np.int32)
|
||||
|
||||
def load_model(self) -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
@ -55,6 +57,7 @@ class TPUModelRunner:
|
||||
model_dir = snapshot_download(model_name)
|
||||
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||
self.params = {"params": params}
|
||||
self.cpu_device = jax.devices("cpu")[0]
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
@ -91,8 +94,8 @@ class TPUModelRunner:
|
||||
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
block_tables = jnp.asarray(self.block_tables[:batch_size],
|
||||
dtype=jnp.int32)
|
||||
block_tables = jnp.zeros((batch_size, _MAX_NUM_BLOCKS_PER_SEQ),
|
||||
dtype=jnp.int32)
|
||||
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
|
||||
@ -263,6 +266,7 @@ class TPUModelRunner:
|
||||
start = time.time()
|
||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||
end = time.time()
|
||||
# print(inputs[0].shape)
|
||||
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
@ -273,7 +277,7 @@ class TPUModelRunner:
|
||||
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||
next_token_ids = jax.device_put(next_token_ids, self.cpu_device)
|
||||
end = time.time()
|
||||
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user