mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 20:05:33 +08:00
[XPU] support Triton Attention backend on Intel GPU (#24149)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
2b30afa442
commit
16ded21eeb
@ -30,10 +30,11 @@ docker run \
|
|||||||
bash -c '
|
bash -c '
|
||||||
set -e
|
set -e
|
||||||
echo $ZE_AFFINITY_MASK
|
echo $ZE_AFFINITY_MASK
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||||
|
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||||
cd tests
|
cd tests
|
||||||
pytest -v -s v1/core
|
pytest -v -s v1/core
|
||||||
pytest -v -s v1/engine
|
pytest -v -s v1/engine
|
||||||
|
|||||||
@ -242,10 +242,9 @@ class ipex_ops:
|
|||||||
k_scale_float: float = 1.0,
|
k_scale_float: float = 1.0,
|
||||||
v_scale_float: float = 1.0,
|
v_scale_float: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert kv_cache_dtype == "auto"
|
|
||||||
# TODO: support FP8 kv cache.
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
key, value, key_cache, value_cache, slot_mapping)
|
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||||
|
k_scale_float, v_scale_float)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flash_attn_varlen_func(
|
def flash_attn_varlen_func(
|
||||||
|
|||||||
@ -6,9 +6,14 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
|
|||||||
@ -37,14 +37,38 @@ class XPUPlatform(Platform):
|
|||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
block_size: int, use_v1: bool, use_mla: bool,
|
block_size: int, use_v1: bool, use_mla: bool,
|
||||||
has_sink: bool) -> str:
|
has_sink: bool) -> str:
|
||||||
if selected_backend is not None and selected_backend != _Backend.IPEX:
|
|
||||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
|
||||||
use_v1 = envs.VLLM_USE_V1
|
use_v1 = envs.VLLM_USE_V1
|
||||||
if not use_v1:
|
if not use_v1:
|
||||||
raise ValueError("XPU backend only supports V1.")
|
raise ValueError("XPU backend only supports V1.")
|
||||||
|
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||||
|
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||||
|
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||||
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
|
return TRITON_ATTN_VLLM_V1
|
||||||
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
|
return FLASH_ATTN_V1
|
||||||
|
elif selected_backend:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend for {cls.device_name}, "
|
||||||
|
f"with use_v1: {use_v1} use_mla: {use_mla}")
|
||||||
|
|
||||||
logger.info("Using Flash Attention backend on V1 engine.")
|
logger.info("Using Flash Attention backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
||||||
|
model_config: "ModelConfig") -> bool:
|
||||||
|
"""
|
||||||
|
Check if the kv_cache_dtype is supported.
|
||||||
|
XPU only support fp8 kv cache with triton backend.
|
||||||
|
"""
|
||||||
|
if envs.is_set("VLLM_ATTENTION_BACKEND") and \
|
||||||
|
envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1":
|
||||||
|
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_device(cls, device: torch.device) -> None:
|
def set_device(cls, device: torch.device) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import ClassVar, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
@ -23,6 +22,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -337,7 +341,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -354,9 +358,10 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
num_tokens, num_heads, head_size = query.shape
|
num_tokens, num_heads, head_size = query.shape
|
||||||
assert layer._q_scale == 1.0, \
|
assert layer._q_scale == 1.0, \
|
||||||
"A non 1.0 q_scale is not currently supported."
|
"A non 1.0 q_scale is not currently supported."
|
||||||
if not current_platform.is_rocm():
|
if current_platform.is_cuda():
|
||||||
# Skip Q quantization on ROCm, since dequantizing back to
|
# Skip Q quantization on ROCm and XPU, enable this on cuda
|
||||||
# f32 in the attention kernel is not supported.
|
# only, since dequantizing back to f32 in the attention kernel
|
||||||
|
# is not supported.
|
||||||
query, _ = ops.scaled_fp8_quant(
|
query, _ = ops.scaled_fp8_quant(
|
||||||
query.reshape(
|
query.reshape(
|
||||||
(num_tokens, num_heads * head_size)).contiguous(),
|
(num_tokens, num_heads * head_size)).contiguous(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user