mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-18 16:37:21 +08:00
Add warmup + formatting
This commit is contained in:
parent
4ea41d01a9
commit
5ae2f81c2b
@ -14,6 +14,7 @@ from vllm.utils import pad_to_max_length
|
||||
|
||||
# DELETE
|
||||
from jax_smi import initialise_tracking
|
||||
|
||||
initialise_tracking()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -43,11 +44,11 @@ class TPUModelRunner:
|
||||
"The model will run without sliding window.")
|
||||
self.model = None
|
||||
self.block_size = None
|
||||
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
|
||||
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
|
||||
# FIXME(woosuk)
|
||||
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
|
||||
|
||||
def load_model(self) -> None:
|
||||
def load_model(self) -> None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm.model_executor.models.jax.gemma import Transformer
|
||||
@ -60,6 +61,51 @@ class TPUModelRunner:
|
||||
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||
self.params = {"params": params}
|
||||
|
||||
def warmup_model(self, tpu_caches: List[Tuple[jax.Array,
|
||||
jax.Array]]) -> None:
|
||||
# Prefill
|
||||
logger.info("Warming up the model...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||
if batch_size * seq_len > 8192:
|
||||
continue
|
||||
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
position_ids = jnp.zeros((batch_size, seq_len),
|
||||
dtype=jnp.int32)
|
||||
slot_mapping = jnp.zeros((batch_size, seq_len),
|
||||
dtype=jnp.int32)
|
||||
block_tables = None
|
||||
context_lens = None
|
||||
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
|
||||
# Dummy run.
|
||||
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||
position_ids, slot_mapping,
|
||||
block_tables, context_lens,
|
||||
prompt_lens, tpu_caches)
|
||||
end = time.time()
|
||||
logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.")
|
||||
|
||||
# Decode
|
||||
start = time.time()
|
||||
for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]:
|
||||
seq_len = 1
|
||||
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||
block_tables = jnp.asarray(self.block_tables[:batch_size],
|
||||
dtype=jnp.int32)
|
||||
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||
|
||||
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||
position_ids, slot_mapping,
|
||||
block_tables, context_lens,
|
||||
prompt_lens, tpu_caches)
|
||||
end = time.time()
|
||||
logger.info(f"Decode warmup done in {(end - start):.2f} seconds.")
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@ -107,9 +153,9 @@ class TPUModelRunner:
|
||||
pad=0,
|
||||
dtype=jnp.int32)
|
||||
slot_mapping = _make_array_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=jnp.int32)
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=jnp.int32)
|
||||
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
|
||||
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
|
||||
|
||||
@ -160,7 +206,8 @@ class TPUModelRunner:
|
||||
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
|
||||
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
|
||||
|
||||
block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32)
|
||||
block_tables = jnp.asarray(self.block_tables[:batch_size],
|
||||
dtype=jnp.int32)
|
||||
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
|
||||
return (input_tokens, input_positions, slot_mapping, block_tables,
|
||||
context_lens, input_lens)
|
||||
@ -210,25 +257,26 @@ class TPUModelRunner:
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[jax.Array],
|
||||
) -> Tuple[Optional[SamplerOutput], List[jax.Array]]:
|
||||
kv_caches: List[Tuple[jax.Array, jax.Array]],
|
||||
) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]:
|
||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||
|
||||
start = time.time()
|
||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||
end = time.time()
|
||||
print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
||||
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
|
||||
next_token_ids, new_kv_caches = self.compiled_fn(
|
||||
self.params, *inputs, kv_caches)
|
||||
next_token_ids.block_until_ready()
|
||||
end = time.time()
|
||||
print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
||||
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
start = time.time()
|
||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||
end = time.time()
|
||||
print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
||||
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
i = 0
|
||||
@ -240,7 +288,9 @@ class TPUModelRunner:
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[i]
|
||||
seq_outputs.append(SequenceOutput(seq_id, next_token_id, {next_token_id: Logprob(0.0)}))
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: Logprob(0.0)}))
|
||||
i += 1
|
||||
|
||||
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
|
||||
@ -284,24 +334,24 @@ Params = Mapping[str, Any]
|
||||
|
||||
|
||||
def load_and_format_params(path: str) -> Params:
|
||||
"""Loads parameters and formats them for compatibility."""
|
||||
params = load_params(path)
|
||||
param_state = jax.tree_util.tree_map(jnp.array, params)
|
||||
remapped_params = param_remapper(param_state)
|
||||
nested_params = nest_params(remapped_params)
|
||||
return nested_params
|
||||
"""Loads parameters and formats them for compatibility."""
|
||||
params = load_params(path)
|
||||
param_state = jax.tree_util.tree_map(jnp.array, params)
|
||||
remapped_params = param_remapper(param_state)
|
||||
nested_params = nest_params(remapped_params)
|
||||
return nested_params
|
||||
|
||||
|
||||
@functools.cache
|
||||
def load_params(path: str) -> Params:
|
||||
"""Loads parameters from a checkpoint path."""
|
||||
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||
params = checkpointer.restore(path)
|
||||
return params
|
||||
"""Loads parameters from a checkpoint path."""
|
||||
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||
params = checkpointer.restore(path)
|
||||
return params
|
||||
|
||||
|
||||
def param_remapper(orig_params: Params) -> Params:
|
||||
"""Remaps params to new module layout.
|
||||
"""Remaps params to new module layout.
|
||||
|
||||
This is needed here because the model definition does not have a separate
|
||||
`mlp` module.
|
||||
@ -312,26 +362,26 @@ def param_remapper(orig_params: Params) -> Params:
|
||||
Returns:
|
||||
dict of params with different names.
|
||||
"""
|
||||
new_params = {}
|
||||
for k, v in orig_params.items():
|
||||
if 'mlp/' in k:
|
||||
layer_name, param = k.rsplit('/', maxsplit=1)
|
||||
if layer_name not in new_params:
|
||||
new_params[layer_name] = {}
|
||||
if 'w' in v:
|
||||
new_params[layer_name][param] = v['w']
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
new_params = {}
|
||||
for k, v in orig_params.items():
|
||||
if 'mlp/' in k:
|
||||
layer_name, param = k.rsplit('/', maxsplit=1)
|
||||
if layer_name not in new_params:
|
||||
new_params[layer_name] = {}
|
||||
if 'w' in v:
|
||||
new_params[layer_name][param] = v['w']
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def nest_params(params: Params) -> Params:
|
||||
"""Nests params as a dict of dicts rather than a flat dict."""
|
||||
nested_params = {}
|
||||
for path, param in params.items():
|
||||
*path, leaf = path.split('/')
|
||||
subdict = nested_params
|
||||
for key in path:
|
||||
subdict = subdict.setdefault(key, {})
|
||||
subdict[leaf] = param
|
||||
return nested_params
|
||||
"""Nests params as a dict of dicts rather than a flat dict."""
|
||||
nested_params = {}
|
||||
for path, param in params.items():
|
||||
*path, leaf = path.split('/')
|
||||
subdict = nested_params
|
||||
for key in path:
|
||||
subdict = subdict.setdefault(key, {})
|
||||
subdict[leaf] = param
|
||||
return nested_params
|
||||
|
||||
@ -92,8 +92,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
# self.model_runner.warmup_model(self.tpu_cache)
|
||||
pass
|
||||
self.model_runner.warmup_model(self.tpu_cache)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
head_size = self.model_config.get_head_size()
|
||||
@ -129,8 +128,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
if num_seq_groups == 0:
|
||||
return {}
|
||||
|
||||
output, kv_caches = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.tpu_cache)
|
||||
output, kv_caches = self.model_runner.execute_model(
|
||||
seq_group_metadata_list, self.tpu_cache)
|
||||
self.tpu_cache = kv_caches
|
||||
return output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user