[Bugfix] Allow CUDA_VISIBLE_DEVICES='' in Platform.device_id_to_physical_device_id (#18979)

Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
Seiji Eicher 2025-06-26 00:01:57 -07:00 committed by GitHub
parent 9502c38138
commit 65397e40f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 114 additions and 10 deletions

View File

@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.quantization.quark.utils import deep_compare
def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
"""Test that configs created with normal (untouched) CUDA_VISIBLE_DEVICES
and CUDA_VISIBLE_DEVICES="" are equivalent. This ensures consistent
behavior regardless of whether GPU visibility is disabled via empty string
or left in its normal state.
"""
def create_config():
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True)
return engine_args.create_engine_config()
# Create config with CUDA_VISIBLE_DEVICES set normally
normal_config = create_config()
# Create config with CUDA_VISIBLE_DEVICES=""
with monkeypatch.context() as m:
m.setenv("CUDA_VISIBLE_DEVICES", "")
empty_config = create_config()
normal_config_dict = vars(normal_config)
empty_config_dict = vars(empty_config)
# Remove instance_id before comparison as it's expected to be different
normal_config_dict.pop("instance_id", None)
empty_config_dict.pop("instance_id", None)
assert deep_compare(normal_config_dict, empty_config_dict), (
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
" should be equivalent")

View File

@ -8,8 +8,10 @@ import time
import uuid
from threading import Thread
from typing import Optional
from unittest.mock import MagicMock
import pytest
import torch
from transformers import AutoTokenizer
from tests.utils import multi_gpu_test
@ -517,3 +519,72 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
)
assert "Engine core initialization failed" in str(e_info.value)
@create_new_process_for_each_test()
def test_engine_core_proc_instantiation_cuda_empty(
monkeypatch: pytest.MonkeyPatch):
"""
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
is empty. This ensures the engine frontend does not need access to GPUs.
"""
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.executor.abstract import Executor
# Create a simple mock executor instead of a complex custom class
mock_executor_class = MagicMock(spec=Executor)
def create_mock_executor(vllm_config):
mock_executor = MagicMock()
# Only implement the methods that are actually called during init
from vllm.v1.kv_cache_interface import FullAttentionSpec
mock_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
mock_executor.get_kv_cache_specs.return_value = [{
"default": mock_spec
}]
mock_executor.determine_available_memory.return_value = [
1024 * 1024 * 1024
]
mock_executor.initialize_from_config.return_value = None
mock_executor.max_concurrent_batches = 1
return mock_executor
mock_executor_class.side_effect = create_mock_executor
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
from vllm.v1.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,
coordinator_output=None)
# Background processes are not important here
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
vllm_config = EngineArgs(
model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True).create_engine_config()
engine_core_proc = EngineCoreProc(
vllm_config=vllm_config,
on_head_node=True,
handshake_address="tcp://127.0.0.1:12345",
executor_class=mock_executor_class,
log_stats=False,
engine_index=0,
)
engine_core_proc.shutdown()

View File

@ -173,17 +173,12 @@ class Platform:
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):
if cls.device_control_env_var in os.environ:
# Treat empty device control env var as unset. This is a valid
# configuration in Ray setups where the engine is launched in
# a CPU-only placement group located on a GPU node.
if cls.device_control_env_var in os.environ and os.environ[
cls.device_control_env_var] != "":
device_ids = os.environ[cls.device_control_env_var].split(",")
if device_ids == [""]:
msg = (f"{cls.device_control_env_var} is set to empty string, "
"which means current platform support is disabled. If "
"you are using ray, please unset the environment "
f"variable `{cls.device_control_env_var}` inside the "
"worker/actor. Check "
"https://github.com/vllm-project/vllm/issues/8402 for "
"more information.")
raise RuntimeError(msg)
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else: