Fix missing kv_caches and attn_metadata in OpenVINOCausalLM (#14271)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-07 13:30:42 +01:00 committed by GitHub
parent 0ca3b8e01c
commit f7a6bd0fa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 10 deletions

View File

@ -2,7 +2,7 @@
# ruff: noqa: SIM117
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Optional
import openvino as ov
import torch
@ -12,8 +12,8 @@ from optimum.intel import OVModelForCausalLM
from torch import nn
import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states)
@ -24,7 +24,7 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
def _flattenize_inputs(inputs):
def _flatten_inputs(inputs):
"""
Helper function for making nested inputs flattens
"""
@ -33,10 +33,9 @@ def _flattenize_inputs(inputs):
if input_data is None:
continue
if isinstance(input_data, (list, tuple)):
flatten_inputs.extend(_flattenize_inputs(input_data))
flatten_inputs.extend(_flatten_inputs(input_data))
elif isinstance(input_data, dict):
flatten_inputs.extend(_flattenize_inputs(list(
input_data.values())))
flatten_inputs.extend(_flatten_inputs(list(input_data.values())))
else:
flatten_inputs.append(input_data)
return flatten_inputs
@ -147,15 +146,15 @@ class OpenVINOCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
attn_metadata: OpenVINOAttentionMetadata,
kv_caches: list[tuple[ov.Tensor, ov.Tensor]],
) -> torch.Tensor:
flatten_kv_cache = _flattenize_inputs(kv_caches)
flat_kv_caches = _flatten_inputs(kv_caches)
attn_metadata = get_forward_context().attn_metadata
inputs = [
input_ids,
positions,
*flatten_kv_cache,
*flat_kv_caches,
attn_metadata.past_lens,
attn_metadata.subsequence_begins,
attn_metadata.block_indices,

View File

@ -346,6 +346,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
input_tokens,
"positions":
input_positions,
"kv_caches":
kv_caches,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device),
}