mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 02:17:53 +08:00
Include argmax to jit
This commit is contained in:
parent
620e7646d3
commit
f42b4c27d8
@ -10,6 +10,10 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import pad_to_max_length
|
from vllm.utils import pad_to_max_length
|
||||||
|
|
||||||
|
# DELETE
|
||||||
|
# from jax_smi import initialise_tracking
|
||||||
|
# initialise_tracking()
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PAD_SLOT_ID = -1
|
_PAD_SLOT_ID = -1
|
||||||
@ -189,7 +193,9 @@ class TPUModelRunner:
|
|||||||
kv_caches,
|
kv_caches,
|
||||||
logits_indices,
|
logits_indices,
|
||||||
)
|
)
|
||||||
return logits, new_kv_caches
|
# TODO
|
||||||
|
next_token_ids = jnp.argmax(logits, axis=-1)
|
||||||
|
return next_token_ids, new_kv_caches
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -199,8 +205,7 @@ class TPUModelRunner:
|
|||||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||||
|
|
||||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||||
logits, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
|
next_token_ids, new_kv_caches = self.compiled_fn(self.params, *inputs, kv_caches)
|
||||||
next_token_ids = jnp.argmax(logits, axis=-1)
|
|
||||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
i = 0
|
i = 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user