mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[Hardware][openvino] is_openvino --> current_platform.is_openvino (#9716)
This commit is contained in:
parent
067e77f9a8
commit
5cbdccd151
@ -30,7 +30,8 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
False)
|
||||
assert backend.name == "ROCM_FLASH"
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.is_openvino", return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform.is_openvino",
|
||||
return_value=True):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "OPENVINO"
|
||||
|
||||
@ -10,7 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -193,7 +193,7 @@ def which_attn_to_use(
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if is_openvino():
|
||||
if current_platform.is_openvino():
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
return _Backend.OPENVINO
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||
get_hf_image_processor_config,
|
||||
get_hf_text_config)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
is_hip, is_openvino, print_warning_once)
|
||||
is_hip, print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -1117,7 +1117,7 @@ class DeviceConfig:
|
||||
self.device_type = "cuda"
|
||||
elif current_platform.is_neuron():
|
||||
self.device_type = "neuron"
|
||||
elif is_openvino():
|
||||
elif current_platform.is_openvino():
|
||||
self.device_type = "openvino"
|
||||
elif current_platform.is_tpu():
|
||||
self.device_type = "tpu"
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
|
||||
get_open_port, make_async)
|
||||
@ -17,14 +18,6 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_openvino_cpu() -> bool:
|
||||
return "CPU" in envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
|
||||
def is_openvino_gpu() -> bool:
|
||||
return "GPU" in envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
|
||||
class OpenVINOExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
@ -32,7 +25,8 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "openvino"
|
||||
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
|
||||
assert is_openvino_cpu() or is_openvino_gpu(), \
|
||||
assert current_platform.is_openvino_cpu() or \
|
||||
current_platform.is_openvino_gpu(), \
|
||||
"OpenVINO backend supports only CPU and GPU devices"
|
||||
|
||||
self.ov_core = ov.Core()
|
||||
@ -163,7 +157,7 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
||||
def _verify_and_get_cache_config(ov_core: ov.Core,
|
||||
config: CacheConfig) -> CacheConfig:
|
||||
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
|
||||
if not is_openvino_cpu():
|
||||
if not current_platform.is_openvino_cpu():
|
||||
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
|
||||
"ignored for GPU, f16 data type will be used.")
|
||||
config.cache_dtype = ov.Type.f16
|
||||
@ -172,7 +166,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
|
||||
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
|
||||
config.cache_dtype = ov.Type.u8
|
||||
else:
|
||||
if is_openvino_cpu():
|
||||
if current_platform.is_openvino_cpu():
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
inference_precision = ov_core.get_property(
|
||||
ov_device, hints.inference_precision)
|
||||
@ -183,7 +177,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
|
||||
else:
|
||||
config.cache_dtype = ov.Type.f16
|
||||
|
||||
if is_openvino_cpu():
|
||||
if current_platform.is_openvino_cpu():
|
||||
if config.block_size != 32:
|
||||
logger.info(
|
||||
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
|
||||
@ -198,7 +192,7 @@ def _verify_and_get_cache_config(ov_core: ov.Core,
|
||||
|
||||
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
|
||||
if kv_cache_space >= 0:
|
||||
if kv_cache_space == 0 and is_openvino_cpu():
|
||||
if kv_cache_space == 0 and current_platform.is_openvino_cpu():
|
||||
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
||||
logger.warning(
|
||||
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
|
||||
|
||||
@ -12,12 +12,12 @@ from torch import nn
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import DeviceConfig, ModelConfig
|
||||
from vllm.executor.openvino_executor import is_openvino_cpu
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
||||
_prune_hidden_states)
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -136,7 +136,7 @@ class OpenVINOCasualLM(nn.Module):
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
paged_attention_transformation(pt_model.model)
|
||||
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
|
||||
is_openvino_cpu())
|
||||
current_platform.is_openvino_cpu())
|
||||
|
||||
ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
|
||||
self.ov_request = ov_compiled.create_infer_request()
|
||||
|
||||
@ -65,6 +65,13 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
is_openvino = False
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
is_openvino = "openvino" in version("vllm")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_tpu:
|
||||
# people might install pytorch built with cuda but run on tpu
|
||||
# so we need to check tpu first
|
||||
@ -85,6 +92,9 @@ elif is_cpu:
|
||||
elif is_neuron:
|
||||
from .neuron import NeuronPlatform
|
||||
current_platform = NeuronPlatform()
|
||||
elif is_openvino:
|
||||
from .openvino import OpenVinoPlatform
|
||||
current_platform = OpenVinoPlatform()
|
||||
else:
|
||||
current_platform = UnspecifiedPlatform()
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ class PlatformEnum(enum.Enum):
|
||||
XPU = enum.auto()
|
||||
CPU = enum.auto()
|
||||
NEURON = enum.auto()
|
||||
OPENVINO = enum.auto()
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
|
||||
@ -52,6 +53,9 @@ class Platform:
|
||||
def is_neuron(self) -> bool:
|
||||
return self._enum == PlatformEnum.NEURON
|
||||
|
||||
def is_openvino(self) -> bool:
|
||||
return self._enum == PlatformEnum.OPENVINO
|
||||
|
||||
def is_cuda_alike(self) -> bool:
|
||||
"""Stateless version of :func:`torch.cuda.is_available`."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
31
vllm/platforms/openvino.py
Normal file
31
vllm/platforms/openvino.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
|
||||
class OpenVinoPlatform(Platform):
|
||||
_enum = PlatformEnum.OPENVINO
|
||||
|
||||
@classmethod
|
||||
def get_device_name(self, device_id: int = 0) -> str:
|
||||
return "openvino"
|
||||
|
||||
@classmethod
|
||||
def inference_mode(self):
|
||||
return torch.inference_mode(mode=True)
|
||||
|
||||
@classmethod
|
||||
def is_openvino_cpu(self) -> bool:
|
||||
return "CPU" in envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
@classmethod
|
||||
def is_openvino_gpu(self) -> bool:
|
||||
return "GPU" in envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(self) -> bool:
|
||||
print_warning_once("Pin memory is not supported on OpenViNO.")
|
||||
return False
|
||||
@ -318,15 +318,6 @@ def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_openvino() -> bool:
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
try:
|
||||
return "openvino" in version("vllm")
|
||||
except PackageNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
@ -757,7 +748,7 @@ def is_pin_memory_available() -> bool:
|
||||
elif current_platform.is_neuron():
|
||||
print_warning_once("Pin memory is not supported on Neuron.")
|
||||
return False
|
||||
elif current_platform.is_cpu() or is_openvino():
|
||||
elif current_platform.is_cpu() or current_platform.is_openvino():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@ -13,12 +13,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.executor.openvino_executor import is_openvino_cpu
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
|
||||
@ -99,7 +99,7 @@ class OpenVINOCacheEngine:
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
|
||||
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
|
||||
|
||||
if is_openvino_cpu():
|
||||
if current_platform.is_openvino_cpu():
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
|
||||
k_block_shape)
|
||||
@ -141,7 +141,7 @@ class OpenVINOCacheEngine:
|
||||
if num_blocks == 0:
|
||||
return swap_cache
|
||||
|
||||
assert not is_openvino_cpu(), \
|
||||
assert not current_platform.is_openvino_cpu(), \
|
||||
"CPU device isn't supposed to have swap cache"
|
||||
|
||||
# Update key_cache shape:
|
||||
@ -285,7 +285,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
|
||||
|
||||
if is_openvino_cpu():
|
||||
if current_platform.is_openvino_cpu():
|
||||
num_device_blocks = int(kvcache_space_bytes // cache_block_size)
|
||||
num_swap_blocks = 0
|
||||
else:
|
||||
@ -322,7 +322,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
num_device_blocks = num_gpu_blocks
|
||||
num_swap_blocks = num_cpu_blocks
|
||||
|
||||
if is_openvino_cpu():
|
||||
if current_platform.is_openvino_cpu():
|
||||
assert (num_swap_blocks == 0
|
||||
), f"{type(self)} does not support swappable cache for CPU"
|
||||
|
||||
@ -366,7 +366,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
assert self.kv_cache is not None
|
||||
|
||||
# Populate the cache to warmup the memory
|
||||
if is_openvino_cpu():
|
||||
if current_platform.is_openvino_cpu():
|
||||
for key_cache, value_cache in self.kv_cache:
|
||||
key_cache.data[:] = 0
|
||||
value_cache.data[:] = 0
|
||||
@ -414,7 +414,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
blocks_to_swap_in = data["blocks_to_swap_in"]
|
||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||
|
||||
if is_openvino_cpu():
|
||||
if current_platform.is_openvino_cpu():
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
else:
|
||||
@ -466,7 +466,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
def profile_run(self) -> int:
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
assert not is_openvino_cpu(), \
|
||||
assert not current_platform.is_openvino_cpu(), \
|
||||
"CPU device isn't supposed to use profile run."
|
||||
|
||||
import openvino.properties.device as device
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user