mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 06:42:20 +08:00
Add warmup + formatting
This commit is contained in:
parent
4ea41d01a9
commit
5ae2f81c2b
@ -14,6 +14,7 @@ from vllm.utils import pad_to_max_length
|
|||||||
|
|
||||||
# DELETE
|
# DELETE
|
||||||
from jax_smi import initialise_tracking
|
from jax_smi import initialise_tracking
|
||||||
|
|
||||||
initialise_tracking()
|
initialise_tracking()
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -43,11 +44,11 @@ class TPUModelRunner:
|
|||||||
"The model will run without sliding window.")
|
"The model will run without sliding window.")
|
||||||
self.model = None
|
self.model = None
|
||||||
self.block_size = None
|
self.block_size = None
|
||||||
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7,))
|
self.compiled_fn = jax.jit(self._execute_step, donate_argnums=(7, ))
|
||||||
# FIXME(woosuk)
|
# FIXME(woosuk)
|
||||||
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
|
self.block_tables = np.zeros((_MAX_NUM_SEQS, 512), dtype=np.int32)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from vllm.model_executor.models.jax.gemma import Transformer
|
from vllm.model_executor.models.jax.gemma import Transformer
|
||||||
@ -60,6 +61,51 @@ class TPUModelRunner:
|
|||||||
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
params = load_and_format_params(model_dir + "/7b/")["transformer"]
|
||||||
self.params = {"params": params}
|
self.params = {"params": params}
|
||||||
|
|
||||||
|
def warmup_model(self, tpu_caches: List[Tuple[jax.Array,
|
||||||
|
jax.Array]]) -> None:
|
||||||
|
# Prefill
|
||||||
|
logger.info("Warming up the model...")
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1]:
|
||||||
|
for seq_len in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]:
|
||||||
|
if batch_size * seq_len > 8192:
|
||||||
|
continue
|
||||||
|
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
position_ids = jnp.zeros((batch_size, seq_len),
|
||||||
|
dtype=jnp.int32)
|
||||||
|
slot_mapping = jnp.zeros((batch_size, seq_len),
|
||||||
|
dtype=jnp.int32)
|
||||||
|
block_tables = None
|
||||||
|
context_lens = None
|
||||||
|
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||||
|
|
||||||
|
# Dummy run.
|
||||||
|
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||||
|
position_ids, slot_mapping,
|
||||||
|
block_tables, context_lens,
|
||||||
|
prompt_lens, tpu_caches)
|
||||||
|
end = time.time()
|
||||||
|
logger.info(f"Prefill warmup done in {(end - start):.2f} seconds.")
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1, 2, 4] + [8 * i for i in range(1, 17)]:
|
||||||
|
seq_len = 1
|
||||||
|
token_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
slot_mapping = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
|
||||||
|
block_tables = jnp.asarray(self.block_tables[:batch_size],
|
||||||
|
dtype=jnp.int32)
|
||||||
|
context_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||||
|
prompt_lens = jnp.ones((batch_size, ), dtype=jnp.int32)
|
||||||
|
|
||||||
|
_, tpu_caches = self.compiled_fn(self.params, token_ids,
|
||||||
|
position_ids, slot_mapping,
|
||||||
|
block_tables, context_lens,
|
||||||
|
prompt_lens, tpu_caches)
|
||||||
|
end = time.time()
|
||||||
|
logger.info(f"Decode warmup done in {(end - start):.2f} seconds.")
|
||||||
|
|
||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
@ -107,9 +153,9 @@ class TPUModelRunner:
|
|||||||
pad=0,
|
pad=0,
|
||||||
dtype=jnp.int32)
|
dtype=jnp.int32)
|
||||||
slot_mapping = _make_array_with_pad(slot_mapping,
|
slot_mapping = _make_array_with_pad(slot_mapping,
|
||||||
max_prompt_len,
|
max_prompt_len,
|
||||||
pad=_PAD_SLOT_ID,
|
pad=_PAD_SLOT_ID,
|
||||||
dtype=jnp.int32)
|
dtype=jnp.int32)
|
||||||
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
|
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
|
||||||
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
|
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
|
||||||
|
|
||||||
@ -160,7 +206,8 @@ class TPUModelRunner:
|
|||||||
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
|
slot_mapping = jnp.asarray(slot_mapping, dtype=jnp.int32)
|
||||||
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
|
context_lens = jnp.asarray(context_lens, dtype=jnp.int32)
|
||||||
|
|
||||||
block_tables = jnp.asarray(self.block_tables[:batch_size], dtype=jnp.int32)
|
block_tables = jnp.asarray(self.block_tables[:batch_size],
|
||||||
|
dtype=jnp.int32)
|
||||||
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
|
input_lens = jnp.asarray([1] * batch_size, dtype=jnp.int32)
|
||||||
return (input_tokens, input_positions, slot_mapping, block_tables,
|
return (input_tokens, input_positions, slot_mapping, block_tables,
|
||||||
context_lens, input_lens)
|
context_lens, input_lens)
|
||||||
@ -210,25 +257,26 @@ class TPUModelRunner:
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
kv_caches: List[jax.Array],
|
kv_caches: List[Tuple[jax.Array, jax.Array]],
|
||||||
) -> Tuple[Optional[SamplerOutput], List[jax.Array]]:
|
) -> Tuple[Optional[SamplerOutput], List[Tuple[jax.Array, jax.Array]]]:
|
||||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
# print(f"prepare_input_arrays: {(end - start) * 1000:.2f} ms")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
|
next_token_ids, new_kv_caches = self.compiled_fn(
|
||||||
|
self.params, *inputs, kv_caches)
|
||||||
next_token_ids.block_until_ready()
|
next_token_ids.block_until_ready()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
# print(f"compiled_fn: {(end - start) * 1000:.2f} ms")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
# print(f"jax.device_put: {(end - start) * 1000:.2f} ms")
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
i = 0
|
i = 0
|
||||||
@ -240,7 +288,9 @@ class TPUModelRunner:
|
|||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
next_token_id = next_token_ids[i]
|
next_token_id = next_token_ids[i]
|
||||||
seq_outputs.append(SequenceOutput(seq_id, next_token_id, {next_token_id: Logprob(0.0)}))
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
{next_token_id: Logprob(0.0)}))
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
|
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
|
||||||
@ -284,24 +334,24 @@ Params = Mapping[str, Any]
|
|||||||
|
|
||||||
|
|
||||||
def load_and_format_params(path: str) -> Params:
|
def load_and_format_params(path: str) -> Params:
|
||||||
"""Loads parameters and formats them for compatibility."""
|
"""Loads parameters and formats them for compatibility."""
|
||||||
params = load_params(path)
|
params = load_params(path)
|
||||||
param_state = jax.tree_util.tree_map(jnp.array, params)
|
param_state = jax.tree_util.tree_map(jnp.array, params)
|
||||||
remapped_params = param_remapper(param_state)
|
remapped_params = param_remapper(param_state)
|
||||||
nested_params = nest_params(remapped_params)
|
nested_params = nest_params(remapped_params)
|
||||||
return nested_params
|
return nested_params
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def load_params(path: str) -> Params:
|
def load_params(path: str) -> Params:
|
||||||
"""Loads parameters from a checkpoint path."""
|
"""Loads parameters from a checkpoint path."""
|
||||||
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
||||||
params = checkpointer.restore(path)
|
params = checkpointer.restore(path)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def param_remapper(orig_params: Params) -> Params:
|
def param_remapper(orig_params: Params) -> Params:
|
||||||
"""Remaps params to new module layout.
|
"""Remaps params to new module layout.
|
||||||
|
|
||||||
This is needed here because the model definition does not have a separate
|
This is needed here because the model definition does not have a separate
|
||||||
`mlp` module.
|
`mlp` module.
|
||||||
@ -312,26 +362,26 @@ def param_remapper(orig_params: Params) -> Params:
|
|||||||
Returns:
|
Returns:
|
||||||
dict of params with different names.
|
dict of params with different names.
|
||||||
"""
|
"""
|
||||||
new_params = {}
|
new_params = {}
|
||||||
for k, v in orig_params.items():
|
for k, v in orig_params.items():
|
||||||
if 'mlp/' in k:
|
if 'mlp/' in k:
|
||||||
layer_name, param = k.rsplit('/', maxsplit=1)
|
layer_name, param = k.rsplit('/', maxsplit=1)
|
||||||
if layer_name not in new_params:
|
if layer_name not in new_params:
|
||||||
new_params[layer_name] = {}
|
new_params[layer_name] = {}
|
||||||
if 'w' in v:
|
if 'w' in v:
|
||||||
new_params[layer_name][param] = v['w']
|
new_params[layer_name][param] = v['w']
|
||||||
else:
|
else:
|
||||||
new_params[k] = v
|
new_params[k] = v
|
||||||
return new_params
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
def nest_params(params: Params) -> Params:
|
def nest_params(params: Params) -> Params:
|
||||||
"""Nests params as a dict of dicts rather than a flat dict."""
|
"""Nests params as a dict of dicts rather than a flat dict."""
|
||||||
nested_params = {}
|
nested_params = {}
|
||||||
for path, param in params.items():
|
for path, param in params.items():
|
||||||
*path, leaf = path.split('/')
|
*path, leaf = path.split('/')
|
||||||
subdict = nested_params
|
subdict = nested_params
|
||||||
for key in path:
|
for key in path:
|
||||||
subdict = subdict.setdefault(key, {})
|
subdict = subdict.setdefault(key, {})
|
||||||
subdict[leaf] = param
|
subdict[leaf] = param
|
||||||
return nested_params
|
return nested_params
|
||||||
|
|||||||
@ -92,8 +92,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
self._warmup_model()
|
self._warmup_model()
|
||||||
|
|
||||||
def _warmup_model(self) -> None:
|
def _warmup_model(self) -> None:
|
||||||
# self.model_runner.warmup_model(self.tpu_cache)
|
self.model_runner.warmup_model(self.tpu_cache)
|
||||||
pass
|
|
||||||
|
|
||||||
def get_cache_block_size_bytes(self) -> int:
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
head_size = self.model_config.get_head_size()
|
head_size = self.model_config.get_head_size()
|
||||||
@ -129,8 +128,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
if num_seq_groups == 0:
|
if num_seq_groups == 0:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
output, kv_caches = self.model_runner.execute_model(seq_group_metadata_list,
|
output, kv_caches = self.model_runner.execute_model(
|
||||||
self.tpu_cache)
|
seq_group_metadata_list, self.tpu_cache)
|
||||||
self.tpu_cache = kv_caches
|
self.tpu_cache = kv_caches
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user