mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 00:06:22 +08:00
137 lines
5.3 KiB
Python
137 lines
5.3 KiB
Python
# Copyright 2023 The vLLM team.
|
|
# Adapted from
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
import json
|
|
import os
|
|
from typing import Dict, Optional, Sequence
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
|
|
from .parallel_state import get_cpu_world_group, get_local_rank
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def ensure_divisibility(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator."""
|
|
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
|
numerator, denominator)
|
|
|
|
|
|
def divide(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator and return
|
|
the division value."""
|
|
ensure_divisibility(numerator, denominator)
|
|
return numerator // denominator
|
|
|
|
|
|
def split_tensor_along_last_dim(
|
|
tensor: torch.Tensor,
|
|
num_partitions: int,
|
|
contiguous_split_chunks: bool = False,
|
|
) -> Sequence[torch.Tensor]:
|
|
""" Split a tensor along its last dimension.
|
|
|
|
Arguments:
|
|
tensor: input tensor.
|
|
num_partitions: number of partitions to split the tensor
|
|
contiguous_split_chunks: If True, make each chunk contiguous
|
|
in memory.
|
|
|
|
Returns:
|
|
A list of Tensors
|
|
"""
|
|
# Get the size and dimension.
|
|
last_dim = tensor.dim() - 1
|
|
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
|
# Split.
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
# NOTE: torch.split does not create contiguous tensors by default.
|
|
if contiguous_split_chunks:
|
|
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
|
|
return tensor_list
|
|
|
|
|
|
# code partly borrowed from
|
|
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
|
|
# License: MIT
|
|
def _can_actually_p2p(idx_a, idx_b):
|
|
dev_i = f"cuda:{idx_a}"
|
|
dev_j = f"cuda:{idx_b}"
|
|
a = torch.randn(5, device=dev_i) + 123.0
|
|
b = a.to(dev_j)
|
|
c = b.to(dev_i)
|
|
return torch.all(a == c).cpu().item()
|
|
|
|
|
|
# why do we need this cache?
|
|
# 1. we can have runtime checks for P2P access, where every process checks
|
|
# P2P access to all other GPUs. Unfortunately, the test might cost many
|
|
# (world_size * world_size) cuda context, and reduce the memory available
|
|
# for the model. see https://github.com/vllm-project/vllm/issues/3821
|
|
# 2. alternatively, we can have a p2p map that is generated by the master
|
|
# process and broadcasted to all other processes. This still requires
|
|
# #world_size of cuda context, belonging to the master process, on each GPU.
|
|
# 3. we can have a cache file, that records the p2p access status. The first
|
|
# time the master process checks the p2p access, it will generate the cache
|
|
# file, at the cost of #world_size of cuda context. Later on, all processes
|
|
# can read the cache file to check the p2p access status without any cost of
|
|
# additional cuda context.
|
|
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
|
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
|
# e.g. used by different vllm engines. The device id in the cache file is a
|
|
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
|
# of visible devices in the vllm engine.
|
|
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
|
|
|
|
|
def gpu_p2p_access_check(i: int, j: int) -> bool:
|
|
"""Check if GPU i can access GPU j."""
|
|
|
|
# if the cache variable is already calculated,
|
|
# read from the cache instead of checking it again
|
|
global _gpu_p2p_access_cache
|
|
if _gpu_p2p_access_cache is not None:
|
|
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
|
|
|
is_distributed = dist.is_initialized()
|
|
|
|
num_dev = torch.cuda.device_count()
|
|
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
|
if cuda_visible_devices is None:
|
|
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
|
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
|
|
path = os.path.expanduser(
|
|
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
|
)
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
if (not is_distributed or get_local_rank() == 0) \
|
|
and (not os.path.exists(path)):
|
|
# only the local master process (with local_rank == 0) can
|
|
# enter this block to calculate the cache
|
|
logger.info("generating GPU P2P access cache for in %s", path)
|
|
cache = {}
|
|
for _i in range(num_dev):
|
|
for _j in range(num_dev):
|
|
# on some platforms, P2P support might be buggy and we need
|
|
# additional checks. See also:
|
|
# https://github.com/vllm-project/vllm/issues/2728
|
|
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
|
|
_i, _j) and _can_actually_p2p(_i, _j)
|
|
with open(path, "w") as f:
|
|
json.dump(cache, f, indent=4)
|
|
if is_distributed:
|
|
cpu_world_group = get_cpu_world_group()
|
|
dist.barrier(cpu_world_group)
|
|
logger.info("reading GPU P2P access cache from %s", path)
|
|
with open(path, "r") as f:
|
|
cache = json.load(f)
|
|
_gpu_p2p_access_cache = cache
|
|
return _gpu_p2p_access_cache[f"{i}->{j}"]
|