mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:26:07 +08:00
[Platforms] Refactor openvino code (#10573)
Signed-off-by: statelesshz <hzji210@gmail.com>
This commit is contained in:
parent
4cfe5d2bca
commit
86a44fb896
@ -1,19 +1,16 @@
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import openvino as ov
|
||||
import openvino.properties.hint as hints
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CacheConfig, ModelConfig
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
|
||||
get_open_port, make_async)
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -30,11 +27,6 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
current_platform.is_openvino_gpu(), \
|
||||
"OpenVINO backend supports only CPU and GPU devices"
|
||||
|
||||
self.ov_core = ov.Core()
|
||||
self.model_config = _verify_and_get_model_config(self.model_config)
|
||||
self.cache_config = _verify_and_get_cache_config(
|
||||
self.ov_core, self.cache_config)
|
||||
|
||||
# Instantiate the worker and load the model to CPU.
|
||||
self._init_worker()
|
||||
|
||||
@ -45,7 +37,7 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = wrapper.init_worker(
|
||||
ov_core=self.ov_core,
|
||||
ov_core=ov.Core(),
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
@ -130,70 +122,3 @@ class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
|
||||
# OpenVINOExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
|
||||
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
||||
if config.dtype != torch.float32:
|
||||
logger.warning(
|
||||
f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501
|
||||
)
|
||||
config.dtype = torch.float32
|
||||
if not config.enforce_eager:
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on OpenVINO backend, fallback to the "
|
||||
"eager mode.")
|
||||
config.enforce_eager = True
|
||||
return config
|
||||
|
||||
|
||||
def _verify_and_get_cache_config(ov_core: ov.Core,
|
||||
config: CacheConfig) -> CacheConfig:
|
||||
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
|
||||
if not current_platform.is_openvino_cpu():
|
||||
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
|
||||
"ignored for GPU, f16 data type will be used.")
|
||||
config.cache_dtype = ov.Type.f16
|
||||
else:
|
||||
logger.info("KV cache type is overridden to u8 via "
|
||||
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
|
||||
config.cache_dtype = ov.Type.u8
|
||||
else:
|
||||
if current_platform.is_openvino_cpu():
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
inference_precision = ov_core.get_property(
|
||||
ov_device, hints.inference_precision)
|
||||
if inference_precision == ov.Type.bf16:
|
||||
config.cache_dtype = ov.Type.bf16
|
||||
else:
|
||||
config.cache_dtype = ov.Type.f16
|
||||
else:
|
||||
config.cache_dtype = ov.Type.f16
|
||||
|
||||
if current_platform.is_openvino_cpu():
|
||||
if config.block_size != 32:
|
||||
logger.info(
|
||||
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
|
||||
)
|
||||
config.block_size = 32
|
||||
else:
|
||||
if config.block_size != 16:
|
||||
logger.info(
|
||||
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
|
||||
)
|
||||
config.block_size = 16
|
||||
|
||||
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
|
||||
if kv_cache_space >= 0:
|
||||
if kv_cache_space == 0 and current_platform.is_openvino_cpu():
|
||||
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
||||
logger.warning(
|
||||
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
|
||||
"for OpenVINO backend is not set, using 4 by default.")
|
||||
else:
|
||||
config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
|
||||
f" {kv_cache_space}, expect a positive integer value.")
|
||||
|
||||
return config
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import openvino as ov
|
||||
import openvino.properties.hint as hints
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -49,6 +51,8 @@ class OpenVinoPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
assert (
|
||||
parallel_config.world_size == 1
|
||||
@ -57,3 +61,68 @@ class OpenVinoPlatform(Platform):
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.openvino_worker.OpenVINOWorker"
|
||||
|
||||
# check and update model config
|
||||
model_config = vllm_config.model_config
|
||||
if model_config.dtype != torch.float32:
|
||||
logger.warning(
|
||||
f"Only float32 dtype is supported on OpenVINO, casting from {model_config.dtype}." # noqa: G004, E501
|
||||
)
|
||||
model_config.dtype = torch.float32
|
||||
if not model_config.enforce_eager:
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on OpenVINO backend, fallback to "
|
||||
"the eager mode.")
|
||||
model_config.enforce_eager = True
|
||||
|
||||
# check and update cache config
|
||||
ov_core = ov.Core()
|
||||
cache_config = vllm_config.cache_config
|
||||
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
|
||||
if not OpenVinoPlatform.is_openvino_cpu():
|
||||
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
|
||||
"ignored for GPU, f16 data type will be used.")
|
||||
cache_config.cache_dtype = ov.Type.f16
|
||||
else:
|
||||
logger.info("KV cache type is overridden to u8 via "
|
||||
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
|
||||
cache_config.cache_dtype = ov.Type.u8
|
||||
else:
|
||||
if OpenVinoPlatform.is_openvino_cpu():
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
inference_precision = ov_core.get_property(
|
||||
ov_device, hints.inference_precision)
|
||||
if inference_precision == ov.Type.bf16:
|
||||
cache_config.cache_dtype = ov.Type.bf16
|
||||
else:
|
||||
cache_config.cache_dtype = ov.Type.f16
|
||||
else:
|
||||
cache_config.cache_dtype = ov.Type.f16
|
||||
|
||||
if OpenVinoPlatform.is_openvino_cpu():
|
||||
if cache_config.block_size != 32:
|
||||
logger.info(
|
||||
f"OpenVINO CPU optimal block size is 32, overriding currently set {cache_config.block_size}" # noqa: G004, E501
|
||||
)
|
||||
cache_config.block_size = 32
|
||||
else:
|
||||
if cache_config.block_size != 16:
|
||||
logger.info(
|
||||
f"OpenVINO GPU optimal block size is 16, overriding currently set {cache_config.block_size}" # noqa: G004, E501
|
||||
)
|
||||
cache_config.block_size = 16
|
||||
|
||||
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
|
||||
if kv_cache_space >= 0:
|
||||
if kv_cache_space == 0 and OpenVinoPlatform.is_openvino_cpu():
|
||||
cache_config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
||||
logger.warning(
|
||||
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
|
||||
"for OpenVINO backend is not set, using 4 by default.")
|
||||
else:
|
||||
cache_config.openvino_kvcache_space_bytes = ( # type: ignore
|
||||
kv_cache_space * GiB_bytes)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
|
||||
f" {kv_cache_space}, expect a positive integer value.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user