mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 16:35:02 +08:00
Fix missing kv_caches and attn_metadata in OpenVINOCausalLM (#14271)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
0ca3b8e01c
commit
f7a6bd0fa1
@ -2,7 +2,7 @@
|
||||
|
||||
# ruff: noqa: SIM117
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import openvino as ov
|
||||
import torch
|
||||
@ -12,8 +12,8 @@ from optimum.intel import OVModelForCausalLM
|
||||
from torch import nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
||||
_prune_hidden_states)
|
||||
@ -24,7 +24,7 @@ from vllm.platforms import current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _flattenize_inputs(inputs):
|
||||
def _flatten_inputs(inputs):
|
||||
"""
|
||||
Helper function for making nested inputs flattens
|
||||
"""
|
||||
@ -33,10 +33,9 @@ def _flattenize_inputs(inputs):
|
||||
if input_data is None:
|
||||
continue
|
||||
if isinstance(input_data, (list, tuple)):
|
||||
flatten_inputs.extend(_flattenize_inputs(input_data))
|
||||
flatten_inputs.extend(_flatten_inputs(input_data))
|
||||
elif isinstance(input_data, dict):
|
||||
flatten_inputs.extend(_flattenize_inputs(list(
|
||||
input_data.values())))
|
||||
flatten_inputs.extend(_flatten_inputs(list(input_data.values())))
|
||||
else:
|
||||
flatten_inputs.append(input_data)
|
||||
return flatten_inputs
|
||||
@ -147,15 +146,15 @@ class OpenVINOCausalLM(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
|
||||
attn_metadata: OpenVINOAttentionMetadata,
|
||||
kv_caches: list[tuple[ov.Tensor, ov.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
flatten_kv_cache = _flattenize_inputs(kv_caches)
|
||||
flat_kv_caches = _flatten_inputs(kv_caches)
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
inputs = [
|
||||
input_ids,
|
||||
positions,
|
||||
*flatten_kv_cache,
|
||||
*flat_kv_caches,
|
||||
attn_metadata.past_lens,
|
||||
attn_metadata.subsequence_begins,
|
||||
attn_metadata.block_indices,
|
||||
|
||||
@ -346,6 +346,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
||||
input_tokens,
|
||||
"positions":
|
||||
input_positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user