mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:15:48 +08:00
[Bugfix] Allow CUDA_VISIBLE_DEVICES='' in Platform.device_id_to_physical_device_id (#18979)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
parent
9502c38138
commit
65397e40f5
38
tests/config/test_config_generation.py
Normal file
38
tests/config/test_config_generation.py
Normal file
@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.quantization.quark.utils import deep_compare
|
||||
|
||||
|
||||
def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that configs created with normal (untouched) CUDA_VISIBLE_DEVICES
|
||||
and CUDA_VISIBLE_DEVICES="" are equivalent. This ensures consistent
|
||||
behavior regardless of whether GPU visibility is disabled via empty string
|
||||
or left in its normal state.
|
||||
"""
|
||||
|
||||
def create_config():
|
||||
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
|
||||
trust_remote_code=True)
|
||||
return engine_args.create_engine_config()
|
||||
|
||||
# Create config with CUDA_VISIBLE_DEVICES set normally
|
||||
normal_config = create_config()
|
||||
|
||||
# Create config with CUDA_VISIBLE_DEVICES=""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("CUDA_VISIBLE_DEVICES", "")
|
||||
empty_config = create_config()
|
||||
|
||||
normal_config_dict = vars(normal_config)
|
||||
empty_config_dict = vars(empty_config)
|
||||
|
||||
# Remove instance_id before comparison as it's expected to be different
|
||||
normal_config_dict.pop("instance_id", None)
|
||||
empty_config_dict.pop("instance_id", None)
|
||||
|
||||
assert deep_compare(normal_config_dict, empty_config_dict), (
|
||||
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
|
||||
" should be equivalent")
|
||||
@ -8,8 +8,10 @@ import time
|
||||
import uuid
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
@ -517,3 +519,72 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
assert "Engine core initialization failed" in str(e_info.value)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core_proc_instantiation_cuda_empty(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
|
||||
is empty. This ensures the engine frontend does not need access to GPUs.
|
||||
"""
|
||||
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
# Create a simple mock executor instead of a complex custom class
|
||||
mock_executor_class = MagicMock(spec=Executor)
|
||||
|
||||
def create_mock_executor(vllm_config):
|
||||
mock_executor = MagicMock()
|
||||
|
||||
# Only implement the methods that are actually called during init
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
mock_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
|
||||
mock_executor.get_kv_cache_specs.return_value = [{
|
||||
"default": mock_spec
|
||||
}]
|
||||
mock_executor.determine_available_memory.return_value = [
|
||||
1024 * 1024 * 1024
|
||||
]
|
||||
mock_executor.initialize_from_config.return_value = None
|
||||
mock_executor.max_concurrent_batches = 1
|
||||
|
||||
return mock_executor
|
||||
|
||||
mock_executor_class.side_effect = create_mock_executor
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
|
||||
|
||||
from vllm.v1.utils import EngineZmqAddresses
|
||||
|
||||
def mock_startup_handshake(self, handshake_socket, on_head_node,
|
||||
parallel_config):
|
||||
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
|
||||
outputs=["tcp://127.0.0.1:5556"],
|
||||
coordinator_input=None,
|
||||
coordinator_output=None)
|
||||
|
||||
# Background processes are not important here
|
||||
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
|
||||
|
||||
vllm_config = EngineArgs(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite",
|
||||
trust_remote_code=True).create_engine_config()
|
||||
engine_core_proc = EngineCoreProc(
|
||||
vllm_config=vllm_config,
|
||||
on_head_node=True,
|
||||
handshake_address="tcp://127.0.0.1:12345",
|
||||
executor_class=mock_executor_class,
|
||||
log_stats=False,
|
||||
engine_index=0,
|
||||
)
|
||||
|
||||
engine_core_proc.shutdown()
|
||||
|
||||
@ -173,17 +173,12 @@ class Platform:
|
||||
|
||||
@classmethod
|
||||
def device_id_to_physical_device_id(cls, device_id: int):
|
||||
if cls.device_control_env_var in os.environ:
|
||||
# Treat empty device control env var as unset. This is a valid
|
||||
# configuration in Ray setups where the engine is launched in
|
||||
# a CPU-only placement group located on a GPU node.
|
||||
if cls.device_control_env_var in os.environ and os.environ[
|
||||
cls.device_control_env_var] != "":
|
||||
device_ids = os.environ[cls.device_control_env_var].split(",")
|
||||
if device_ids == [""]:
|
||||
msg = (f"{cls.device_control_env_var} is set to empty string, "
|
||||
"which means current platform support is disabled. If "
|
||||
"you are using ray, please unset the environment "
|
||||
f"variable `{cls.device_control_env_var}` inside the "
|
||||
"worker/actor. Check "
|
||||
"https://github.com/vllm-project/vllm/issues/8402 for "
|
||||
"more information.")
|
||||
raise RuntimeError(msg)
|
||||
physical_device_id = device_ids[device_id]
|
||||
return int(physical_device_id)
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user