[Platforms] Refactor openvino code (#10573)

Signed-off-by: statelesshz <hzji210@gmail.com>
This commit is contained in:
JiHuazhong 2024-11-23 14:23:12 +08:00 committed by GitHub
parent 4cfe5d2bca
commit 86a44fb896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 78 deletions

View File

@ -1,19 +1,16 @@
from typing import List, Set, Tuple
import openvino as ov
import openvino.properties.hint as hints
import torch
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
get_open_port, make_async)
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -30,11 +27,6 @@ class OpenVINOExecutor(ExecutorBase):
current_platform.is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"
self.ov_core = ov.Core()
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(
self.ov_core, self.cache_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
@ -45,7 +37,7 @@ class OpenVINOExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = wrapper.init_worker(
ov_core=self.ov_core,
ov_core=ov.Core(),
vllm_config=self.vllm_config,
local_rank=0,
rank=0,
@ -130,70 +122,3 @@ class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
# OpenVINOExecutor will always be healthy as long as
# it's running.
return
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype != torch.float32:
logger.warning(
f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501
)
config.dtype = torch.float32
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on OpenVINO backend, fallback to the "
"eager mode.")
config.enforce_eager = True
return config
def _verify_and_get_cache_config(ov_core: ov.Core,
config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
if not current_platform.is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used.")
config.cache_dtype = ov.Type.f16
else:
logger.info("KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
else:
if current_platform.is_openvino_cpu():
ov_device = envs.VLLM_OPENVINO_DEVICE
inference_precision = ov_core.get_property(
ov_device, hints.inference_precision)
if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
else:
config.cache_dtype = ov.Type.f16
else:
config.cache_dtype = ov.Type.f16
if current_platform.is_openvino_cpu():
if config.block_size != 32:
logger.info(
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 32
else:
if config.block_size != 16:
logger.info(
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 16
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0 and current_platform.is_openvino_cpu():
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default.")
else:
config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
else:
raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")
return config

View File

@ -1,5 +1,7 @@
from typing import TYPE_CHECKING
import openvino as ov
import openvino.properties.hint as hints
import torch
import vllm.envs as envs
@ -49,6 +51,8 @@ class OpenVinoPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.utils import GiB_bytes
parallel_config = vllm_config.parallel_config
assert (
parallel_config.world_size == 1
@ -57,3 +61,68 @@ class OpenVinoPlatform(Platform):
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.openvino_worker.OpenVINOWorker"
# check and update model config
model_config = vllm_config.model_config
if model_config.dtype != torch.float32:
logger.warning(
f"Only float32 dtype is supported on OpenVINO, casting from {model_config.dtype}." # noqa: G004, E501
)
model_config.dtype = torch.float32
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on OpenVINO backend, fallback to "
"the eager mode.")
model_config.enforce_eager = True
# check and update cache config
ov_core = ov.Core()
cache_config = vllm_config.cache_config
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
if not OpenVinoPlatform.is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used.")
cache_config.cache_dtype = ov.Type.f16
else:
logger.info("KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
cache_config.cache_dtype = ov.Type.u8
else:
if OpenVinoPlatform.is_openvino_cpu():
ov_device = envs.VLLM_OPENVINO_DEVICE
inference_precision = ov_core.get_property(
ov_device, hints.inference_precision)
if inference_precision == ov.Type.bf16:
cache_config.cache_dtype = ov.Type.bf16
else:
cache_config.cache_dtype = ov.Type.f16
else:
cache_config.cache_dtype = ov.Type.f16
if OpenVinoPlatform.is_openvino_cpu():
if cache_config.block_size != 32:
logger.info(
f"OpenVINO CPU optimal block size is 32, overriding currently set {cache_config.block_size}" # noqa: G004, E501
)
cache_config.block_size = 32
else:
if cache_config.block_size != 16:
logger.info(
f"OpenVINO GPU optimal block size is 16, overriding currently set {cache_config.block_size}" # noqa: G004, E501
)
cache_config.block_size = 16
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0 and OpenVinoPlatform.is_openvino_cpu():
cache_config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default.")
else:
cache_config.openvino_kvcache_space_bytes = ( # type: ignore
kv_cache_space * GiB_bytes)
else:
raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")