mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 05:25:01 +08:00
[XPU] Delay BF16 check to worker init for spawn compatibility (#22979)
Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
parent
9188ae7cb5
commit
8a044754bd
@ -518,6 +518,26 @@ class CudaPlatformBase(Platform):
|
|||||||
supported = True
|
supported = True
|
||||||
return supported
|
return supported
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||||
|
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||||
|
if not cls.has_device_capability(80):
|
||||||
|
capability = cls.get_device_capability()
|
||||||
|
gpu_name = cls.get_device_name()
|
||||||
|
|
||||||
|
if capability is None:
|
||||||
|
compute_str = "does not have a compute capability"
|
||||||
|
else:
|
||||||
|
version_str = capability.as_version_str()
|
||||||
|
compute_str = f"has compute capability {version_str}"
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Bfloat16 is only supported on GPUs "
|
||||||
|
"with compute capability of at least 8.0. "
|
||||||
|
f"Your {gpu_name} GPU {compute_str}. "
|
||||||
|
"You can use float16 instead by explicitly setting the "
|
||||||
|
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@ -572,6 +572,13 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||||
|
"""
|
||||||
|
Check if the dtype is supported by the current platform.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -462,3 +462,23 @@ class RocmPlatform(Platform):
|
|||||||
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||||
model_config: "ModelConfig") -> bool:
|
model_config: "ModelConfig") -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||||
|
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||||
|
if not cls.has_device_capability(80):
|
||||||
|
capability = cls.get_device_capability()
|
||||||
|
gpu_name = cls.get_device_name()
|
||||||
|
|
||||||
|
if capability is None:
|
||||||
|
compute_str = "does not have a compute capability"
|
||||||
|
else:
|
||||||
|
version_str = capability.as_version_str()
|
||||||
|
compute_str = f"has compute capability {version_str}"
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Bfloat16 is only supported on GPUs "
|
||||||
|
"with compute capability of at least 8.0. "
|
||||||
|
f"Your {gpu_name} GPU {compute_str}. "
|
||||||
|
"You can use float16 instead by explicitly setting the "
|
||||||
|
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||||
|
|||||||
@ -97,13 +97,6 @@ class XPUPlatform(Platform):
|
|||||||
from vllm.config import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
|
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
|
||||||
|
|
||||||
# Instances created using VllmConfig() typically have model_config as
|
|
||||||
# None by default. The modification involves adding a check to prevent
|
|
||||||
# potential null exceptions check and update model config.
|
|
||||||
if model_config is not None and model_config.dtype == torch.bfloat16 \
|
|
||||||
and not cls.device_support_bf16():
|
|
||||||
model_config.dtype = torch.float16
|
|
||||||
|
|
||||||
# lazy import to avoid circular import
|
# lazy import to avoid circular import
|
||||||
from vllm.config import CUDAGraphMode
|
from vllm.config import CUDAGraphMode
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
@ -162,30 +155,11 @@ class XPUPlatform(Platform):
|
|||||||
torch.xpu.reset_peak_memory_stats(device)
|
torch.xpu.reset_peak_memory_stats(device)
|
||||||
return torch.xpu.max_memory_allocated(device)
|
return torch.xpu.max_memory_allocated(device)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def device_support_bf16(cls) -> bool:
|
|
||||||
device_name = cls.get_device_name().lower()
|
|
||||||
if cls.is_client_gpu_a770():
|
|
||||||
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
|
|
||||||
" fallback to float16")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Device name %s supports bfloat16. Please file an issue "
|
|
||||||
"if you encounter any accuracy problems with bfloat16.",
|
|
||||||
device_name)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_data_center_gpu(cls) -> bool:
|
def is_data_center_gpu(cls) -> bool:
|
||||||
device_name = cls.get_device_name().lower()
|
device_name = cls.get_device_name().lower()
|
||||||
return device_name.count("data center gpu") > 0
|
return device_name.count("data center gpu") > 0
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_client_gpu_a770(cls) -> bool:
|
|
||||||
device_name = cls.get_device_name().lower()
|
|
||||||
return device_name.count("a770") > 0
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_communicator_cls(cls) -> str:
|
def get_device_communicator_cls(cls) -> str:
|
||||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||||
@ -197,3 +171,14 @@ class XPUPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def device_count(cls) -> int:
|
def device_count(cls) -> int:
|
||||||
return torch.xpu.device_count()
|
return torch.xpu.device_count()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||||
|
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||||
|
device_name = cls.get_device_name().lower()
|
||||||
|
# client gpu a770
|
||||||
|
if device_name.count("a770") > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Intel Arc A770 have bfloat16 accuracy known issue. "
|
||||||
|
"You can use float16 instead by explicitly setting the "
|
||||||
|
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||||
|
|||||||
@ -167,7 +167,7 @@ class Worker(WorkerBase):
|
|||||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||||
current_platform.set_device(self.device)
|
current_platform.set_device(self.device)
|
||||||
|
|
||||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -612,23 +612,3 @@ def init_worker_distributed_environment(
|
|||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
|
||||||
# Check if the GPU supports the dtype.
|
|
||||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
|
||||||
if not current_platform.has_device_capability(80):
|
|
||||||
capability = current_platform.get_device_capability()
|
|
||||||
gpu_name = current_platform.get_device_name()
|
|
||||||
|
|
||||||
if capability is None:
|
|
||||||
compute_str = "does not have a compute capability"
|
|
||||||
else:
|
|
||||||
version_str = capability.as_version_str()
|
|
||||||
compute_str = f"has compute capability {version_str}"
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
"Bfloat16 is only supported on GPUs with compute capability "
|
|
||||||
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
|
||||||
"You can use float16 instead by explicitly setting the "
|
|
||||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
|
||||||
|
|||||||
@ -145,6 +145,7 @@ class XPUWorker(Worker):
|
|||||||
):
|
):
|
||||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||||
current_platform.set_device(self.device)
|
current_platform.set_device(self.device)
|
||||||
|
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||||
self.local_rank).total_memory
|
self.local_rank).total_memory
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user