diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 445cd2735c19..73f3e63fbf5f 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -31,6 +31,7 @@ docker run \ set -e echo $ZE_AFFINITY_MASK VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp cd tests diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2d288bcbe0c9..237802afccde 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -190,8 +190,7 @@ class Attention(nn.Module, AttentionLayerBase): # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( - ) and not current_platform.is_cpu() + self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 286221d32c1e..60ae14331879 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -9,6 +9,7 @@ import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger +from vllm.platforms import current_platform from .fx_utils import is_func from .vllm_inductor_pass import VllmInductorPass @@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass): """ def __call__(self, graph: torch.fx.Graph): + # XPU does not support auto-functionalization yet. + # Will enable this when switch to vllm-xpu-kernels. + if current_platform.is_xpu(): + logger.debug("XPU platform does not support fix functionalization" + "pass currently.") + return + self.begin() self.dump_graph(graph, "before_fix_functionalization") diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index c748595a7153..5686fae5cd7d 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -335,3 +335,7 @@ class CpuPlatform(Platform): return (cls.supports_v1(model_config) and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM, CpuArchEnum.S390X)) + + @classmethod + def opaque_attention_op(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c0e0fe35e402..5cbb7346436e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -442,6 +442,10 @@ class CudaPlatformBase(Platform): def use_custom_allreduce(cls) -> bool: return True + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f6c17de86d05..01f3e2d977bc 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -509,6 +509,14 @@ class Platform: """ return False + @classmethod + def opaque_attention_op(cls) -> bool: + """ + Returns True if we register attention as one giant opaque custom op + on the current platform + """ + return False + @classmethod def validate_request( cls, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 85b2fe2e480c..c6d14aa87c7f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -411,6 +411,10 @@ class RocmPlatform(Platform): supported_archs = ['gfx94', 'gfx95'] return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 235e5d8294e5..84f4cd725646 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -90,21 +90,14 @@ class XPUPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 64 - # FIXME: Temporarily forcing eager mode - # remove after t.compile support stabilizes. - if (envs.VLLM_USE_V1 and model_config is not None - and not vllm_config.model_config.enforce_eager): - from vllm.config import CompilationLevel - vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 - # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config if compilation_config.cudagraph_mode is None or \ compilation_config.cudagraph_mode.max_cudagraph_mode() \ != CUDAGraphMode.NONE: - logger.info("[XPU] CUDA graph is not supported on XPU, " - "disabling cudagraphs.") + logger.info("[XPU] CUDA graph is not supported on XPU, disabling " + "cudagraphs. Fallback to cudagraph_mode=NONE") compilation_config.cudagraph_mode = CUDAGraphMode.NONE # check and update parallel config @@ -182,3 +175,7 @@ class XPUPlatform(Platform): "Intel Arc A770 have bfloat16 accuracy known issue. " "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") + + @classmethod + def opaque_attention_op(cls) -> bool: + return True