mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:37:25 +08:00
[Misc] Remove use of CUDA_VISIBLE_DEVICES for device selection (fix DP slow startup time &c) (#26709)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
parent
faee3ccdc2
commit
237cf6d32a
@ -991,11 +991,14 @@ class NixlConnectorWorker:
|
||||
# Enable different block lengths for different layers when MLA is used.
|
||||
self.block_len_per_layer = list[int]()
|
||||
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
|
||||
self.device_id = self.tp_rank
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
if not self.use_host_buffer and current_platform.is_cuda_alike():
|
||||
self.device_id = cache.device.index
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
@ -1023,7 +1026,7 @@ class NixlConnectorWorker:
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
caches_data.append(
|
||||
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
|
||||
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@ -1070,7 +1073,7 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||
|
||||
if self._use_flashinfer:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
@ -1081,12 +1084,13 @@ class NixlConnectorWorker:
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
||||
blocks_data.append((v_addr, kv_block_len, self.device_id))
|
||||
logger.debug(
|
||||
"Created %s blocks for src engine %s and rank %s",
|
||||
"Created %s blocks for src engine %s and rank %s on device id %s",
|
||||
len(blocks_data),
|
||||
self.engine_id,
|
||||
self.tp_rank,
|
||||
self.device_id,
|
||||
)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
|
||||
@ -134,9 +134,12 @@ class CoreEngineProcManager:
|
||||
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
|
||||
try:
|
||||
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
||||
# Adjust device control in DP for non-CUDA platforms
|
||||
# For CUDA platforms, setting same device id for different DP
|
||||
# processes affects NCCL init performance.
|
||||
with (
|
||||
set_device_control_env_var(vllm_config, local_dp_rank)
|
||||
if (data_parallel)
|
||||
if (data_parallel and not current_platform.is_cuda_alike())
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
proc.start()
|
||||
|
||||
@ -8,7 +8,6 @@ import torch.distributed as dist
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
@ -20,7 +19,8 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_device_and_group(parallel_config: ParallelConfig):
|
||||
device = current_platform.device_type
|
||||
# Use the actual device assigned to the DP group, not just the device type
|
||||
device = get_dp_group().device
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
||||
|
||||
@ -169,6 +169,27 @@ class Worker(WorkerBase):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
if (
|
||||
self.parallel_config.data_parallel_size > 1
|
||||
and self.parallel_config.data_parallel_size_local > 0
|
||||
and self.parallel_config.data_parallel_backend != "ray"
|
||||
):
|
||||
# Use local DP rank if available, otherwise use global DP rank.
|
||||
dp_local_rank = self.parallel_config.data_parallel_rank_local
|
||||
if dp_local_rank is None:
|
||||
dp_local_rank = self.parallel_config.data_parallel_rank
|
||||
|
||||
tp_pp_world_size = (
|
||||
self.parallel_config.pipeline_parallel_size
|
||||
* self.parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
|
||||
self.local_rank += dp_local_rank * tp_pp_world_size
|
||||
assert self.local_rank <= torch.cuda.device_count(), (
|
||||
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||
)
|
||||
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user