mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 01:45:01 +08:00
Update num_tokens_across_dp to use nccl instead of gloo (#24105)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
a0b26701c9
commit
49bfc538e4
@ -92,6 +92,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_DISABLED_KERNELS: list[str] = []
|
||||
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
|
||||
VLLM_USE_V1: bool = True
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
|
||||
@ -745,6 +746,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
|
||||
"VLLM_DISABLED_KERNELS"].split(","),
|
||||
|
||||
# Swaps the all reduce backend that we use to coordinate the DP padding
|
||||
# information from NCCL to gloo.
|
||||
"VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION":
|
||||
lambda:
|
||||
(os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# If set, use the V1 code path.
|
||||
"VLLM_USE_V1":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
|
||||
|
||||
@ -13,6 +13,7 @@ import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -75,14 +76,26 @@ class DPMetadata:
|
||||
Gather the num_tokens across all DP ranks and return results in a
|
||||
CPU tensor of size dp_size.
|
||||
"""
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
device = current_platform.device_type
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
||||
# point that could adversely affect performance of vllm with asynch
|
||||
# scheduling. This environment variable exists to quickly disable
|
||||
# this optimization if we run into this case.
|
||||
if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION:
|
||||
logger.info_once(
|
||||
"Using CPU all reduce to syncronize DP padding between ranks.")
|
||||
device = "cpu"
|
||||
group = get_dp_group().cpu_group
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = num_tokens
|
||||
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
||||
device="cpu",
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||
return num_tokens_tensor
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user