diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 73f3e63fbf5f6..efcd10acf0b93 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -30,10 +30,11 @@ docker run \ bash -c ' set -e echo $ZE_AFFINITY_MASK - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager cd tests pytest -v -s v1/core pytest -v -s v1/engine diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 19f6c4e3060ce..c2868c040aa16 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -242,10 +242,9 @@ class ipex_ops: k_scale_float: float = 1.0, v_scale_float: float = 1.0, ) -> None: - assert kv_cache_dtype == "auto" - # TODO: support FP8 kv cache. ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, value, key_cache, value_cache, slot_mapping) + key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, + k_scale_float, v_scale_float) @staticmethod def flash_attn_varlen_func( diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c6d1501e27578..4d870a45e5800 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -6,9 +6,14 @@ from typing import List, Optional, Tuple import torch -from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 645a9e63a4e5a..9f89334e9a8a8 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -37,14 +37,38 @@ class XPUPlatform(Platform): dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool) -> str: - if selected_backend is not None and selected_backend != _Backend.IPEX: - logger.info("Cannot use %s backend on XPU.", selected_backend) use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") + TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + logger.info_once("Using Triton backend on V1 engine.") + return TRITON_ATTN_VLLM_V1 + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return FLASH_ATTN_V1 + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}, " + f"with use_v1: {use_v1} use_mla: {use_mla}") + logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: + """ + Check if the kv_cache_dtype is supported. + XPU only support fp8 kv cache with triton backend. + """ + if envs.is_set("VLLM_ATTENTION_BACKEND") and \ + envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1": + return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] + + return False + @classmethod def set_device(cls, device: torch.device) -> None: """ diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index a37a7f6811ef9..104cebb45d740 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -7,7 +7,6 @@ from typing import ClassVar, Optional import torch -from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + logger = init_logger(__name__) @@ -337,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl): layer._v_scale, ) else: - torch.ops._C_cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, key_cache, @@ -354,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl): num_tokens, num_heads, head_size = query.shape assert layer._q_scale == 1.0, \ "A non 1.0 q_scale is not currently supported." - if not current_platform.is_rocm(): - # Skip Q quantization on ROCm, since dequantizing back to - # f32 in the attention kernel is not supported. + if current_platform.is_cuda(): + # Skip Q quantization on ROCm and XPU, enable this on cuda + # only, since dequantizing back to f32 in the attention kernel + # is not supported. query, _ = ops.scaled_fp8_quant( query.reshape( (num_tokens, num_heads * head_size)).contiguous(),