mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 13:14:59 +08:00
[Platforms] Refactor xpu code (#10468)
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
parent
09dbf9ff16
commit
d5b28447e0
@ -1,8 +1,5 @@
|
|||||||
from typing import Callable, List, Optional, Tuple, Type, Union
|
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig, ParallelConfig
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -23,7 +20,6 @@ class XPUExecutor(GPUExecutor):
|
|||||||
assert self.speculative_config is None, (
|
assert self.speculative_config is None, (
|
||||||
"Speculative decoding not yet supported for XPU backend")
|
"Speculative decoding not yet supported for XPU backend")
|
||||||
|
|
||||||
self.model_config = _verify_and_get_model_config(self.model_config)
|
|
||||||
GPUExecutor._init_executor(self)
|
GPUExecutor._init_executor(self)
|
||||||
|
|
||||||
def _get_worker_module_and_class(
|
def _get_worker_module_and_class(
|
||||||
@ -53,26 +49,3 @@ class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
|
|||||||
output = await make_async(self.driver_worker.execute_model
|
output = await make_async(self.driver_worker.execute_model
|
||||||
)(execute_model_req=execute_model_req)
|
)(execute_model_req=execute_model_req)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
|
||||||
if config.dtype == torch.bfloat16:
|
|
||||||
logger.warning(
|
|
||||||
"bfloat16 is not fully supported on XPU, casting to float16.")
|
|
||||||
config.dtype = torch.float16
|
|
||||||
if not config.enforce_eager:
|
|
||||||
logger.warning(
|
|
||||||
"CUDA graph is not supported on XPU, fallback to the eager "
|
|
||||||
"mode.")
|
|
||||||
config.enforce_eager = True
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
|
|
||||||
if (config.distributed_executor_backend is not None
|
|
||||||
and config.distributed_executor_backend != "ray"):
|
|
||||||
logger.warning(
|
|
||||||
"%s is not supported on XPU, fallback to ray distributed executor "
|
|
||||||
"backend.", config.distributed_executor_backend)
|
|
||||||
config.distributed_executor_backend = "ray"
|
|
||||||
return config
|
|
||||||
|
|||||||
@ -1,9 +1,16 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
else:
|
||||||
|
VllmConfig = None
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -34,3 +41,17 @@ class XPUPlatform(Platform):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def inference_mode():
|
def inference_mode():
|
||||||
return torch.no_grad()
|
return torch.no_grad()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
|
# check and update model config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
if model_config.dtype == torch.bfloat16:
|
||||||
|
logger.warning(
|
||||||
|
"bfloat16 is not fully supported on XPU, casting to float16.")
|
||||||
|
model_config.dtype = torch.float16
|
||||||
|
if not model_config.enforce_eager:
|
||||||
|
logger.warning(
|
||||||
|
"CUDA graph is not supported on XPU, fallback to the eager "
|
||||||
|
"mode.")
|
||||||
|
model_config.enforce_eager = True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user