mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-26 09:19:11 +08:00
Minor
This commit is contained in:
parent
743695f586
commit
84284302d8
@ -12,6 +12,8 @@ from vllm.utils import pad_to_max_length
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
|
||||
@ -35,7 +37,10 @@ class TPUModelRunner:
|
||||
self.model = None
|
||||
self.block_size = None
|
||||
# FIXME
|
||||
# self.compiled_fn = jax.jit(self._execute_step)
|
||||
# self.compiled_fn = jax.jit(
|
||||
# self._execute_step,
|
||||
# donate_argnums=6,
|
||||
# )
|
||||
self.compiled_fn = self._execute_step
|
||||
|
||||
def load_model(self) -> None:
|
||||
@ -96,7 +101,7 @@ class TPUModelRunner:
|
||||
dtype=jnp.int32)
|
||||
slot_mapping = _make_array_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=0, # FIXME
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=jnp.int32)
|
||||
prompt_lens = jnp.asarray(prompt_lens, dtype=jnp.int32)
|
||||
return input_tokens, input_positions, slot_mapping, None, None, prompt_lens
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user