diff --git a/vllm/model_executor/models/jax/gemma.py b/vllm/model_executor/models/jax/gemma.py index da3d94a4c1dd7..4d98d23c4d2cb 100644 --- a/vllm/model_executor/models/jax/gemma.py +++ b/vllm/model_executor/models/jax/gemma.py @@ -154,10 +154,8 @@ class Attention(nn.Module): ) # Write the incoming keys and values to KV cache. - key_cache = cache[0] - value_cache = cache[1] - key_cache = write_to_cache(key_proj, key_cache, slot_mapping) - value_cache = write_to_cache(value_proj, value_cache, slot_mapping) + key_cache = write_to_cache(key_proj, cache[0], slot_mapping) + value_cache = write_to_cache(value_proj, cache[1], slot_mapping) cache = jnp.stack([key_cache, value_cache]) if block_tables is None: @@ -308,19 +306,20 @@ class Transformer(nn.Module): logits_indices: jax.Array, ) -> tuple[jax.Array, list[jax.Array]]: x = self.embedder.encode(token_ids) + new_caches = [] for i, block in enumerate(self.blocks): - layer_cache = kv_caches[i] - x, layer_cache = block( + x, new_cache = block( x, positions, slot_mapping, block_tables, context_lens, - layer_cache, + kv_caches[i], ) - kv_caches[i] = layer_cache + new_caches.append(new_cache) + x = self.final_norm(x) x = x.reshape(-1, x.shape[-1]) hidden_states = x[logits_indices] logits = self.embedder.decode(hidden_states) - return logits, kv_caches + return logits, new_caches diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 5683003178058..afb0a3b2ff317 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -168,7 +168,7 @@ class TPUModelRunner: base_indicies = jnp.arange(batch_size, dtype=jnp.int32) * seq_len logits_indices = base_indicies + input_lens - 1 - logits, kv_caches = self.model.apply( + logits, new_kv_caches = self.model.apply( {"params": self.params["transformer"]}, token_ids, position_ids, @@ -178,17 +178,17 @@ class TPUModelRunner: kv_caches, logits_indices, ) - return logits, kv_caches + return logits, new_kv_caches def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[jax.Array], - ) -> Optional[SamplerOutput]: + ) -> Tuple[Optional[SamplerOutput], List[jax.Array]]: from vllm.sequence import SequenceOutput, SequenceGroupOutput, Logprob inputs = self.prepare_input_arrays(seq_group_metadata_list) - logits, _ = self.compiled_fn(*inputs, kv_caches) + logits, new_kv_caches = self.compiled_fn(*inputs, kv_caches) next_token_ids = jnp.argmax(logits, axis=-1) next_token_ids = jax.device_put(next_token_ids, jax.devices("cpu")[0]) next_token_ids = next_token_ids.tolist() @@ -205,7 +205,7 @@ class TPUModelRunner: i += 1 sampler_outputs.append(SequenceGroupOutput(seq_outputs, None)) - return SamplerOutput(sampler_outputs) + return SamplerOutput(sampler_outputs), new_kv_caches def _make_array_with_pad( diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 2626c9d13a966..0bd1b8437da24 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -121,8 +121,9 @@ class TPUWorker(LoraNotSupportedWorkerBase): if num_seq_groups == 0: return {} - output = self.model_runner.execute_model(seq_group_metadata_list, + output, kv_caches = self.model_runner.execute_model(seq_group_metadata_list, self.tpu_cache) + self.tpu_cache = kv_caches return output