[XPU] support Triton Attention backend on Intel GPU (#24149)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-09-04 20:41:08 +08:00 committed by GitHub
parent 2b30afa442
commit 16ded21eeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 15 deletions

View File

@ -30,10 +30,11 @@ docker run \
bash -c '
set -e
echo $ZE_AFFINITY_MASK
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
cd tests
pytest -v -s v1/core
pytest -v -s v1/engine

View File

@ -242,10 +242,9 @@ class ipex_ops:
k_scale_float: float = 1.0,
v_scale_float: float = 1.0,
) -> None:
assert kv_cache_dtype == "auto"
# TODO: support FP8 kv cache.
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slot_mapping)
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale_float, v_scale_float)
@staticmethod
def flash_attn_varlen_func(

View File

@ -6,9 +6,14 @@ from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd

View File

@ -37,14 +37,38 @@ class XPUPlatform(Platform):
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str:
if selected_backend is not None and selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
use_v1 = envs.VLLM_USE_V1
if not use_v1:
raise ValueError("XPU backend only supports V1.")
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
"""
Check if the kv_cache_dtype is supported.
XPU only support fp8 kv cache with triton backend.
"""
if envs.is_set("VLLM_ATTENTION_BACKEND") and \
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1":
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
return False
@classmethod
def set_device(cls, device: torch.device) -> None:
"""

View File

@ -7,7 +7,6 @@ from typing import ClassVar, Optional
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
@ -337,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl):
layer._v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
ops.reshape_and_cache_flash(
key,
value,
key_cache,
@ -354,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl):
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported."
if not current_platform.is_rocm():
# Skip Q quantization on ROCm, since dequantizing back to
# f32 in the attention kernel is not supported.
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),