mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 07:54:30 +08:00
[Hardware][Intel GPU] Add v1 Intel GPU support with Flash attention backend. (#19560)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
0bceac9810
commit
b69781f107
@ -28,4 +28,5 @@ docker run \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
'
|
||||
|
||||
@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
ENV VLLM_TARGET_DEVICE=xpu
|
||||
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
|
||||
@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0
|
||||
wheel
|
||||
jinja2>=3.1.6
|
||||
datasets # for benchmark scripts
|
||||
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
|
||||
torch==2.7.0+xpu
|
||||
torchaudio
|
||||
|
||||
@ -228,6 +228,111 @@ class ipex_ops:
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache_flash(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: Optional[torch.Tensor] = None,
|
||||
v_scale: Optional[torch.Tensor] = None,
|
||||
k_scale_float: float = 1.0,
|
||||
v_scale_float: float = 1.0,
|
||||
) -> None:
|
||||
assert kv_cache_dtype == "auto"
|
||||
# TODO: support FP8 kv cache.
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||
key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
out: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
seqused_k: torch.Tensor, # we don't support this in ipex kernel
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
causal: bool,
|
||||
block_table: torch.Tensor,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
window_size: Optional[list[int]] = None,
|
||||
softcap: Optional[float] = 0.0,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
# The following parameters are not used in ipex kernel currently,
|
||||
# we keep API compatible to CUDA's.
|
||||
scheduler_metadata=None,
|
||||
fa_version: int = 2,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
):
|
||||
if cu_seqlens_k is None:
|
||||
# cu_seqlens_k is not used in ipex kernel.
|
||||
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
|
||||
cu_seqlens_k = torch.cat([
|
||||
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
|
||||
cu_seqlens_k
|
||||
]).to(torch.int32)
|
||||
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1])
|
||||
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
q.contiguous(),
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
softcap=softcap,
|
||||
window_size_left=real_window_size[0],
|
||||
window_size_right=real_window_size[1],
|
||||
k_scale=1.0,
|
||||
v_scale=1.0,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler_metadata(
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
headdim,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype=torch.bfloat16,
|
||||
headdim_v=None,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||
cache_leftpad: Optional[torch.Tensor] = None,
|
||||
page_size: Optional[int] = None,
|
||||
max_seqlen_k_new=0,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
has_softcap=False,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"get_scheduler_metadata is not implemented for ipex_ops, "
|
||||
"returning None.")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
|
||||
@ -4,13 +4,27 @@ from typing import Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm import _custom_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata)
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_xpu():
|
||||
return 2
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
@ -50,6 +64,5 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
from vllm.platforms import current_platform
|
||||
return get_flash_attn_version() == 3 and \
|
||||
current_platform.get_device_capability().major == 9
|
||||
|
||||
@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
||||
if envs.VLLM_USE_V1:
|
||||
if envs.VLLM_USE_V1 and not current_platform.is_xpu():
|
||||
# V1 uses SPMD worker and compiled DAG
|
||||
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -35,8 +38,13 @@ class XPUPlatform(Platform):
|
||||
use_mla: bool) -> str:
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
logger.info("Using IPEX attention backend.")
|
||||
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
if use_v1:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
else:
|
||||
logger.info("Using IPEX attention backend.")
|
||||
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
@ -67,25 +75,27 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
cache_config = vllm_config.cache_config
|
||||
# in V1(or with ipex chunked prefill) block_size is 64
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
if envs.VLLM_USE_V1:
|
||||
cache_config.block_size = 64
|
||||
else:
|
||||
cache_config.block_size = 16
|
||||
|
||||
# check and update model config
|
||||
model_config = vllm_config.model_config
|
||||
if model_config.dtype == torch.bfloat16:
|
||||
bf16_supported = cls.device_support_bf16()
|
||||
if not bf16_supported:
|
||||
# Instances created using VllmConfig() typically have model_config as
|
||||
# None by default. The modification involves adding a check to prevent
|
||||
# potential null exceptions check and update model config.
|
||||
if vllm_config.model_config is not None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_config.dtype == torch.bfloat16:
|
||||
bf16_supported = cls.device_support_bf16()
|
||||
if not bf16_supported:
|
||||
model_config.dtype = torch.float16
|
||||
if not model_config.enforce_eager:
|
||||
logger.warning(
|
||||
"bfloat16 is only supported on Intel Data Center GPU, "
|
||||
"Intel Arc GPU is not supported yet. Your device is %s,"
|
||||
" which is not supported. will fallback to float16",
|
||||
cls.get_device_name())
|
||||
model_config.dtype = torch.float16
|
||||
if not model_config.enforce_eager:
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on XPU, fallback to the eager "
|
||||
"mode.")
|
||||
model_config.enforce_eager = True
|
||||
"CUDA graph is not supported on XPU, fallback to the eager "
|
||||
"mode.")
|
||||
model_config.enforce_eager = True
|
||||
|
||||
if vllm_config.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
@ -96,21 +106,27 @@ class XPUPlatform(Platform):
|
||||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls =\
|
||||
"vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
||||
|
||||
if parallel_config.distributed_executor_backend is None:
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
if parallel_config.world_size > 1:
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
else:
|
||||
parallel_config.distributed_executor_backend = "uni"
|
||||
elif parallel_config.distributed_executor_backend == "mp":
|
||||
# FIXME(kunshang):
|
||||
# spawn needs calling `if __name__ == '__main__':``
|
||||
# fork is not supported for xpu start new process.
|
||||
logger.error(
|
||||
"Both start methods (spawn and fork) have issue "
|
||||
"on XPU if you use mp backend, setting it to ray instead.")
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
elif parallel_config.distributed_executor_backend != "ray":
|
||||
if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
logger.warning(
|
||||
"Please use spawn as start method if you want to use mp.")
|
||||
elif parallel_config.distributed_executor_backend != "ray" and \
|
||||
parallel_config.distributed_executor_backend != "uni":
|
||||
logger.warning(
|
||||
"%s is not supported on XPU, fallback to ray distributed"
|
||||
" executor backend.",
|
||||
@ -142,15 +158,35 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def device_support_bf16(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
if device_name.count("arc") > 0:
|
||||
if cls.is_client_gpu_a770():
|
||||
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
|
||||
" fallback to float16")
|
||||
return False
|
||||
elif device_name.count("data center gpu") > 0:
|
||||
return True
|
||||
else:
|
||||
logger.warning("Unknown device name %s, always use float16",
|
||||
device_name)
|
||||
return False
|
||||
logger.info(
|
||||
"Device name %s supports bfloat16. Please file an issue "
|
||||
"if you encounter any accuracy problems with bfloat16.",
|
||||
device_name)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_data_center_gpu(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
return device_name.count("data center gpu") > 0
|
||||
|
||||
@classmethod
|
||||
def is_client_gpu_a770(cls) -> bool:
|
||||
device_name = cls.get_device_name().lower()
|
||||
return device_name.count("a770") > 0
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||
|
||||
@classmethod
|
||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return torch.xpu.device_count()
|
||||
|
||||
@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
flash_attn_varlen_func,
|
||||
get_flash_attn_version,
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
@ -28,10 +30,6 @@ from vllm.v1.worker.block_table import BlockTable
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -443,7 +441,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
|
||||
32
vllm/v1/worker/xpu_model_runner.py
Normal file
32
vllm/v1/worker/xpu_model_runner.py
Normal file
@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUModelRunner(GPUModelRunner):
|
||||
"""A model runner for XPU devices."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(vllm_config, device)
|
||||
# FIXME: To be verified.
|
||||
self.cascade_attn_enabled = False
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
pass
|
||||
|
||||
def _sync_device(self) -> None:
|
||||
torch.xpu.synchronize()
|
||||
164
vllm/v1/worker/xpu_worker.py
Normal file
164
vllm/v1/worker/xpu_worker.py
Normal file
@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUWorker(Worker):
|
||||
"""A XPU worker class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
super().__init__(vllm_config, local_rank, rank,
|
||||
distributed_init_method, is_driver_worker)
|
||||
device_config = self.device_config
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.XPU,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
|
||||
# return correct free_gpu_memory on intel client GPU. We need to
|
||||
# calculate/estiamte it.
|
||||
def xpu_get_mem_info(self):
|
||||
if current_platform.is_data_center_gpu():
|
||||
return torch.xpu.mem_get_info()
|
||||
else:
|
||||
_, total_gpu_memory = torch.xpu.mem_get_info()
|
||||
# FIXME: memory_allocated() doesn't count non-torch allocations,
|
||||
# and we don't have any API to get it. so we mark it as 128MB.
|
||||
used_memory = torch.xpu.memory_allocated()
|
||||
non_torch_allocations = 128 * 1024 * 1024
|
||||
free_gpu_memory = total_gpu_memory - (used_memory +
|
||||
non_torch_allocations)
|
||||
return free_gpu_memory, total_gpu_memory
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.xpu.empty_cache()
|
||||
torch.xpu.reset_peak_memory_stats()
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
|
||||
current_allocated_bytes = torch.xpu.memory_allocated()
|
||||
msg = ("Before memory profiling run, "
|
||||
f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
|
||||
f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
|
||||
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
|
||||
logger.info(msg)
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
free_gpu_memory, _ = self.xpu_get_mem_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_gpu_memory > free_gpu_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
# Get the peak memory allocation recorded by torch
|
||||
peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
|
||||
|
||||
torch.xpu.empty_cache()
|
||||
torch_allocated_bytes = torch.xpu.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = self.xpu_get_mem_info(
|
||||
)[1] - self.xpu_get_mem_info()[0]
|
||||
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
|
||||
msg = ("After memory profiling run, "
|
||||
f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
|
||||
f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
|
||||
f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
|
||||
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
|
||||
logger.info(msg)
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
torch.xpu.set_device(self.device)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
|
||||
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd")
|
||||
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
|
||||
str(self.parallel_config.world_size))
|
||||
os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE
|
||||
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
dist_backend = "ccl"
|
||||
|
||||
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank, dist_backend)
|
||||
|
||||
# global all_reduce needed for overall oneccl warm up
|
||||
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
self.vllm_config, self.device)
|
||||
Loading…
x
Reference in New Issue
Block a user