mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[ROCm][AMD] unify CUDA_VISIBLE_DEVICES usage in cuda/rocm (#6352)
This commit is contained in:
parent
d26a8b3f1f
commit
b6c16cf8ff
@ -52,25 +52,25 @@ RUN pip install --upgrade pip
|
||||
# Remove sccache so it doesn't interfere with ccache
|
||||
# TODO: implement sccache support across components
|
||||
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
# Install torch == 2.4.0 on ROCm
|
||||
# Install torch == 2.5.0 on ROCm
|
||||
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-5.7"*) \
|
||||
pip uninstall -y torch torchaudio torchvision \
|
||||
&& pip install --no-cache-dir --pre \
|
||||
torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
|
||||
torchvision==0.19.0.dev20240612 \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
|
||||
*"rocm-6.0"*) \
|
||||
pip uninstall -y torch torchaudio torchvision \
|
||||
&& pip install --no-cache-dir --pre \
|
||||
torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
|
||||
torchvision==0.19.0.dev20240612 \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
|
||||
*"rocm-6.1"*) \
|
||||
pip uninstall -y torch torchaudio torchvision \
|
||||
&& pip install --no-cache-dir --pre \
|
||||
torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
|
||||
torchvision==0.19.0.dev20240612 \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
|
||||
*) ;; esac
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import ray
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.utils import (cuda_device_count_stateless, is_hip,
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
update_environment_variables)
|
||||
|
||||
|
||||
@ -22,11 +22,6 @@ class _CUDADeviceCountStatelessTestActor:
|
||||
def test_cuda_device_count_stateless():
|
||||
"""Test that cuda_device_count_stateless changes return value if
|
||||
CUDA_VISIBLE_DEVICES is changed."""
|
||||
if is_hip():
|
||||
# Set HIP_VISIBLE_DEVICES == CUDA_VISIBLE_DEVICES. Conversion
|
||||
# is handled by `update_environment_variables`
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
|
||||
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
||||
num_gpus=2).remote()
|
||||
assert sorted(ray.get(
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
@ -14,7 +13,7 @@ from vllm.tracing import is_otel_installed
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
|
||||
print_warning_once, update_environment_variables)
|
||||
print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -695,12 +694,6 @@ class ParallelConfig:
|
||||
self.distributed_executor_backend = backend
|
||||
logger.info("Defaulting to use %s for distributed inference",
|
||||
backend)
|
||||
# If CUDA_VISIBLE_DEVICES is set on ROCm prior to vLLM init,
|
||||
# propagate changes to HIP_VISIBLE_DEVICES (conversion handled by
|
||||
# the update_environment_variables function)
|
||||
if is_hip() and envs.CUDA_VISIBLE_DEVICES:
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
|
||||
|
||||
self._verify_args()
|
||||
self.rank = 0
|
||||
|
||||
@ -386,10 +386,6 @@ def get_open_port() -> int:
|
||||
|
||||
|
||||
def update_environment_variables(envs: Dict[str, str]):
|
||||
if is_hip() and "CUDA_VISIBLE_DEVICES" in envs:
|
||||
# Propagate changes to CUDA_VISIBLE_DEVICES to
|
||||
# ROCm's HIP_VISIBLE_DEVICES as well
|
||||
envs["HIP_VISIBLE_DEVICES"] = envs["CUDA_VISIBLE_DEVICES"]
|
||||
for k, v in envs.items():
|
||||
if k in os.environ and os.environ[k] != v:
|
||||
logger.warning(
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput)
|
||||
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
update_environment_variables)
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
|
||||
@ -309,14 +309,6 @@ class WorkerWrapperBase:
|
||||
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||
# suppress the warning in `update_environment_variables`
|
||||
del os.environ[key]
|
||||
if is_hip():
|
||||
hip_env_var = "HIP_VISIBLE_DEVICES"
|
||||
if hip_env_var in os.environ:
|
||||
logger.warning(
|
||||
"Ignoring pre-set environment variable `%s=%s` as "
|
||||
"%s has also been set, which takes precedence.",
|
||||
hip_env_var, os.environ[hip_env_var], key)
|
||||
os.environ.pop(hip_env_var, None)
|
||||
update_environment_variables(envs)
|
||||
|
||||
def init_worker(self, *args, **kwargs):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user