mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 04:45:02 +08:00
Fix missing kv_caches and attn_metadata in OpenVINOCausalLM (#14271)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
0ca3b8e01c
commit
f7a6bd0fa1
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# ruff: noqa: SIM117
|
# ruff: noqa: SIM117
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import openvino as ov
|
import openvino as ov
|
||||||
import torch
|
import torch
|
||||||
@ -12,8 +12,8 @@ from optimum.intel import OVModelForCausalLM
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
|
||||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
||||||
_prune_hidden_states)
|
_prune_hidden_states)
|
||||||
@ -24,7 +24,7 @@ from vllm.platforms import current_platform
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _flattenize_inputs(inputs):
|
def _flatten_inputs(inputs):
|
||||||
"""
|
"""
|
||||||
Helper function for making nested inputs flattens
|
Helper function for making nested inputs flattens
|
||||||
"""
|
"""
|
||||||
@ -33,10 +33,9 @@ def _flattenize_inputs(inputs):
|
|||||||
if input_data is None:
|
if input_data is None:
|
||||||
continue
|
continue
|
||||||
if isinstance(input_data, (list, tuple)):
|
if isinstance(input_data, (list, tuple)):
|
||||||
flatten_inputs.extend(_flattenize_inputs(input_data))
|
flatten_inputs.extend(_flatten_inputs(input_data))
|
||||||
elif isinstance(input_data, dict):
|
elif isinstance(input_data, dict):
|
||||||
flatten_inputs.extend(_flattenize_inputs(list(
|
flatten_inputs.extend(_flatten_inputs(list(input_data.values())))
|
||||||
input_data.values())))
|
|
||||||
else:
|
else:
|
||||||
flatten_inputs.append(input_data)
|
flatten_inputs.append(input_data)
|
||||||
return flatten_inputs
|
return flatten_inputs
|
||||||
@ -147,15 +146,15 @@ class OpenVINOCausalLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
|
kv_caches: list[tuple[ov.Tensor, ov.Tensor]],
|
||||||
attn_metadata: OpenVINOAttentionMetadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
flatten_kv_cache = _flattenize_inputs(kv_caches)
|
flat_kv_caches = _flatten_inputs(kv_caches)
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
*flatten_kv_cache,
|
*flat_kv_caches,
|
||||||
attn_metadata.past_lens,
|
attn_metadata.past_lens,
|
||||||
attn_metadata.subsequence_begins,
|
attn_metadata.subsequence_begins,
|
||||||
attn_metadata.block_indices,
|
attn_metadata.block_indices,
|
||||||
|
|||||||
@ -346,6 +346,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
|||||||
input_tokens,
|
input_tokens,
|
||||||
"positions":
|
"positions":
|
||||||
input_positions,
|
input_positions,
|
||||||
|
"kv_caches":
|
||||||
|
kv_caches,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
|
||||||
device=self.device),
|
device=self.device),
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user