mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:24:57 +08:00
[Hardware][XPU] using current_platform.is_xpu (#9605)
This commit is contained in:
parent
51c24c9736
commit
2394962d70
@ -10,7 +10,7 @@ import vllm.envs as envs
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino, is_xpu
|
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ def get_attn_backend(
|
|||||||
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
|
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
|
||||||
return OpenVINOAttentionBackend
|
return OpenVINOAttentionBackend
|
||||||
elif backend == _Backend.IPEX:
|
elif backend == _Backend.IPEX:
|
||||||
assert is_xpu(), RuntimeError(
|
assert current_platform.is_xpu(), RuntimeError(
|
||||||
"IPEX attention backend is only used for the XPU device.")
|
"IPEX attention backend is only used for the XPU device.")
|
||||||
logger.info("Using IPEX attention backend.")
|
logger.info("Using IPEX attention backend.")
|
||||||
from vllm.attention.backends.ipex_attn import IpexAttnBackend
|
from vllm.attention.backends.ipex_attn import IpexAttnBackend
|
||||||
@ -198,7 +198,7 @@ def which_attn_to_use(
|
|||||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||||
return _Backend.OPENVINO
|
return _Backend.OPENVINO
|
||||||
|
|
||||||
if is_xpu():
|
if current_platform.is_xpu():
|
||||||
if selected_backend != _Backend.IPEX:
|
if selected_backend != _Backend.IPEX:
|
||||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||||
return _Backend.IPEX
|
return _Backend.IPEX
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
|||||||
get_hf_image_processor_config,
|
get_hf_image_processor_config,
|
||||||
get_hf_text_config)
|
get_hf_text_config)
|
||||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||||
is_hip, is_openvino, is_xpu, print_warning_once)
|
is_hip, is_openvino, print_warning_once)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
@ -1121,7 +1121,7 @@ class DeviceConfig:
|
|||||||
self.device_type = "tpu"
|
self.device_type = "tpu"
|
||||||
elif current_platform.is_cpu():
|
elif current_platform.is_cpu():
|
||||||
self.device_type = "cpu"
|
self.device_type = "cpu"
|
||||||
elif is_xpu():
|
elif current_platform.is_xpu():
|
||||||
self.device_type = "xpu"
|
self.device_type = "xpu"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Failed to infer device type")
|
raise RuntimeError("Failed to infer device type")
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||||
from vllm.utils import get_ip, is_hip, is_xpu
|
from vllm.utils import get_ip, is_hip
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -231,7 +231,7 @@ def initialize_ray_cluster(
|
|||||||
assert_ray_available()
|
assert_ray_available()
|
||||||
|
|
||||||
# Connect to a ray cluster.
|
# Connect to a ray cluster.
|
||||||
if is_hip() or is_xpu():
|
if is_hip() or current_platform.is_xpu():
|
||||||
ray.init(address=ray_address,
|
ray.init(address=ray_address,
|
||||||
ignore_reinit_error=True,
|
ignore_reinit_error=True,
|
||||||
num_gpus=parallel_config.world_size)
|
num_gpus=parallel_config.world_size)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import vllm.envs as envs
|
|||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_hip, is_xpu, print_warning_once
|
from vllm.utils import is_hip, print_warning_once
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_cpu
|
return self.forward_cpu
|
||||||
elif current_platform.is_tpu():
|
elif current_platform.is_tpu():
|
||||||
return self.forward_tpu
|
return self.forward_tpu
|
||||||
elif is_xpu():
|
elif current_platform.is_xpu():
|
||||||
return self.forward_xpu
|
return self.forward_xpu
|
||||||
else:
|
else:
|
||||||
return self.forward_cuda
|
return self.forward_cuda
|
||||||
|
|||||||
@ -327,29 +327,6 @@ def is_openvino() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
|
||||||
def is_xpu() -> bool:
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
try:
|
|
||||||
is_xpu_flag = "xpu" in version("vllm")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
return False
|
|
||||||
# vllm is not build with xpu
|
|
||||||
if not is_xpu_flag:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
|
||||||
_import_ipex = True
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning("Import Error for IPEX: %s", e.msg)
|
|
||||||
_import_ipex = False
|
|
||||||
# ipex dependency is not ready
|
|
||||||
if not _import_ipex:
|
|
||||||
logger.warning("not found ipex lib")
|
|
||||||
return False
|
|
||||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
@ -379,7 +356,7 @@ def seed_everything(seed: int) -> None:
|
|||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
if is_xpu():
|
if current_platform.is_xpu():
|
||||||
torch.xpu.manual_seed_all(seed)
|
torch.xpu.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
@ -774,7 +751,7 @@ def is_pin_memory_available() -> bool:
|
|||||||
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
|
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
|
||||||
"This may slow down the performance.")
|
"This may slow down the performance.")
|
||||||
return False
|
return False
|
||||||
elif is_xpu():
|
elif current_platform.is_xpu():
|
||||||
print_warning_once("Pin memory is not supported on XPU.")
|
print_warning_once("Pin memory is not supported on XPU.")
|
||||||
return False
|
return False
|
||||||
elif current_platform.is_neuron():
|
elif current_platform.is_neuron():
|
||||||
@ -795,7 +772,7 @@ class DeviceMemoryProfiler:
|
|||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
torch.cuda.reset_peak_memory_stats(self.device)
|
torch.cuda.reset_peak_memory_stats(self.device)
|
||||||
mem = torch.cuda.max_memory_allocated(self.device)
|
mem = torch.cuda.max_memory_allocated(self.device)
|
||||||
elif is_xpu():
|
elif current_platform.is_xpu():
|
||||||
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
|
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
|
||||||
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
|
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
|
||||||
return mem
|
return mem
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.utils import is_xpu
|
from vllm.platforms import current_platform
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
@ -53,7 +53,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
|||||||
observability_config: Optional[ObservabilityConfig] = None,
|
observability_config: Optional[ObservabilityConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert device_config.device_type == "xpu"
|
assert device_config.device_type == "xpu"
|
||||||
assert is_xpu()
|
assert current_platform.is_xpu()
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
@ -91,7 +91,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
|||||||
self.gpu_cache: Optional[List[List[torch.Tensor]]]
|
self.gpu_cache: Optional[List[List[torch.Tensor]]]
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
if self.device_config.device.type == "xpu" and is_xpu():
|
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
|
||||||
|
):
|
||||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||||
torch.xpu.set_device(self.device)
|
torch.xpu.set_device(self.device)
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user