mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 01:15:01 +08:00
[Misc] Replace CUDA_VISIBLE_DEVICES in DP with torch.cuda.set_device for device selection on cuda-like devices (#27564)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
parent
e5e076cad7
commit
60f76baa66
@ -1008,11 +1008,14 @@ class NixlConnectorWorker:
|
|||||||
# Enable different block lengths for different layers when MLA is used.
|
# Enable different block lengths for different layers when MLA is used.
|
||||||
self.block_len_per_layer = list[int]()
|
self.block_len_per_layer = list[int]()
|
||||||
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
|
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
|
||||||
|
self.device_id = self.tp_rank
|
||||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||||
|
|
||||||
for cache in cache_list:
|
for cache in cache_list:
|
||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
|
if not self.use_host_buffer and current_platform.is_cuda_alike():
|
||||||
|
self.device_id = cache.device.index
|
||||||
if base_addr in seen_base_addresses:
|
if base_addr in seen_base_addresses:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -1040,7 +1043,7 @@ class NixlConnectorWorker:
|
|||||||
"All kv cache tensors must have the same size"
|
"All kv cache tensors must have the same size"
|
||||||
)
|
)
|
||||||
caches_data.append(
|
caches_data.append(
|
||||||
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
|
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -1087,7 +1090,7 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * self.block_len_per_layer[i]
|
block_offset = block_id * self.block_len_per_layer[i]
|
||||||
addr = base_addr + block_offset
|
addr = base_addr + block_offset
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||||
|
|
||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
# Separate and interleave K/V regions to maintain the same
|
# Separate and interleave K/V regions to maintain the same
|
||||||
@ -1098,12 +1101,13 @@ class NixlConnectorWorker:
|
|||||||
addr = base_addr + block_offset
|
addr = base_addr + block_offset
|
||||||
# Register addresses for V cache (K registered first).
|
# Register addresses for V cache (K registered first).
|
||||||
v_addr = addr + kv_block_len
|
v_addr = addr + kv_block_len
|
||||||
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
blocks_data.append((v_addr, kv_block_len, self.device_id))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created %s blocks for src engine %s and rank %s",
|
"Created %s blocks for src engine %s and rank %s on device id %s",
|
||||||
len(blocks_data),
|
len(blocks_data),
|
||||||
self.engine_id,
|
self.engine_id,
|
||||||
self.tp_rank,
|
self.tp_rank,
|
||||||
|
self.device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||||
|
|||||||
@ -134,9 +134,18 @@ class CoreEngineProcManager:
|
|||||||
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
|
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
|
||||||
try:
|
try:
|
||||||
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
||||||
|
# Adjust device control in DP for non-CUDA platforms
|
||||||
|
# as well as external and ray launchers
|
||||||
|
# For CUDA platforms, we use torch.cuda.set_device()
|
||||||
with (
|
with (
|
||||||
set_device_control_env_var(vllm_config, local_dp_rank)
|
set_device_control_env_var(vllm_config, local_dp_rank)
|
||||||
if (data_parallel)
|
if (
|
||||||
|
data_parallel
|
||||||
|
and (
|
||||||
|
not current_platform.is_cuda_alike()
|
||||||
|
or vllm_config.parallel_config.use_ray
|
||||||
|
)
|
||||||
|
)
|
||||||
else contextlib.nullcontext()
|
else contextlib.nullcontext()
|
||||||
):
|
):
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import torch.distributed as dist
|
|||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.v1.worker.ubatch_utils import (
|
from vllm.v1.worker.ubatch_utils import (
|
||||||
UBatchSlices,
|
UBatchSlices,
|
||||||
check_ubatch_thresholds,
|
check_ubatch_thresholds,
|
||||||
@ -20,7 +19,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def _get_device_and_group(parallel_config: ParallelConfig):
|
def _get_device_and_group(parallel_config: ParallelConfig):
|
||||||
device = current_platform.device_type
|
# Use the actual device assigned to the DP group, not just the device type
|
||||||
|
device = get_dp_group().device
|
||||||
group = get_dp_group().device_group
|
group = get_dp_group().device_group
|
||||||
|
|
||||||
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
||||||
|
|||||||
@ -172,6 +172,29 @@ class Worker(WorkerBase):
|
|||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
# This env var set by Ray causes exceptions with graph building.
|
# This env var set by Ray causes exceptions with graph building.
|
||||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||||
|
if (
|
||||||
|
self.parallel_config.data_parallel_size > 1
|
||||||
|
and self.parallel_config.data_parallel_size_local > 0
|
||||||
|
and self.parallel_config.distributed_executor_backend
|
||||||
|
not in ["ray", "external_launcher"]
|
||||||
|
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
|
||||||
|
):
|
||||||
|
# Use local DP rank if available, otherwise use global DP rank.
|
||||||
|
dp_local_rank = self.parallel_config.data_parallel_rank_local
|
||||||
|
if dp_local_rank is None:
|
||||||
|
dp_local_rank = self.parallel_config.data_parallel_rank
|
||||||
|
|
||||||
|
tp_pp_world_size = (
|
||||||
|
self.parallel_config.pipeline_parallel_size
|
||||||
|
* self.parallel_config.tensor_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
|
||||||
|
self.local_rank += dp_local_rank * tp_pp_world_size
|
||||||
|
assert self.local_rank < torch.cuda.device_count(), (
|
||||||
|
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||||
|
)
|
||||||
|
|
||||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||||
current_platform.set_device(self.device)
|
current_platform.set_device(self.device)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user