mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 06:14:25 +08:00
[Bugfix] Fix OpenVINO model runner (#12750)
This commit is contained in:
parent
58b218d7ae
commit
fcf2e3d7fc
@ -140,3 +140,7 @@ class OpenVINOAttentionMetadata:
|
|||||||
# `model_executable`.
|
# `model_executable`.
|
||||||
multi_modal_placeholder_index_maps: Optional[Dict[
|
multi_modal_placeholder_index_maps: Optional[Dict[
|
||||||
str, MultiModalPlaceholderMap.IndexMap]]
|
str, MultiModalPlaceholderMap.IndexMap]]
|
||||||
|
|
||||||
|
# Enable/disable KV scales calculation. This is so that we can disable the
|
||||||
|
# calculation until after prefill and cuda graph capture.
|
||||||
|
enable_kv_scales_calculation: bool
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from torch import nn
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||||
from vllm.config import DeviceConfig, ModelConfig
|
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
||||||
_prune_hidden_states)
|
_prune_hidden_states)
|
||||||
@ -103,7 +103,6 @@ class OpenVINOCausalLM(nn.Module):
|
|||||||
self,
|
self,
|
||||||
ov_core: ov.Core,
|
ov_core: ov.Core,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
|
||||||
kv_cache_dtype: ov.Type,
|
kv_cache_dtype: ov.Type,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -187,8 +186,7 @@ class OpenVINOCausalLM(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_config: ModelConfig,
|
vllm_config: VllmConfig,
|
||||||
device_config: DeviceConfig,
|
|
||||||
kv_cache_dtype: ov.Type,
|
kv_cache_dtype: ov.Type,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
@ -201,5 +199,6 @@ def get_model(
|
|||||||
"be added in the future. If this is important to you, "
|
"be added in the future. If this is important to you, "
|
||||||
"please open an issue on github.")
|
"please open an issue on github.")
|
||||||
|
|
||||||
return OpenVINOCausalLM(ov_core, model_config, device_config,
|
with set_current_vllm_config(vllm_config):
|
||||||
kv_cache_dtype)
|
return OpenVINOCausalLM(ov_core, vllm_config.model_config,
|
||||||
|
kv_cache_dtype)
|
||||||
|
|||||||
@ -54,15 +54,13 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
|||||||
):
|
):
|
||||||
self.ov_core = ov_core
|
self.ov_core = ov_core
|
||||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||||
cache_config = self.cache_config
|
|
||||||
model_config = self.model_config
|
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.sliding_window = model_config.get_sliding_window()
|
self.sliding_window = self.model_config.get_sliding_window()
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = self.cache_config.block_size
|
||||||
|
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
@ -81,8 +79,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
|||||||
self.model: nn.Module # Set after init_Model
|
self.model: nn.Module # Set after init_Model
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_model(model_config=self.model_config,
|
self.model = get_model(vllm_config=self.vllm_config,
|
||||||
device_config=self.device_config,
|
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
ov_core=self.ov_core)
|
ov_core=self.ov_core)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user