Add warmup + formatting

This commit is contained in:
Woosuk Kwon 2024-04-26 05:28:09 +00:00
parent 4ea41d01a9
commit 5ae2f81c2b
2 changed files with 97 additions and 48 deletions

View File

@ -14,6 +14,7 @@ from vllm.utils import pad_to_max_length
# DELETE
from jax_smi import initialise_tracking
initialise_tracking()
logger = init_logger(__name__)
@ -43,11 +44,11 @@ class TPUModelRunner:
"The model will run without sliding window.")
self.model = None
self.block_size = None
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
# FIXME(woosuk)
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
def load_model(self) -> None:
def load_model(self) -> None:
from huggingface_hub import snapshot_download
from vllm.model_executor.models.jax.gemma import Transformer
@ -60,6 +61,51 @@ class TPUModelRunner:
params = load_and_format_params(model_dir + "/7b/")["transformer"]
self.params = {"params": params}
def warmup_model(self, tpu_caches: List[Tuple[jax.Array,
jax.Array]]) -> None:
# Prefill
logger.info("Warming up the model...")
start = time.time()
for batch_size in [1]:
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
if batch_size * seq_len > 8192:
continue
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len),
dtype=jnp.int32)
slot_mapping = jnp.zeros((batch_size, seq_len),
dtype=jnp.int32)
block_tables = None
context_lens = None
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
# Dummy run.
_, tpu_caches = self.compiled_fn(self.params, token_ids,
position_ids, slot_mapping,
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.")
# Decode
start = time.time()
for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]:
seq_len = 1
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size],
dtype=jnp.int32)
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
_, tpu_caches = self.compiled_fn(self.params, token_ids,
position_ids, slot_mapping,
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Decode warmup done in {(end - start):.2f} seconds.")
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -107,9 +153,9 @@ class TPUModelRunner:
pad=0,
dtype=jnp.int32)
slot_mapping = _make_array_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=jnp.int32)
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=jnp.int32)
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
@ -160,7 +206,8 @@ class TPUModelRunner:
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size],
dtype=jnp.int32)
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
return (input_tokens, input_positions, slot_mapping, block_tables,
context_lens, input_lens)
@ -210,25 +257,26 @@ class TPUModelRunner:
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[jax.Array],
) -> Tuple[Optional[SamplerOutput], List[jax.Array]]:
kv_caches: List[Tuple[jax.Array, jax.Array]],
) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]:
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
start = time.time()
inputs = self.prepare_input_arrays(seq_group_metadata_list)
end = time.time()
print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
start = time.time()
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
next_token_ids, new_kv_caches = self.compiled_fn(
self.params, *inputs, kv_caches)
next_token_ids.block_until_ready()
end = time.time()
print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
start = time.time()
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
end = time.time()
print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
next_token_ids = next_token_ids.tolist()
i = 0
@ -240,7 +288,9 @@ class TPUModelRunner:
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
next_token_id = next_token_ids[i]
seq_outputs.append(SequenceOutput(seq_id, next_token_id, {next_token_id: Logprob(0.0)}))
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: Logprob(0.0)}))
i += 1
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
@ -284,24 +334,24 @@ Params = Mapping[str, Any]
def load_and_format_params(path: str) -> Params:
"""Loads parameters and formats them for compatibility."""
params = load_params(path)
param_state = jax.tree_util.tree_map(jnp.array, params)
remapped_params = param_remapper(param_state)
nested_params = nest_params(remapped_params)
return nested_params
"""Loads parameters and formats them for compatibility."""
params = load_params(path)
param_state = jax.tree_util.tree_map(jnp.array, params)
remapped_params = param_remapper(param_state)
nested_params = nest_params(remapped_params)
return nested_params
@functools.cache
def load_params(path: str) -> Params:
"""Loads parameters from a checkpoint path."""
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(path)
return params
"""Loads parameters from a checkpoint path."""
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(path)
return params
def param_remapper(orig_params: Params) -> Params:
"""Remaps params to new module layout.
"""Remaps params to new module layout.
This is needed here because the model definition does not have a separate
`mlp` module.
@ -312,26 +362,26 @@ def param_remapper(orig_params: Params) -> Params:
Returns:
dict of params with different names.
"""
new_params = {}
for k, v in orig_params.items():
if 'mlp/' in k:
layer_name, param = k.rsplit('/', maxsplit=1)
if layer_name not in new_params:
new_params[layer_name] = {}
if 'w' in v:
new_params[layer_name][param] = v['w']
else:
new_params[k] = v
return new_params
new_params = {}
for k, v in orig_params.items():
if 'mlp/' in k:
layer_name, param = k.rsplit('/', maxsplit=1)
if layer_name not in new_params:
new_params[layer_name] = {}
if 'w' in v:
new_params[layer_name][param] = v['w']
else:
new_params[k] = v
return new_params
def nest_params(params: Params) -> Params:
"""Nests params as a dict of dicts rather than a flat dict."""
nested_params = {}
for path, param in params.items():
*path, leaf = path.split('/')
subdict = nested_params
for key in path:
subdict = subdict.setdefault(key, {})
subdict[leaf] = param
return nested_params
"""Nests params as a dict of dicts rather than a flat dict."""
nested_params = {}
for path, param in params.items():
*path, leaf = path.split('/')
subdict = nested_params
for key in path:
subdict = subdict.setdefault(key, {})
subdict[leaf] = param
return nested_params

View File

@ -92,8 +92,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self._warmup_model()
def _warmup_model(self) -> None:
# self.model_runner.warmup_model(self.tpu_cache)
pass
self.model_runner.warmup_model(self.tpu_cache)
def get_cache_block_size_bytes(self) -> int:
head_size = self.model_config.get_head_size()
@ -129,8 +128,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
if num_seq_groups == 0:
return {}
output, kv_caches = self.model_runner.execute_model(seq_group_metadata_list,
self.tpu_cache)
output, kv_caches = self.model_runner.execute_model(
seq_group_metadata_list, self.tpu_cache)
self.tpu_cache = kv_caches
return output