Add warmup + formatting

This commit is contained in:
Woosuk Kwon 2024-04-26 05:28:09 +00:00
parent 4ea41d01a9
commit 5ae2f81c2b
2 changed files with 97 additions and 48 deletions

View File

@ -14,6 +14,7 @@ from vllm.utils import pad_to_max_length
# DELETE # DELETE
from jax_smi import initialise_tracking from jax_smi import initialise_tracking
initialise_tracking() initialise_tracking()
logger = init_logger(__name__) logger = init_logger(__name__)
@ -43,11 +44,11 @@ class TPUModelRunner:
"The model will run without sliding window.") "The model will run without sliding window.")
self.model = None self.model = None
self.block_size = None self.block_size = None
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,)) self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
# FIXME(woosuk) # FIXME(woosuk)
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32) self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
def load_model(self) -> None: def load_model(self) -> None:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from vllm.model_executor.models.jax.gemma import Transformer from vllm.model_executor.models.jax.gemma import Transformer
@ -60,6 +61,51 @@ class TPUModelRunner:
params = load_and_format_params(model_dir + "/7b/")["transformer"] params = load_and_format_params(model_dir + "/7b/")["transformer"]
self.params = {"params": params} self.params = {"params": params}
def warmup_model(self, tpu_caches: List[Tuple[jax.Array,
jax.Array]]) -> None:
# Prefill
logger.info("Warming up the model...")
start = time.time()
for batch_size in [1]:
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
if batch_size * seq_len > 8192:
continue
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len),
dtype=jnp.int32)
slot_mapping = jnp.zeros((batch_size, seq_len),
dtype=jnp.int32)
block_tables = None
context_lens = None
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
# Dummy run.
_, tpu_caches = self.compiled_fn(self.params, token_ids,
position_ids, slot_mapping,
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.")
# Decode
start = time.time()
for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]:
seq_len = 1
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size],
dtype=jnp.int32)
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
_, tpu_caches = self.compiled_fn(self.params, token_ids,
position_ids, slot_mapping,
block_tables, context_lens,
prompt_lens, tpu_caches)
end = time.time()
logger.info(f"Decode warmup done in {(end - start):.2f} seconds.")
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
@ -107,9 +153,9 @@ class TPUModelRunner:
pad=0, pad=0,
dtype=jnp.int32) dtype=jnp.int32)
slot_mapping = _make_array_with_pad(slot_mapping, slot_mapping = _make_array_with_pad(slot_mapping,
max_prompt_len, max_prompt_len,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=jnp.int32) dtype=jnp.int32)
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32) prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
@ -160,7 +206,8 @@ class TPUModelRunner:
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32) slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
context_lens = jnp.asarray(context_lens, dtype=jnp.int32) context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32) block_tables = jnp.asarray(self.block_tables[:batch_size],
dtype=jnp.int32)
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32) input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
return (input_tokens, input_positions, slot_mapping, block_tables, return (input_tokens, input_positions, slot_mapping, block_tables,
context_lens, input_lens) context_lens, input_lens)
@ -210,25 +257,26 @@ class TPUModelRunner:
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[jax.Array], kv_caches: List[Tuple[jax.Array, jax.Array]],
) -> Tuple[Optional[SamplerOutput], List[jax.Array]]: ) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]:
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
start = time.time() start = time.time()
inputs = self.prepare_input_arrays(seq_group_metadata_list) inputs = self.prepare_input_arrays(seq_group_metadata_list)
end = time.time() end = time.time()
print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms") # print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
start = time.time() start = time.time()
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches) next_token_ids, new_kv_caches = self.compiled_fn(
self.params, *inputs, kv_caches)
next_token_ids.block_until_ready() next_token_ids.block_until_ready()
end = time.time() end = time.time()
print(f"compiled_fn: {(end - start) * 1000:.2f} ms") # print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
start = time.time() start = time.time()
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
end = time.time() end = time.time()
print(f"jax.device_put: {(end - start) * 1000:.2f} ms") # print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
i = 0 i = 0
@ -240,7 +288,9 @@ class TPUModelRunner:
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids: for seq_id in seq_ids:
next_token_id = next_token_ids[i] next_token_id = next_token_ids[i]
seq_outputs.append(SequenceOutput(seq_id, next_token_id, {next_token_id: Logprob(0.0)})) seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: Logprob(0.0)}))
i += 1 i += 1
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None)) sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
@ -284,24 +334,24 @@ Params = Mapping[str, Any]
def load_and_format_params(path: str) -> Params: def load_and_format_params(path: str) -> Params:
"""Loads parameters and formats them for compatibility.""" """Loads parameters and formats them for compatibility."""
params = load_params(path) params = load_params(path)
param_state = jax.tree_util.tree_map(jnp.array, params) param_state = jax.tree_util.tree_map(jnp.array, params)
remapped_params = param_remapper(param_state) remapped_params = param_remapper(param_state)
nested_params = nest_params(remapped_params) nested_params = nest_params(remapped_params)
return nested_params return nested_params
@functools.cache @functools.cache
def load_params(path: str) -> Params: def load_params(path: str) -> Params:
"""Loads parameters from a checkpoint path.""" """Loads parameters from a checkpoint path."""
checkpointer = orbax.checkpoint.PyTreeCheckpointer() checkpointer = orbax.checkpoint.PyTreeCheckpointer()
params = checkpointer.restore(path) params = checkpointer.restore(path)
return params return params
def param_remapper(orig_params: Params) -> Params: def param_remapper(orig_params: Params) -> Params:
"""Remaps params to new module layout. """Remaps params to new module layout.
This is needed here because the model definition does not have a separate This is needed here because the model definition does not have a separate
`mlp` module. `mlp` module.
@ -312,26 +362,26 @@ def param_remapper(orig_params: Params) -> Params:
Returns: Returns:
dict of params with different names. dict of params with different names.
""" """
new_params = {} new_params = {}
for k, v in orig_params.items(): for k, v in orig_params.items():
if 'mlp/' in k: if 'mlp/' in k:
layer_name, param = k.rsplit('/', maxsplit=1) layer_name, param = k.rsplit('/', maxsplit=1)
if layer_name not in new_params: if layer_name not in new_params:
new_params[layer_name] = {} new_params[layer_name] = {}
if 'w' in v: if 'w' in v:
new_params[layer_name][param] = v['w'] new_params[layer_name][param] = v['w']
else: else:
new_params[k] = v new_params[k] = v
return new_params return new_params
def nest_params(params: Params) -> Params: def nest_params(params: Params) -> Params:
"""Nests params as a dict of dicts rather than a flat dict.""" """Nests params as a dict of dicts rather than a flat dict."""
nested_params = {} nested_params = {}
for path, param in params.items(): for path, param in params.items():
*path, leaf = path.split('/') *path, leaf = path.split('/')
subdict = nested_params subdict = nested_params
for key in path: for key in path:
subdict = subdict.setdefault(key, {}) subdict = subdict.setdefault(key, {})
subdict[leaf] = param subdict[leaf] = param
return nested_params return nested_params

View File

@ -92,8 +92,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self._warmup_model() self._warmup_model()
def _warmup_model(self) -> None: def _warmup_model(self) -> None:
# self.model_runner.warmup_model(self.tpu_cache) self.model_runner.warmup_model(self.tpu_cache)
pass
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
head_size = self.model_config.get_head_size() head_size = self.model_config.get_head_size()
@ -129,8 +128,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
if num_seq_groups == 0: if num_seq_groups == 0:
return {} return {}
output, kv_caches = self.model_runner.execute_model(seq_group_metadata_list, output, kv_caches = self.model_runner.execute_model(
self.tpu_cache) seq_group_metadata_list, self.tpu_cache)
self.tpu_cache = kv_caches self.tpu_cache = kv_caches
return output return output