mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-15 01:17:04 +08:00
explictly return new_kv_caches
This commit is contained in:
parent
ef762cb110
commit
186c88c497
@ -154,10 +154,8 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
# Write the incoming keys and values to KV cache.
|
||||
key_cache = cache[0]
|
||||
value_cache = cache[1]
|
||||
key_cache = write_to_cache(key_proj, key_cache, slot_mapping)
|
||||
value_cache = write_to_cache(value_proj, value_cache, slot_mapping)
|
||||
key_cache = write_to_cache(key_proj, cache[0], slot_mapping)
|
||||
value_cache = write_to_cache(value_proj, cache[1], slot_mapping)
|
||||
cache = jnp.stack([key_cache, value_cache])
|
||||
|
||||
if block_tables is None:
|
||||
@ -308,19 +306,20 @@ class Transformer(nn.Module):
|
||||
logits_indices: jax.Array,
|
||||
) -> tuple[jax.Array, list[jax.Array]]:
|
||||
x = self.embedder.encode(token_ids)
|
||||
new_caches = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
layer_cache = kv_caches[i]
|
||||
x, layer_cache = block(
|
||||
x, new_cache = block(
|
||||
x,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_tables,
|
||||
context_lens,
|
||||
layer_cache,
|
||||
kv_caches[i],
|
||||
)
|
||||
kv_caches[i] = layer_cache
|
||||
new_caches.append(new_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
hidden_states = x[logits_indices]
|
||||
logits = self.embedder.decode(hidden_states)
|
||||
return logits, kv_caches
|
||||
return logits, new_caches
|
||||
|
||||
@ -168,7 +168,7 @@ class TPUModelRunner:
|
||||
base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len
|
||||
logits_indices = base_indicies + input_lens - 1
|
||||
|
||||
logits, kv_caches = self.model.apply(
|
||||
logits, new_kv_caches = self.model.apply(
|
||||
{"params": self.params["transformer"]},
|
||||
token_ids,
|
||||
position_ids,
|
||||
@ -178,17 +178,17 @@ class TPUModelRunner:
|
||||
kv_caches,
|
||||
logits_indices,
|
||||
)
|
||||
return logits, kv_caches
|
||||
return logits, new_kv_caches
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[jax.Array],
|
||||
) -> Optional[SamplerOutput]:
|
||||
) -> Tuple[Optional[SamplerOutput], List[jax.Array]]:
|
||||
from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob
|
||||
|
||||
inputs = self.prepare_input_arrays(seq_group_metadata_list)
|
||||
logits, _ = self.compiled_fn(*inputs, kv_caches)
|
||||
logits, new_kv_caches = self.compiled_fn(*inputs, kv_caches)
|
||||
next_token_ids = jnp.argmax(logits, axis=-1)
|
||||
next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0])
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
@ -205,7 +205,7 @@ class TPUModelRunner:
|
||||
i += 1
|
||||
|
||||
sampler_outputs.append(SequenceGroupOutput(seq_outputs, None))
|
||||
return SamplerOutput(sampler_outputs)
|
||||
return SamplerOutput(sampler_outputs), new_kv_caches
|
||||
|
||||
|
||||
def _make_array_with_pad(
|
||||
|
||||
@ -121,8 +121,9 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
if num_seq_groups == 0:
|
||||
return {}
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
output, kv_caches = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.tpu_cache)
|
||||
self.tpu_cache = kv_caches
|
||||
return output
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user