[Hardware][openvino] is_openvino --> current_platform.is_openvino (#9716)

This commit is contained in:
Mengqing Cao 2024-10-26 18:59:06 +08:00 committed by GitHub
parent 067e77f9a8
commit 5cbdccd151
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 69 additions and 38 deletions

View File

@ -30,7 +30,8 @@ def test_env(name: str, device: str, monkeypatch):
False) False)
assert backend.name == "ROCM_FLASH" assert backend.name == "ROCM_FLASH"
elif device == "openvino": elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True): with patch("vllm.attention.selector.current_platform.is_openvino",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False) False)
assert backend.name == "OPENVINO" assert backend.name == "OPENVINO"

View File

@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino from vllm.utils import STR_BACKEND_ENV_VAR, is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
@ -193,7 +193,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
if is_openvino(): if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO: if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend) logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO return _Backend.OPENVINO

View File

@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_openvino, print_warning_once) is_hip, print_warning_once)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
@ -1117,7 +1117,7 @@ class DeviceConfig:
self.device_type = "cuda" self.device_type = "cuda"
elif current_platform.is_neuron(): elif current_platform.is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_openvino(): elif current_platform.is_openvino():
self.device_type = "openvino" self.device_type = "openvino"
elif current_platform.is_tpu(): elif current_platform.is_tpu():
self.device_type = "tpu" self.device_type = "tpu"

View File

@ -10,6 +10,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
get_open_port, make_async) get_open_port, make_async)
@ -17,14 +18,6 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
logger = init_logger(__name__) logger = init_logger(__name__)
def is_openvino_cpu() -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
def is_openvino_gpu() -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
class OpenVINOExecutor(ExecutorBase): class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False uses_ray: bool = False
@ -32,7 +25,8 @@ class OpenVINOExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino" assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
assert is_openvino_cpu() or is_openvino_gpu(), \ assert current_platform.is_openvino_cpu() or \
current_platform.is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices" "OpenVINO backend supports only CPU and GPU devices"
self.ov_core = ov.Core() self.ov_core = ov.Core()
@ -163,7 +157,7 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
def _verify_and_get_cache_config(ov_core: ov.Core, def _verify_and_get_cache_config(ov_core: ov.Core,
config: CacheConfig) -> CacheConfig: config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
if not is_openvino_cpu(): if not current_platform.is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used.") "ignored for GPU, f16 data type will be used.")
config.cache_dtype = ov.Type.f16 config.cache_dtype = ov.Type.f16
@ -172,7 +166,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8 config.cache_dtype = ov.Type.u8
else: else:
if is_openvino_cpu(): if current_platform.is_openvino_cpu():
ov_device = envs.VLLM_OPENVINO_DEVICE ov_device = envs.VLLM_OPENVINO_DEVICE
inference_precision = ov_core.get_property( inference_precision = ov_core.get_property(
ov_device, hints.inference_precision) ov_device, hints.inference_precision)
@ -183,7 +177,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
else: else:
config.cache_dtype = ov.Type.f16 config.cache_dtype = ov.Type.f16
if is_openvino_cpu(): if current_platform.is_openvino_cpu():
if config.block_size != 32: if config.block_size != 32:
logger.info( logger.info(
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
@ -198,7 +192,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0: if kv_cache_space >= 0:
if kv_cache_space == 0 and is_openvino_cpu(): if kv_cache_space == 0 and current_platform.is_openvino_cpu():
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning( logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "

View File

@ -12,12 +12,12 @@ from torch import nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig from vllm.config import DeviceConfig, ModelConfig
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor, from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states) _prune_hidden_states)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -136,7 +136,7 @@ class OpenVINOCasualLM(nn.Module):
ov_device = envs.VLLM_OPENVINO_DEVICE ov_device = envs.VLLM_OPENVINO_DEVICE
paged_attention_transformation(pt_model.model) paged_attention_transformation(pt_model.model)
_modify_cache_parameters(pt_model.model, kv_cache_dtype, _modify_cache_parameters(pt_model.model, kv_cache_dtype,
is_openvino_cpu()) current_platform.is_openvino_cpu())
ov_compiled = ov_core.compile_model(pt_model.model, ov_device) ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
self.ov_request = ov_compiled.create_infer_request() self.ov_request = ov_compiled.create_infer_request()

View File

@ -65,6 +65,13 @@ try:
except ImportError: except ImportError:
pass pass
is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass
if is_tpu: if is_tpu:
# people might install pytorch built with cuda but run on tpu # people might install pytorch built with cuda but run on tpu
# so we need to check tpu first # so we need to check tpu first
@ -85,6 +92,9 @@ elif is_cpu:
elif is_neuron: elif is_neuron:
from .neuron import NeuronPlatform from .neuron import NeuronPlatform
current_platform = NeuronPlatform() current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else: else:
current_platform = UnspecifiedPlatform() current_platform = UnspecifiedPlatform()

View File

@ -11,6 +11,7 @@ class PlatformEnum(enum.Enum):
XPU = enum.auto() XPU = enum.auto()
CPU = enum.auto() CPU = enum.auto()
NEURON = enum.auto() NEURON = enum.auto()
OPENVINO = enum.auto()
UNSPECIFIED = enum.auto() UNSPECIFIED = enum.auto()
@ -52,6 +53,9 @@ class Platform:
def is_neuron(self) -> bool: def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON return self._enum == PlatformEnum.NEURON
def is_openvino(self) -> bool:
return self._enum == PlatformEnum.OPENVINO
def is_cuda_alike(self) -> bool: def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`.""" """Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

View File

@ -0,0 +1,31 @@
import torch
import vllm.envs as envs
from vllm.utils import print_warning_once
from .interface import Platform, PlatformEnum
class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
@classmethod
def get_device_name(self, device_id: int = 0) -> str:
return "openvino"
@classmethod
def inference_mode(self):
return torch.inference_mode(mode=True)
@classmethod
def is_openvino_cpu(self) -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
@classmethod
def is_openvino_gpu(self) -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
@classmethod
def is_pin_memory_available(self) -> bool:
print_warning_once("Pin memory is not supported on OpenViNO.")
return False

View File

@ -318,15 +318,6 @@ def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
@lru_cache(maxsize=None)
def is_openvino() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "openvino" in version("vllm")
except PackageNotFoundError:
return False
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
@ -757,7 +748,7 @@ def is_pin_memory_available() -> bool:
elif current_platform.is_neuron(): elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.") print_warning_once("Pin memory is not supported on Neuron.")
return False return False
elif current_platform.is_cpu() or is_openvino(): elif current_platform.is_cpu() or current_platform.is_openvino():
return False return False
return True return True

View File

@ -13,12 +13,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.openvino_model_runner import OpenVINOModelRunner
@ -99,7 +99,7 @@ class OpenVINOCacheEngine:
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
if is_openvino_cpu(): if current_platform.is_openvino_cpu():
for _ in range(self.num_layers): for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype, key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape) k_block_shape)
@ -141,7 +141,7 @@ class OpenVINOCacheEngine:
if num_blocks == 0: if num_blocks == 0:
return swap_cache return swap_cache
assert not is_openvino_cpu(), \ assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to have swap cache" "CPU device isn't supposed to have swap cache"
# Update key_cache shape: # Update key_cache shape:
@ -285,7 +285,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
cache_block_size = self.get_cache_block_size_bytes() cache_block_size = self.get_cache_block_size_bytes()
kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
if is_openvino_cpu(): if current_platform.is_openvino_cpu():
num_device_blocks = int(kvcache_space_bytes // cache_block_size) num_device_blocks = int(kvcache_space_bytes // cache_block_size)
num_swap_blocks = 0 num_swap_blocks = 0
else: else:
@ -322,7 +322,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
num_device_blocks = num_gpu_blocks num_device_blocks = num_gpu_blocks
num_swap_blocks = num_cpu_blocks num_swap_blocks = num_cpu_blocks
if is_openvino_cpu(): if current_platform.is_openvino_cpu():
assert (num_swap_blocks == 0 assert (num_swap_blocks == 0
), f"{type(self)} does not support swappable cache for CPU" ), f"{type(self)} does not support swappable cache for CPU"
@ -366,7 +366,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
assert self.kv_cache is not None assert self.kv_cache is not None
# Populate the cache to warmup the memory # Populate the cache to warmup the memory
if is_openvino_cpu(): if current_platform.is_openvino_cpu():
for key_cache, value_cache in self.kv_cache: for key_cache, value_cache in self.kv_cache:
key_cache.data[:] = 0 key_cache.data[:] = 0
value_cache.data[:] = 0 value_cache.data[:] = 0
@ -414,7 +414,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in = data["blocks_to_swap_in"] blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_swap_out = data["blocks_to_swap_out"]
if is_openvino_cpu(): if current_platform.is_openvino_cpu():
assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0
else: else:
@ -466,7 +466,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def profile_run(self) -> int: def profile_run(self) -> int:
ov_device = envs.VLLM_OPENVINO_DEVICE ov_device = envs.VLLM_OPENVINO_DEVICE
assert not is_openvino_cpu(), \ assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to use profile run." "CPU device isn't supposed to use profile run."
import openvino.properties.device as device import openvino.properties.device as device