mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 07:05:01 +08:00
[Bugfix][Hardware][POWERPC] Fix auto dtype failure in case of POWER10 (#11331)
Signed-off-by: Akash Kaothalkar <0052v2@linux.vnet.ibm.com>
This commit is contained in:
parent
a985f7af9f
commit
48edab8041
@ -22,7 +22,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import current_platform, interface
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
@ -2199,6 +2199,17 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
if (current_platform.is_cpu()
|
||||
and current_platform.get_cpu_architecture()
|
||||
== interface.CpuArchEnum.POWERPC
|
||||
and (config_dtype == torch.float16
|
||||
or config_dtype == torch.float32)):
|
||||
logger.info(
|
||||
"For POWERPC, we cast models to bfloat16 instead of "
|
||||
"using float16 by default. Float16 is not currently "
|
||||
"supported for POWERPC.")
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
if current_platform.is_hpu() and config_dtype == torch.float16:
|
||||
logger.info(
|
||||
"For HPU, we cast models to bfloat16 instead of"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user