mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:15:01 +08:00
[XPU] Add xpu torch.compile support (#22609)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
d272415e57
commit
fce10dbed5
@ -31,6 +31,7 @@ docker run \
|
|||||||
set -e
|
set -e
|
||||||
echo $ZE_AFFINITY_MASK
|
echo $ZE_AFFINITY_MASK
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||||
|
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||||
cd tests
|
cd tests
|
||||||
|
|||||||
@ -190,8 +190,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
# torch.compile works by registering the attention as one giant
|
# torch.compile works by registering the attention as one giant
|
||||||
# opaque custom op. For other platforms, we directly call them
|
# opaque custom op. For other platforms, we directly call them
|
||||||
# and let torch.compile handle them.
|
# and let torch.compile handle them.
|
||||||
self.use_direct_call = not current_platform.is_cuda_alike(
|
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||||
) and not current_platform.is_cpu()
|
|
||||||
|
|
||||||
self.use_output = self.attn_backend.accept_output_buffer
|
self.use_output = self.attn_backend.accept_output_buffer
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import torch
|
|||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .fx_utils import is_func
|
from .fx_utils import is_func
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
# XPU does not support auto-functionalization yet.
|
||||||
|
# Will enable this when switch to vllm-xpu-kernels.
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
logger.debug("XPU platform does not support fix functionalization"
|
||||||
|
"pass currently.")
|
||||||
|
return
|
||||||
|
|
||||||
self.begin()
|
self.begin()
|
||||||
self.dump_graph(graph, "before_fix_functionalization")
|
self.dump_graph(graph, "before_fix_functionalization")
|
||||||
|
|
||||||
|
|||||||
@ -335,3 +335,7 @@ class CpuPlatform(Platform):
|
|||||||
return (cls.supports_v1(model_config)
|
return (cls.supports_v1(model_config)
|
||||||
and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC,
|
and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC,
|
||||||
CpuArchEnum.ARM, CpuArchEnum.S390X))
|
CpuArchEnum.ARM, CpuArchEnum.S390X))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def opaque_attention_op(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|||||||
@ -442,6 +442,10 @@ class CudaPlatformBase(Platform):
|
|||||||
def use_custom_allreduce(cls) -> bool:
|
def use_custom_allreduce(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def opaque_attention_op(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_static_graph_wrapper_cls(cls) -> str:
|
def get_static_graph_wrapper_cls(cls) -> str:
|
||||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||||
|
|||||||
@ -509,6 +509,14 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def opaque_attention_op(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if we register attention as one giant opaque custom op
|
||||||
|
on the current platform
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_request(
|
def validate_request(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -411,6 +411,10 @@ class RocmPlatform(Platform):
|
|||||||
supported_archs = ['gfx94', 'gfx95']
|
supported_archs = ['gfx94', 'gfx95']
|
||||||
return any(gfx in gcn_arch for gfx in supported_archs)
|
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def opaque_attention_op(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||||
return torch.cuda.get_device_properties(
|
return torch.cuda.get_device_properties(
|
||||||
|
|||||||
@ -90,21 +90,14 @@ class XPUPlatform(Platform):
|
|||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 64
|
cache_config.block_size = 64
|
||||||
|
|
||||||
# FIXME: Temporarily forcing eager mode
|
|
||||||
# remove after t.compile support stabilizes.
|
|
||||||
if (envs.VLLM_USE_V1 and model_config is not None
|
|
||||||
and not vllm_config.model_config.enforce_eager):
|
|
||||||
from vllm.config import CompilationLevel
|
|
||||||
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
|
|
||||||
|
|
||||||
# lazy import to avoid circular import
|
# lazy import to avoid circular import
|
||||||
from vllm.config import CUDAGraphMode
|
from vllm.config import CUDAGraphMode
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if compilation_config.cudagraph_mode is None or \
|
if compilation_config.cudagraph_mode is None or \
|
||||||
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
compilation_config.cudagraph_mode.max_cudagraph_mode() \
|
||||||
!= CUDAGraphMode.NONE:
|
!= CUDAGraphMode.NONE:
|
||||||
logger.info("[XPU] CUDA graph is not supported on XPU, "
|
logger.info("[XPU] CUDA graph is not supported on XPU, disabling "
|
||||||
"disabling cudagraphs.")
|
"cudagraphs. Fallback to cudagraph_mode=NONE")
|
||||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
# check and update parallel config
|
# check and update parallel config
|
||||||
@ -182,3 +175,7 @@ class XPUPlatform(Platform):
|
|||||||
"Intel Arc A770 have bfloat16 accuracy known issue. "
|
"Intel Arc A770 have bfloat16 accuracy known issue. "
|
||||||
"You can use float16 instead by explicitly setting the "
|
"You can use float16 instead by explicitly setting the "
|
||||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def opaque_attention_op(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user