Fix missing kv_caches and attn_metadata in OpenVINOCausalLM (#14271)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-07 13:30:42 +01:00 committed by GitHub
parent 0ca3b8e01c
commit f7a6bd0fa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 10 deletions

View File

@ -2,7 +2,7 @@
# ruff: noqa: SIM117 # ruff: noqa: SIM117
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import Optional
import openvino as ov import openvino as ov
import torch import torch
@ -12,8 +12,8 @@ from optimum.intel import OVModelForCausalLM
from torch import nn from torch import nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor, from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states) _prune_hidden_states)
@ -24,7 +24,7 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
def _flattenize_inputs(inputs): def _flatten_inputs(inputs):
""" """
Helper function for making nested inputs flattens Helper function for making nested inputs flattens
""" """
@ -33,10 +33,9 @@ def _flattenize_inputs(inputs):
if input_data is None: if input_data is None:
continue continue
if isinstance(input_data, (list, tuple)): if isinstance(input_data, (list, tuple)):
flatten_inputs.extend(_flattenize_inputs(input_data)) flatten_inputs.extend(_flatten_inputs(input_data))
elif isinstance(input_data, dict): elif isinstance(input_data, dict):
flatten_inputs.extend(_flattenize_inputs(list( flatten_inputs.extend(_flatten_inputs(list(input_data.values())))
input_data.values())))
else: else:
flatten_inputs.append(input_data) flatten_inputs.append(input_data)
return flatten_inputs return flatten_inputs
@ -147,15 +146,15 @@ class OpenVINOCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], kv_caches: list[tuple[ov.Tensor, ov.Tensor]],
attn_metadata: OpenVINOAttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
flatten_kv_cache = _flattenize_inputs(kv_caches) flat_kv_caches = _flatten_inputs(kv_caches)
attn_metadata = get_forward_context().attn_metadata
inputs = [ inputs = [
input_ids, input_ids,
positions, positions,
*flatten_kv_cache, *flat_kv_caches,
attn_metadata.past_lens, attn_metadata.past_lens,
attn_metadata.subsequence_begins, attn_metadata.subsequence_begins,
attn_metadata.block_indices, attn_metadata.block_indices,

View File

@ -346,6 +346,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
input_tokens, input_tokens,
"positions": "positions":
input_positions, input_positions,
"kv_caches":
kv_caches,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device), device=self.device),
} }