diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index f54010c4231f9..827649bfcf548 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -28,4 +28,5 @@ docker run \ sh -c ' VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager ' diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 681102b9d18be..466ba98333635 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi ENV VLLM_TARGET_DEVICE=xpu +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 3cb6a4a8addac..0d95dc57152de 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0 wheel jinja2>=3.1.6 datasets # for benchmark scripts +numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding torch==2.7.0+xpu torchaudio diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index ae63e06030dd1..2be02411ec05e 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -228,6 +228,111 @@ class ipex_ops: ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping) + @staticmethod + def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, + k_scale_float: float = 1.0, + v_scale_float: float = 1.0, + ) -> None: + assert kv_cache_dtype == "auto" + # TODO: support FP8 kv cache. + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def flash_attn_varlen_func( + out: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + seqused_k: torch.Tensor, # we don't support this in ipex kernel + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + causal: bool, + block_table: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + window_size: Optional[list[int]] = None, + softcap: Optional[float] = 0.0, + cu_seqlens_k: Optional[torch.Tensor] = None, + # The following parameters are not used in ipex kernel currently, + # we keep API compatible to CUDA's. + scheduler_metadata=None, + fa_version: int = 2, + q_descale=None, + k_descale=None, + v_descale=None, + ): + if cu_seqlens_k is None: + # cu_seqlens_k is not used in ipex kernel. + cu_seqlens_k = torch.cumsum(seqused_k, dim=0) + cu_seqlens_k = torch.cat([ + torch.tensor([0], device=seqused_k.device, dtype=torch.int32), + cu_seqlens_k + ]).to(torch.int32) + + real_window_size: tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + q.contiguous(), + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + block_table, + alibi_slopes, + softcap=softcap, + window_size_left=real_window_size[0], + window_size_right=real_window_size[1], + k_scale=1.0, + v_scale=1.0, + ) + + @staticmethod + def get_scheduler_metadata( + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + ) -> None: + logger.warning_once( + "get_scheduler_metadata is not implemented for ipex_ops, " + "returning None.") + return None + @staticmethod def copy_blocks(key_caches: list[torch.Tensor], value_caches: list[torch.Tensor], diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 69cde06fd72e9..36fd2d231bc5f 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -4,13 +4,27 @@ from typing import Optional from vllm import envs from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) +if current_platform.is_cuda(): + from vllm import _custom_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash + flash_attn_varlen_func = ops.flash_attn_varlen_func + get_scheduler_metadata = ops.get_scheduler_metadata + def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform + if current_platform.is_xpu(): + return 2 try: from vllm.vllm_flash_attn.flash_attn_interface import ( fa_version_unsupported_reason, is_fa_version_supported) @@ -50,6 +64,5 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: def flash_attn_supports_fp8() -> bool: - from vllm.platforms import current_platform return get_flash_attn_version() == 3 and \ current_platform.get_device_capability().major == 9 diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index a3f05ec5ea3f2..84e8ddd8e274d 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - if envs.VLLM_USE_V1: + if envs.VLLM_USE_V1 and not current_platform.is_xpu(): # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 73f6f3d417671..f361f5e2616ef 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,18 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import TYPE_CHECKING, Optional import torch +import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: + ModelConfig = None VllmConfig = None logger = init_logger(__name__) @@ -35,8 +38,13 @@ class XPUPlatform(Platform): use_mla: bool) -> str: if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) - logger.info("Using IPEX attention backend.") - return "vllm.attention.backends.ipex_attn.IpexAttnBackend" + use_v1 = envs.VLLM_USE_V1 + if use_v1: + logger.info("Using Flash Attention backend on V1 engine.") + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + else: + logger.info("Using IPEX attention backend.") + return "vllm.attention.backends.ipex_attn.IpexAttnBackend" @classmethod def get_device_capability( @@ -67,25 +75,27 @@ class XPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config + # in V1(or with ipex chunked prefill) block_size is 64 if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + if envs.VLLM_USE_V1: + cache_config.block_size = 64 + else: + cache_config.block_size = 16 - # check and update model config - model_config = vllm_config.model_config - if model_config.dtype == torch.bfloat16: - bf16_supported = cls.device_support_bf16() - if not bf16_supported: + # Instances created using VllmConfig() typically have model_config as + # None by default. The modification involves adding a check to prevent + # potential null exceptions check and update model config. + if vllm_config.model_config is not None: + model_config = vllm_config.model_config + if model_config.dtype == torch.bfloat16: + bf16_supported = cls.device_support_bf16() + if not bf16_supported: + model_config.dtype = torch.float16 + if not model_config.enforce_eager: logger.warning( - "bfloat16 is only supported on Intel Data Center GPU, " - "Intel Arc GPU is not supported yet. Your device is %s," - " which is not supported. will fallback to float16", - cls.get_device_name()) - model_config.dtype = torch.float16 - if not model_config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - model_config.enforce_eager = True + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True if vllm_config.speculative_config is not None: raise NotImplementedError( @@ -96,21 +106,27 @@ class XPUPlatform(Platform): # check and update parallel config parallel_config = vllm_config.parallel_config - if parallel_config.worker_cls == "auto": + if envs.VLLM_USE_V1: + parallel_config.worker_cls =\ + "vllm.v1.worker.xpu_worker.XPUWorker" + else: parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" if parallel_config.distributed_executor_backend is None: - parallel_config.distributed_executor_backend = "ray" + if parallel_config.world_size > 1: + parallel_config.distributed_executor_backend = "ray" + else: + parallel_config.distributed_executor_backend = "uni" elif parallel_config.distributed_executor_backend == "mp": # FIXME(kunshang): # spawn needs calling `if __name__ == '__main__':`` # fork is not supported for xpu start new process. - logger.error( - "Both start methods (spawn and fork) have issue " - "on XPU if you use mp backend, setting it to ray instead.") - parallel_config.distributed_executor_backend = "ray" - - elif parallel_config.distributed_executor_backend != "ray": + if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn": + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + logger.warning( + "Please use spawn as start method if you want to use mp.") + elif parallel_config.distributed_executor_backend != "ray" and \ + parallel_config.distributed_executor_backend != "uni": logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", @@ -142,15 +158,35 @@ class XPUPlatform(Platform): @classmethod def device_support_bf16(cls) -> bool: device_name = cls.get_device_name().lower() - if device_name.count("arc") > 0: + if cls.is_client_gpu_a770(): + logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," + " fallback to float16") return False - elif device_name.count("data center gpu") > 0: - return True else: - logger.warning("Unknown device name %s, always use float16", - device_name) - return False + logger.info( + "Device name %s supports bfloat16. Please file an issue " + "if you encounter any accuracy problems with bfloat16.", + device_name) + return True + + @classmethod + def is_data_center_gpu(cls) -> bool: + device_name = cls.get_device_name().lower() + return device_name.count("data center gpu") > 0 + + @classmethod + def is_client_gpu_a770(cls) -> bool: + device_name = cls.get_device_name().lower() + return device_name.count("a770") > 0 @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + return True + + @classmethod + def device_count(cls) -> int: + return torch.xpu.device_count() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ef65d2ea36e4f..42b5997f085b1 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) + flash_attn_varlen_func, + get_flash_attn_version, + get_scheduler_metadata, + reshape_and_cache_flash) from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, @@ -28,10 +30,6 @@ from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner -if current_platform.is_cuda(): - from vllm.vllm_flash_attn import (flash_attn_varlen_func, - get_scheduler_metadata) - logger = init_logger(__name__) @@ -443,7 +441,7 @@ class FlashAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( + reshape_and_cache_flash( key, value, key_cache, diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py new file mode 100644 index 0000000000000..55d116dcd4968 --- /dev/null +++ b/vllm/v1/worker/xpu_model_runner.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class XPUModelRunner(GPUModelRunner): + """A model runner for XPU devices.""" + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(vllm_config, device) + # FIXME: To be verified. + self.cascade_attn_enabled = False + + def _init_device_properties(self) -> None: + pass + + def _sync_device(self) -> None: + torch.xpu.synchronize() diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py new file mode 100644 index 0000000000000..d9ea03986566b --- /dev/null +++ b/vllm/v1/worker/xpu_worker.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.v1.worker.gpu_worker import (Worker, + init_worker_distributed_environment) +from vllm.v1.worker.xpu_model_runner import XPUModelRunner + +logger = init_logger(__name__) + + +class XPUWorker(Worker): + """A XPU worker class.""" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + super().__init__(vllm_config, local_rank, rank, + distributed_init_method, is_driver_worker) + device_config = self.device_config + assert device_config.device_type == "xpu" + assert current_platform.is_xpu() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + # we provide this function due to `torch.xpu.mem_get_info()` doesn't + # return correct free_gpu_memory on intel client GPU. We need to + # calculate/estiamte it. + def xpu_get_mem_info(self): + if current_platform.is_data_center_gpu(): + return torch.xpu.mem_get_info() + else: + _, total_gpu_memory = torch.xpu.mem_get_info() + # FIXME: memory_allocated() doesn't count non-torch allocations, + # and we don't have any API to get it. so we mark it as 128MB. + used_memory = torch.xpu.memory_allocated() + non_torch_allocations = 128 * 1024 * 1024 + free_gpu_memory = total_gpu_memory - (used_memory + + non_torch_allocations) + return free_gpu_memory, total_gpu_memory + + @torch.inference_mode() + def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.xpu.empty_cache() + torch.xpu.reset_peak_memory_stats() + + free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() + current_allocated_bytes = torch.xpu.memory_allocated() + msg = ("Before memory profiling run, " + f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " + f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + logger.info(msg) + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + free_gpu_memory, _ = self.xpu_get_mem_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + assert self.init_gpu_memory > free_gpu_memory, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + # Get the peak memory allocation recorded by torch + peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] + + torch.xpu.empty_cache() + torch_allocated_bytes = torch.xpu.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = self.xpu_get_mem_info( + )[1] - self.xpu_get_mem_info()[0] + + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + msg = ("After memory profiling run, " + f"peak memory usage is {peak_memory / 1024**2:.2f} MB," + f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " + f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + logger.info(msg) + + return int(available_kv_cache_memory) + + def init_device(self): + if self.device_config.device.type == "xpu" and current_platform.is_xpu( + ): + self.device = torch.device(f"xpu:{self.local_rank}") + torch.xpu.set_device(self.device) + torch.xpu.empty_cache() + self.init_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd") + ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(self.parallel_config.world_size)) + os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE + os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) + dist_backend = "ccl" + + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank, dist_backend) + + # global all_reduce needed for overall oneccl warm up + torch.distributed.all_reduce(torch.zeros(1).xpu()) + + # Set random seed. + set_random_seed(self.model_config.seed) + + # Construct the model runner + self.model_runner = XPUModelRunner( # type: ignore + self.vllm_config, self.device)