diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 8bcee9840377..df3e770e260e 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -30,7 +30,8 @@ def test_env(name: str, device: str, monkeypatch): False) assert backend.name == "ROCM_FLASH" elif device == "openvino": - with patch("vllm.attention.selector.is_openvino", return_value=True): + with patch("vllm.attention.selector.current_platform.is_openvino", + return_value=True): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "OPENVINO" diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index cd3c642b8c8a..10d4509b3827 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino +from vllm.utils import STR_BACKEND_ENV_VAR, is_hip logger = init_logger(__name__) @@ -193,7 +193,7 @@ def which_attn_to_use( logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA - if is_openvino(): + if current_platform.is_openvino(): if selected_backend != _Backend.OPENVINO: logger.info("Cannot use %s backend on OpenVINO.", selected_backend) return _Backend.OPENVINO diff --git a/vllm/config.py b/vllm/config.py index 25f841231ded..a1fba98233b8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_hip, is_openvino, print_warning_once) + is_hip, print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -1117,7 +1117,7 @@ class DeviceConfig: self.device_type = "cuda" elif current_platform.is_neuron(): self.device_type = "neuron" - elif is_openvino(): + elif current_platform.is_openvino(): self.device_type = "openvino" elif current_platform.is_tpu(): self.device_type = "tpu" diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 4a39839a0319..d0c0333854da 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -10,6 +10,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, get_open_port, make_async) @@ -17,14 +18,6 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, logger = init_logger(__name__) -def is_openvino_cpu() -> bool: - return "CPU" in envs.VLLM_OPENVINO_DEVICE - - -def is_openvino_gpu() -> bool: - return "GPU" in envs.VLLM_OPENVINO_DEVICE - - class OpenVINOExecutor(ExecutorBase): uses_ray: bool = False @@ -32,7 +25,8 @@ class OpenVINOExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "openvino" assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" - assert is_openvino_cpu() or is_openvino_gpu(), \ + assert current_platform.is_openvino_cpu() or \ + current_platform.is_openvino_gpu(), \ "OpenVINO backend supports only CPU and GPU devices" self.ov_core = ov.Core() @@ -163,7 +157,7 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: def _verify_and_get_cache_config(ov_core: ov.Core, config: CacheConfig) -> CacheConfig: if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": - if not is_openvino_cpu(): + if not current_platform.is_openvino_cpu(): logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" "ignored for GPU, f16 data type will be used.") config.cache_dtype = ov.Type.f16 @@ -172,7 +166,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core, "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") config.cache_dtype = ov.Type.u8 else: - if is_openvino_cpu(): + if current_platform.is_openvino_cpu(): ov_device = envs.VLLM_OPENVINO_DEVICE inference_precision = ov_core.get_property( ov_device, hints.inference_precision) @@ -183,7 +177,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core, else: config.cache_dtype = ov.Type.f16 - if is_openvino_cpu(): + if current_platform.is_openvino_cpu(): if config.block_size != 32: logger.info( f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 @@ -198,7 +192,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core, kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE if kv_cache_space >= 0: - if kv_cache_space == 0 and is_openvino_cpu(): + if kv_cache_space == 0 and current_platform.is_openvino_cpu(): config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore logger.warning( "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index 88b7ac46e554..8ada2210d0d5 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -12,12 +12,12 @@ from torch import nn import vllm.envs as envs from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.config import DeviceConfig, ModelConfig -from vllm.executor.openvino_executor import is_openvino_cpu from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import (LogitsProcessor, _prune_hidden_states) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -136,7 +136,7 @@ class OpenVINOCasualLM(nn.Module): ov_device = envs.VLLM_OPENVINO_DEVICE paged_attention_transformation(pt_model.model) _modify_cache_parameters(pt_model.model, kv_cache_dtype, - is_openvino_cpu()) + current_platform.is_openvino_cpu()) ov_compiled = ov_core.compile_model(pt_model.model, ov_device) self.ov_request = ov_compiled.create_infer_request() diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 58912158139b..7e9f8b1297b8 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -65,6 +65,13 @@ try: except ImportError: pass +is_openvino = False +try: + from importlib.metadata import version + is_openvino = "openvino" in version("vllm") +except Exception: + pass + if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first @@ -85,6 +92,9 @@ elif is_cpu: elif is_neuron: from .neuron import NeuronPlatform current_platform = NeuronPlatform() +elif is_openvino: + from .openvino import OpenVinoPlatform + current_platform = OpenVinoPlatform() else: current_platform = UnspecifiedPlatform() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d36367f2bc9c..7c933385d6ff 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -11,6 +11,7 @@ class PlatformEnum(enum.Enum): XPU = enum.auto() CPU = enum.auto() NEURON = enum.auto() + OPENVINO = enum.auto() UNSPECIFIED = enum.auto() @@ -52,6 +53,9 @@ class Platform: def is_neuron(self) -> bool: return self._enum == PlatformEnum.NEURON + def is_openvino(self) -> bool: + return self._enum == PlatformEnum.OPENVINO + def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py new file mode 100644 index 000000000000..35dbe22abf7f --- /dev/null +++ b/vllm/platforms/openvino.py @@ -0,0 +1,31 @@ +import torch + +import vllm.envs as envs +from vllm.utils import print_warning_once + +from .interface import Platform, PlatformEnum + + +class OpenVinoPlatform(Platform): + _enum = PlatformEnum.OPENVINO + + @classmethod + def get_device_name(self, device_id: int = 0) -> str: + return "openvino" + + @classmethod + def inference_mode(self): + return torch.inference_mode(mode=True) + + @classmethod + def is_openvino_cpu(self) -> bool: + return "CPU" in envs.VLLM_OPENVINO_DEVICE + + @classmethod + def is_openvino_gpu(self) -> bool: + return "GPU" in envs.VLLM_OPENVINO_DEVICE + + @classmethod + def is_pin_memory_available(self) -> bool: + print_warning_once("Pin memory is not supported on OpenViNO.") + return False diff --git a/vllm/utils.py b/vllm/utils.py index 0e9b241b6f9f..fba9804289b9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -318,15 +318,6 @@ def is_hip() -> bool: return torch.version.hip is not None -@lru_cache(maxsize=None) -def is_openvino() -> bool: - from importlib.metadata import PackageNotFoundError, version - try: - return "openvino" in version("vllm") - except PackageNotFoundError: - return False - - @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" @@ -757,7 +748,7 @@ def is_pin_memory_available() -> bool: elif current_platform.is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False - elif current_platform.is_cpu() or is_openvino(): + elif current_platform.is_cpu() or current_platform.is_openvino(): return False return True diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index bc245d19663d..a420d390c1ae 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -13,12 +13,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) -from vllm.executor.openvino_executor import is_openvino_cpu from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.worker.openvino_model_runner import OpenVINOModelRunner @@ -99,7 +99,7 @@ class OpenVINOCacheEngine: num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] - if is_openvino_cpu(): + if current_platform.is_openvino_cpu(): for _ in range(self.num_layers): key_blocks = ov.Tensor(self.cache_config.cache_dtype, k_block_shape) @@ -141,7 +141,7 @@ class OpenVINOCacheEngine: if num_blocks == 0: return swap_cache - assert not is_openvino_cpu(), \ + assert not current_platform.is_openvino_cpu(), \ "CPU device isn't supposed to have swap cache" # Update key_cache shape: @@ -285,7 +285,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): cache_block_size = self.get_cache_block_size_bytes() kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes - if is_openvino_cpu(): + if current_platform.is_openvino_cpu(): num_device_blocks = int(kvcache_space_bytes // cache_block_size) num_swap_blocks = 0 else: @@ -322,7 +322,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): num_device_blocks = num_gpu_blocks num_swap_blocks = num_cpu_blocks - if is_openvino_cpu(): + if current_platform.is_openvino_cpu(): assert (num_swap_blocks == 0 ), f"{type(self)} does not support swappable cache for CPU" @@ -366,7 +366,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): assert self.kv_cache is not None # Populate the cache to warmup the memory - if is_openvino_cpu(): + if current_platform.is_openvino_cpu(): for key_cache, value_cache in self.kv_cache: key_cache.data[:] = 0 value_cache.data[:] = 0 @@ -414,7 +414,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): blocks_to_swap_in = data["blocks_to_swap_in"] blocks_to_swap_out = data["blocks_to_swap_out"] - if is_openvino_cpu(): + if current_platform.is_openvino_cpu(): assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0 else: @@ -466,7 +466,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): def profile_run(self) -> int: ov_device = envs.VLLM_OPENVINO_DEVICE - assert not is_openvino_cpu(), \ + assert not current_platform.is_openvino_cpu(), \ "CPU device isn't supposed to use profile run." import openvino.properties.device as device