From f7a6bd0fa1d9ede66c935c60017c494d1a852fc6 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:30:42 +0100 Subject: [PATCH] Fix missing `kv_caches` and `attn_metadata` in `OpenVINOCausalLM` (#14271) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/model_loader/openvino.py | 19 +++++++++---------- vllm/worker/openvino_model_runner.py | 2 ++ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index 805f0cfc585e3..cd2d427edbbd1 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -2,7 +2,7 @@ # ruff: noqa: SIM117 from pathlib import Path -from typing import List, Optional, Tuple +from typing import Optional import openvino as ov import torch @@ -12,8 +12,8 @@ from optimum.intel import OVModelForCausalLM from torch import nn import vllm.envs as envs -from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import (LogitsProcessor, _prune_hidden_states) @@ -24,7 +24,7 @@ from vllm.platforms import current_platform logger = init_logger(__name__) -def _flattenize_inputs(inputs): +def _flatten_inputs(inputs): """ Helper function for making nested inputs flattens """ @@ -33,10 +33,9 @@ def _flattenize_inputs(inputs): if input_data is None: continue if isinstance(input_data, (list, tuple)): - flatten_inputs.extend(_flattenize_inputs(input_data)) + flatten_inputs.extend(_flatten_inputs(input_data)) elif isinstance(input_data, dict): - flatten_inputs.extend(_flattenize_inputs(list( - input_data.values()))) + flatten_inputs.extend(_flatten_inputs(list(input_data.values()))) else: flatten_inputs.append(input_data) return flatten_inputs @@ -147,15 +146,15 @@ class OpenVINOCausalLM(nn.Module): self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], - attn_metadata: OpenVINOAttentionMetadata, + kv_caches: list[tuple[ov.Tensor, ov.Tensor]], ) -> torch.Tensor: - flatten_kv_cache = _flattenize_inputs(kv_caches) + flat_kv_caches = _flatten_inputs(kv_caches) + attn_metadata = get_forward_context().attn_metadata inputs = [ input_ids, positions, - *flatten_kv_cache, + *flat_kv_caches, attn_metadata.past_lens, attn_metadata.subsequence_begins, attn_metadata.block_indices, diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 5035ea20294c4..aa1d2cbb2df29 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -346,6 +346,8 @@ class OpenVINOModelRunner(ModelRunnerBase): input_tokens, "positions": input_positions, + "kv_caches": + kv_caches, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, device=self.device), }