[Hardware][openvino] is_openvino --> current_platform.is_openvino (#9716)

This commit is contained in:
Mengqing Cao 2024-10-26 18:59:06 +08:00 committed by GitHub
parent 067e77f9a8
commit 5cbdccd151
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 69 additions and 38 deletions

View File

@ -30,7 +30,8 @@ def test_env(name: str, device: str, monkeypatch):
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
with patch("vllm.attention.selector.current_platform.is_openvino",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"

View File

@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip
logger = init_logger(__name__)
@ -193,7 +193,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
if is_openvino():
if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO

View File

@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_openvino, print_warning_once)
is_hip, print_warning_once)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -1117,7 +1117,7 @@ class DeviceConfig:
self.device_type = "cuda"
elif current_platform.is_neuron():
self.device_type = "neuron"
elif is_openvino():
elif current_platform.is_openvino():
self.device_type = "openvino"
elif current_platform.is_tpu():
self.device_type = "tpu"

View File

@ -10,6 +10,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
get_open_port, make_async)
@ -17,14 +18,6 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
logger = init_logger(__name__)
def is_openvino_cpu() -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
def is_openvino_gpu() -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False
@ -32,7 +25,8 @@ class OpenVINOExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
assert is_openvino_cpu() or is_openvino_gpu(), \
assert current_platform.is_openvino_cpu() or \
current_platform.is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"
self.ov_core = ov.Core()
@ -163,7 +157,7 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
def _verify_and_get_cache_config(ov_core: ov.Core,
config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
if not is_openvino_cpu():
if not current_platform.is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used.")
config.cache_dtype = ov.Type.f16
@ -172,7 +166,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
else:
if is_openvino_cpu():
if current_platform.is_openvino_cpu():
ov_device = envs.VLLM_OPENVINO_DEVICE
inference_precision = ov_core.get_property(
ov_device, hints.inference_precision)
@ -183,7 +177,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
else:
config.cache_dtype = ov.Type.f16
if is_openvino_cpu():
if current_platform.is_openvino_cpu():
if config.block_size != 32:
logger.info(
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
@ -198,7 +192,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0 and is_openvino_cpu():
if kv_cache_space == 0 and current_platform.is_openvino_cpu():
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "

View File

@ -12,12 +12,12 @@ from torch import nn
import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -136,7 +136,7 @@ class OpenVINOCasualLM(nn.Module):
ov_device = envs.VLLM_OPENVINO_DEVICE
paged_attention_transformation(pt_model.model)
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
is_openvino_cpu())
current_platform.is_openvino_cpu())
ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
self.ov_request = ov_compiled.create_infer_request()

View File

@ -65,6 +65,13 @@ try:
except ImportError:
pass
is_openvino = False
try:
from importlib.metadata import version
is_openvino = "openvino" in version("vllm")
except Exception:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
@ -85,6 +92,9 @@ elif is_cpu:
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()

View File

@ -11,6 +11,7 @@ class PlatformEnum(enum.Enum):
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
OPENVINO = enum.auto()
UNSPECIFIED = enum.auto()
@ -52,6 +53,9 @@ class Platform:
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_openvino(self) -> bool:
return self._enum == PlatformEnum.OPENVINO
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

View File

@ -0,0 +1,31 @@
import torch
import vllm.envs as envs
from vllm.utils import print_warning_once
from .interface import Platform, PlatformEnum
class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
@classmethod
def get_device_name(self, device_id: int = 0) -> str:
return "openvino"
@classmethod
def inference_mode(self):
return torch.inference_mode(mode=True)
@classmethod
def is_openvino_cpu(self) -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
@classmethod
def is_openvino_gpu(self) -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
@classmethod
def is_pin_memory_available(self) -> bool:
print_warning_once("Pin memory is not supported on OpenViNO.")
return False

View File

@ -318,15 +318,6 @@ def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None)
def is_openvino() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "openvino" in version("vllm")
except PackageNotFoundError:
return False
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
@ -757,7 +748,7 @@ def is_pin_memory_available() -> bool:
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif current_platform.is_cpu() or is_openvino():
elif current_platform.is_cpu() or current_platform.is_openvino():
return False
return True

View File

@ -13,12 +13,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
@ -99,7 +99,7 @@ class OpenVINOCacheEngine:
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
if is_openvino_cpu():
if current_platform.is_openvino_cpu():
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
@ -141,7 +141,7 @@ class OpenVINOCacheEngine:
if num_blocks == 0:
return swap_cache
assert not is_openvino_cpu(), \
assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to have swap cache"
# Update key_cache shape:
@ -285,7 +285,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
cache_block_size = self.get_cache_block_size_bytes()
kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
if is_openvino_cpu():
if current_platform.is_openvino_cpu():
num_device_blocks = int(kvcache_space_bytes // cache_block_size)
num_swap_blocks = 0
else:
@ -322,7 +322,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
num_device_blocks = num_gpu_blocks
num_swap_blocks = num_cpu_blocks
if is_openvino_cpu():
if current_platform.is_openvino_cpu():
assert (num_swap_blocks == 0
), f"{type(self)} does not support swappable cache for CPU"
@ -366,7 +366,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
assert self.kv_cache is not None
# Populate the cache to warmup the memory
if is_openvino_cpu():
if current_platform.is_openvino_cpu():
for key_cache, value_cache in self.kv_cache:
key_cache.data[:] = 0
value_cache.data[:] = 0
@ -414,7 +414,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
if is_openvino_cpu():
if current_platform.is_openvino_cpu():
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
else:
@ -466,7 +466,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def profile_run(self) -> int:
ov_device = envs.VLLM_OPENVINO_DEVICE
assert not is_openvino_cpu(), \
assert not current_platform.is_openvino_cpu(), \
"CPU device isn't supposed to use profile run."
import openvino.properties.device as device