[Hardware][Intel GPU] Add v1 Intel GPU support with Flash attention backend. (#19560)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-06-27 00:27:18 +08:00 committed by GitHub
parent 0bceac9810
commit b69781f107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 393 additions and 42 deletions

View File

@ -28,4 +28,5 @@ docker run \
sh -c '
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
'

View File

@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
ENV VLLM_TARGET_DEVICE=xpu
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \

View File

@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0
wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
torch==2.7.0+xpu
torchaudio

View File

@ -228,6 +228,111 @@ class ipex_ops:
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
k_scale_float: float = 1.0,
v_scale_float: float = 1.0,
) -> None:
assert kv_cache_dtype == "auto"
# TODO: support FP8 kv cache.
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: Optional[torch.Tensor],
window_size: Optional[list[int]] = None,
softcap: Optional[float] = 0.0,
cu_seqlens_k: Optional[torch.Tensor] = None,
# The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
fa_version: int = 2,
q_descale=None,
k_descale=None,
v_descale=None,
):
if cu_seqlens_k is None:
# cu_seqlens_k is not used in ipex kernel.
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
cu_seqlens_k = torch.cat([
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
cu_seqlens_k
]).to(torch.int32)
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
@staticmethod
def get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
) -> None:
logger.warning_once(
"get_scheduler_metadata is not implemented for ipex_ops, "
"returning None.")
return None
@staticmethod
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],

View File

@ -4,13 +4,27 @@ from typing import Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
if current_platform.is_cuda():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
if current_platform.is_xpu():
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
@ -50,6 +64,5 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
def flash_attn_supports_fp8() -> bool:
from vllm.platforms import current_platform
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9

View File

@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
if envs.VLLM_USE_V1:
if envs.VLLM_USE_V1 and not current_platform.is_xpu():
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"

View File

@ -1,18 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
logger = init_logger(__name__)
@ -35,8 +38,13 @@ class XPUPlatform(Platform):
use_mla: bool) -> str:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
use_v1 = envs.VLLM_USE_V1
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
else:
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
@classmethod
def get_device_capability(
@ -67,25 +75,27 @@ class XPUPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
# in V1(or with ipex chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
if envs.VLLM_USE_V1:
cache_config.block_size = 64
else:
cache_config.block_size = 16
# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.
if vllm_config.model_config is not None:
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"bfloat16 is only supported on Intel Data Center GPU, "
"Intel Arc GPU is not supported yet. Your device is %s,"
" which is not supported. will fallback to float16",
cls.get_device_name())
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True
if vllm_config.speculative_config is not None:
raise NotImplementedError(
@ -96,21 +106,27 @@ class XPUPlatform(Platform):
# check and update parallel config
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1:
parallel_config.worker_cls =\
"vllm.v1.worker.xpu_worker.XPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
if parallel_config.distributed_executor_backend is None:
parallel_config.distributed_executor_backend = "ray"
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "ray"
else:
parallel_config.distributed_executor_backend = "uni"
elif parallel_config.distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, setting it to ray instead.")
parallel_config.distributed_executor_backend = "ray"
elif parallel_config.distributed_executor_backend != "ray":
if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
logger.warning(
"Please use spawn as start method if you want to use mp.")
elif parallel_config.distributed_executor_backend != "ray" and \
parallel_config.distributed_executor_backend != "uni":
logger.warning(
"%s is not supported on XPU, fallback to ray distributed"
" executor backend.",
@ -142,15 +158,35 @@ class XPUPlatform(Platform):
@classmethod
def device_support_bf16(cls) -> bool:
device_name = cls.get_device_name().lower()
if device_name.count("arc") > 0:
if cls.is_client_gpu_a770():
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
" fallback to float16")
return False
elif device_name.count("data center gpu") > 0:
return True
else:
logger.warning("Unknown device name %s, always use float16",
device_name)
return False
logger.info(
"Device name %s supports bfloat16. Please file an issue "
"if you encounter any accuracy problems with bfloat16.",
device_name)
return True
@classmethod
def is_data_center_gpu(cls) -> bool:
device_name = cls.get_device_name().lower()
return device_name.count("data center gpu") > 0
@classmethod
def is_client_gpu_a770(cls) -> bool:
device_name = cls.get_device_name().lower()
return device_name.count("a770") > 0
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True
@classmethod
def device_count(cls) -> int:
return torch.xpu.device_count()

View File

@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
flash_attn_varlen_func,
get_flash_attn_version,
get_scheduler_metadata,
reshape_and_cache_flash)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
@ -28,10 +30,6 @@ from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
logger = init_logger(__name__)
@ -443,7 +441,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
reshape_and_cache_flash(
key,
value,
key_cache,

View File

@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
pass
logger = init_logger(__name__)
class XPUModelRunner(GPUModelRunner):
"""A model runner for XPU devices."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(vllm_config, device)
# FIXME: To be verified.
self.cascade_attn_enabled = False
def _init_device_properties(self) -> None:
pass
def _sync_device(self) -> None:
torch.xpu.synchronize()

View File

@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
import os
import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
class XPUWorker(Worker):
"""A XPU worker class."""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
super().__init__(vllm_config, local_rank, rank,
distributed_init_method, is_driver_worker)
device_config = self.device_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.XPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
# return correct free_gpu_memory on intel client GPU. We need to
# calculate/estiamte it.
def xpu_get_mem_info(self):
if current_platform.is_data_center_gpu():
return torch.xpu.mem_get_info()
else:
_, total_gpu_memory = torch.xpu.mem_get_info()
# FIXME: memory_allocated() doesn't count non-torch allocations,
# and we don't have any API to get it. so we mark it as 128MB.
used_memory = torch.xpu.memory_allocated()
non_torch_allocations = 128 * 1024 * 1024
free_gpu_memory = total_gpu_memory - (used_memory +
non_torch_allocations)
return free_gpu_memory, total_gpu_memory
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
torch.xpu.reset_peak_memory_stats()
free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
current_allocated_bytes = torch.xpu.memory_allocated()
msg = ("Before memory profiling run, "
f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
logger.info(msg)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
free_gpu_memory, _ = self.xpu_get_mem_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_gpu_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
# Get the peak memory allocation recorded by torch
peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
torch.xpu.empty_cache()
torch_allocated_bytes = torch.xpu.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = self.xpu_get_mem_info(
)[1] - self.xpu_get_mem_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
msg = ("After memory profiling run, "
f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
logger.info(msg)
return int(available_kv_cache_memory)
def init_device(self):
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
):
self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd")
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(self.parallel_config.world_size))
os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
dist_backend = "ccl"
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank, dist_backend)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner = XPUModelRunner( # type: ignore
self.vllm_config, self.device)