mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 06:15:01 +08:00
[Core] RayWorkerVllm --> WorkerWrapper to reduce duplication (#4024)
[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (#4024)
This commit is contained in:
parent
11d652bd4f
commit
8438e0569e
@ -1,18 +1,18 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||||
ncclGetUniqueId)
|
ncclGetUniqueId)
|
||||||
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
def distributed_run(fn, world_size):
|
def distributed_run(fn, world_size):
|
||||||
number_of_processes = world_size
|
number_of_processes = world_size
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env = os.environ.copy()
|
env = {}
|
||||||
env['RANK'] = str(i)
|
env['RANK'] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env['LOCAL_RANK'] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env['WORLD_SIZE'] = str(number_of_processes)
|
||||||
@ -32,8 +32,7 @@ def update_env(fn):
|
|||||||
# so we need to pass the environment variables as arguments
|
# so we need to pass the environment variables as arguments
|
||||||
# and update the environment variables in the function
|
# and update the environment variables in the function
|
||||||
def wrapper(env):
|
def wrapper(env):
|
||||||
import os
|
update_environment_variables(env)
|
||||||
os.environ.update(env)
|
|
||||||
fn()
|
fn()
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@ -1,55 +1,28 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
|
from vllm.utils import get_ip, is_hip
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
class RayWorkerVllm:
|
class RayWorkerWrapper(WorkerWrapperBase):
|
||||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||||
|
|
||||||
def __init__(self, init_cached_hf_modules=False) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
if init_cached_hf_modules:
|
super().__init__(*args, **kwargs)
|
||||||
from transformers.dynamic_module_utils import init_hf_modules
|
|
||||||
init_hf_modules()
|
|
||||||
self._worker: Optional[Worker] = None
|
|
||||||
# Since the compiled DAG runs a main execution
|
# Since the compiled DAG runs a main execution
|
||||||
# in a different thread that calls cuda.set_device.
|
# in a different thread that calls cuda.set_device.
|
||||||
# The flag indicates is set_device is called on
|
# The flag indicates is set_device is called on
|
||||||
# that thread.
|
# that thread.
|
||||||
self.compiled_dag_cuda_device_set = False
|
self.compiled_dag_cuda_device_set = False
|
||||||
|
|
||||||
def init_worker(self, worker_init_fn: Callable[[], Worker]):
|
|
||||||
self._worker = worker_init_fn()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def worker(self) -> Worker:
|
|
||||||
assert self._worker is not None
|
|
||||||
return self._worker
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self.worker, name)
|
|
||||||
|
|
||||||
def execute_method(self, method, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
executor = getattr(self, method)
|
|
||||||
return executor(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
# exceptions in ray worker may cause deadlock
|
|
||||||
# see https://github.com/vllm-project/vllm/issues/3455
|
|
||||||
# print the error and inform the user to solve the error
|
|
||||||
msg = (f"Error executing method {method}. "
|
|
||||||
"This might cause deadlock in distributed execution.")
|
|
||||||
logger.exception(msg)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_node_ip(self) -> str:
|
def get_node_ip(self) -> str:
|
||||||
return get_ip()
|
return get_ip()
|
||||||
|
|
||||||
@ -58,9 +31,6 @@ try:
|
|||||||
gpu_ids = ray.get_gpu_ids()
|
gpu_ids = ray.get_gpu_ids()
|
||||||
return node_id, gpu_ids
|
return node_id, gpu_ids
|
||||||
|
|
||||||
def set_cuda_visible_devices(self, device_ids) -> None:
|
|
||||||
set_cuda_visible_devices(device_ids)
|
|
||||||
|
|
||||||
def execute_model_compiled_dag_remote(self, ignored):
|
def execute_model_compiled_dag_remote(self, ignored):
|
||||||
"""Used only when compiled DAG is enabled."""
|
"""Used only when compiled DAG is enabled."""
|
||||||
import torch
|
import torch
|
||||||
@ -77,7 +47,7 @@ except ImportError as e:
|
|||||||
"For distributed inference, please install Ray with "
|
"For distributed inference, please install Ray with "
|
||||||
"`pip install ray`.")
|
"`pip install ray`.")
|
||||||
ray = None # type: ignore
|
ray = None # type: ignore
|
||||||
RayWorkerVllm = None # type: ignore
|
RayWorkerWrapper = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def initialize_ray_cluster(
|
def initialize_ray_cluster(
|
||||||
|
|||||||
@ -1,17 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
from vllm.engine.ray_utils import RayWorkerWrapper, ray
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
make_async, set_cuda_visible_devices)
|
make_async)
|
||||||
|
|
||||||
if ray is not None:
|
if ray is not None:
|
||||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
@ -74,9 +73,9 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
# The driver dummy worker does not actually use any resources.
|
# The driver dummy worker does not actually use any resources.
|
||||||
# It holds the resource for the driver worker.
|
# It holds the resource for the driver worker.
|
||||||
self.driver_dummy_worker: RayWorkerVllm = None
|
self.driver_dummy_worker: RayWorkerWrapper = None
|
||||||
# The remaining workers are the actual ray actors.
|
# The remaining workers are the actual ray actors.
|
||||||
self.workers: List[RayWorkerVllm] = []
|
self.workers: List[RayWorkerWrapper] = []
|
||||||
|
|
||||||
if self.parallel_config.ray_workers_use_nsight:
|
if self.parallel_config.ray_workers_use_nsight:
|
||||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||||
@ -97,13 +96,20 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
num_gpus=num_gpus,
|
num_gpus=num_gpus,
|
||||||
scheduling_strategy=scheduling_strategy,
|
scheduling_strategy=scheduling_strategy,
|
||||||
**ray_remote_kwargs,
|
**ray_remote_kwargs,
|
||||||
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
|
)(RayWorkerWrapper).remote(
|
||||||
|
worker_module_name="vllm.worker.worker",
|
||||||
|
worker_class_name="Worker",
|
||||||
|
)
|
||||||
|
|
||||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||||
# If the worker is on the same node as the driver, we use it
|
# If the worker is on the same node as the driver, we use it
|
||||||
# as the resource holder for the driver process.
|
# as the resource holder for the driver process.
|
||||||
self.driver_dummy_worker = worker
|
self.driver_dummy_worker = worker
|
||||||
|
self.driver_worker = RayWorkerWrapper(
|
||||||
|
worker_module_name="vllm.worker.worker",
|
||||||
|
worker_class_name="Worker",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Else, added to the list of workers.
|
# Else, added to the list of workers.
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
@ -115,82 +121,56 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
"GPU node.")
|
"GPU node.")
|
||||||
|
|
||||||
# Get the set of GPU IDs used on each node.
|
# Get the set of GPU IDs used on each node.
|
||||||
driver_node_id, driver_gpu_ids = ray.get(
|
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||||
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
|
use_dummy_driver=True)
|
||||||
worker_node_and_gpu_ids = ray.get(
|
|
||||||
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
|
|
||||||
|
|
||||||
node_workers = defaultdict(list)
|
node_workers = defaultdict(list)
|
||||||
node_gpus = defaultdict(list)
|
node_gpus = defaultdict(list)
|
||||||
|
|
||||||
node_workers[driver_node_id].append(0)
|
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||||
node_gpus[driver_node_id].extend(driver_gpu_ids)
|
|
||||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
|
|
||||||
start=1):
|
|
||||||
node_workers[node_id].append(i)
|
node_workers[node_id].append(i)
|
||||||
node_gpus[node_id].extend(gpu_ids)
|
node_gpus[node_id].extend(gpu_ids)
|
||||||
for node_id, gpu_ids in node_gpus.items():
|
for node_id, gpu_ids in node_gpus.items():
|
||||||
node_gpus[node_id] = sorted(gpu_ids)
|
node_gpus[node_id] = sorted(gpu_ids)
|
||||||
|
|
||||||
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
|
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
|
||||||
set_cuda_visible_devices(node_gpus[driver_node_id])
|
all_args_to_update_environment_variables = []
|
||||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
for (node_id, _) in worker_node_and_gpu_ids:
|
||||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
all_args_to_update_environment_variables.append([{
|
||||||
|
"CUDA_VISIBLE_DEVICES":
|
||||||
|
",".join(map(str, node_gpus[node_id]))
|
||||||
|
}])
|
||||||
|
self._run_workers("update_environment_variables",
|
||||||
|
all_args=all_args_to_update_environment_variables)
|
||||||
|
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
driver_ip, get_open_port())
|
driver_ip, get_open_port())
|
||||||
|
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
def collect_arg_helper_func(**kwargs):
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# avoid writing `{"name": value}` manually
|
||||||
from vllm.worker.worker import Worker
|
return kwargs
|
||||||
|
|
||||||
model_config = copy.deepcopy(self.model_config)
|
init_worker_all_kwargs = []
|
||||||
parallel_config = copy.deepcopy(self.parallel_config)
|
|
||||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
|
||||||
load_config = copy.deepcopy(self.load_config)
|
|
||||||
device_config = copy.deepcopy(self.device_config)
|
|
||||||
lora_config = copy.deepcopy(self.lora_config)
|
|
||||||
cache_config = copy.deepcopy(self.cache_config)
|
|
||||||
vision_language_config = copy.deepcopy(self.vision_language_config)
|
|
||||||
|
|
||||||
# Initialize the actual workers with the Worker class.
|
# Initialize the actual workers inside worker wrapper.
|
||||||
for rank, (worker, (node_id, _)) in enumerate(
|
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
|
||||||
zip(self.workers, worker_node_and_gpu_ids),
|
|
||||||
start=1,
|
|
||||||
):
|
|
||||||
local_rank = node_workers[node_id].index(rank)
|
local_rank = node_workers[node_id].index(rank)
|
||||||
worker.init_worker.remote(
|
init_worker_all_kwargs.append(
|
||||||
lambda rank=rank, local_rank=local_rank: Worker(
|
collect_arg_helper_func(
|
||||||
model_config=model_config,
|
model_config=self.model_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
device_config=device_config,
|
device_config=self.device_config,
|
||||||
cache_config=cache_config,
|
cache_config=self.cache_config,
|
||||||
load_config=load_config,
|
load_config=self.load_config,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=lora_config,
|
lora_config=self.lora_config,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
|
is_driver_worker=rank == 0,
|
||||||
))
|
))
|
||||||
|
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||||
# Initialize the driver worker with the Worker class.
|
|
||||||
driver_rank = 0
|
|
||||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
|
||||||
self.driver_worker = Worker(
|
|
||||||
model_config=self.model_config,
|
|
||||||
parallel_config=self.parallel_config,
|
|
||||||
scheduler_config=self.scheduler_config,
|
|
||||||
device_config=self.device_config,
|
|
||||||
cache_config=self.cache_config,
|
|
||||||
local_rank=driver_local_rank,
|
|
||||||
rank=driver_rank,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
lora_config=self.lora_config,
|
|
||||||
vision_language_config=self.vision_language_config,
|
|
||||||
load_config=self.load_config,
|
|
||||||
is_driver_worker=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._run_workers("init_device")
|
self._run_workers("init_device")
|
||||||
self._run_workers(
|
self._run_workers(
|
||||||
@ -279,13 +259,35 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
*args,
|
*args,
|
||||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
driver_args: Optional[Tuple[Any]] = None,
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
all_args: Optional[List[List[Any]]] = None,
|
||||||
|
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
use_dummy_driver: bool = False,
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
use_ray_compiled_dag: bool = False,
|
use_ray_compiled_dag: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Runs the given method on all workers."""
|
"""Runs the given method on all workers.
|
||||||
|
all_args and all_kwargs are used to pass heterogeneous arguments,
|
||||||
|
i.e. different arguments for each worker.
|
||||||
|
"""
|
||||||
|
if driver_args is None:
|
||||||
|
driver_args = args
|
||||||
|
if driver_kwargs is None:
|
||||||
|
driver_kwargs = kwargs
|
||||||
|
|
||||||
|
# for mypy type checking
|
||||||
|
assert driver_args is not None
|
||||||
|
assert driver_kwargs is not None
|
||||||
|
if all_args is None:
|
||||||
|
all_args = [driver_args] + [args] * len(self.workers)
|
||||||
|
if all_kwargs is None:
|
||||||
|
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
|
||||||
|
|
||||||
|
# for mypy type checking
|
||||||
|
assert all_args is not None
|
||||||
|
assert all_kwargs is not None
|
||||||
|
|
||||||
if max_concurrent_workers:
|
if max_concurrent_workers:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -299,8 +301,10 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
else:
|
else:
|
||||||
# Start the ray workers first.
|
# Start the ray workers first.
|
||||||
ray_worker_outputs = [
|
ray_worker_outputs = [
|
||||||
worker.execute_method.remote(method, *args, **kwargs)
|
worker.execute_method.remote(method, *worker_args,
|
||||||
for worker in self.workers
|
**worker_kwargs)
|
||||||
|
for (worker, worker_args, worker_kwargs
|
||||||
|
) in zip(self.workers, all_args[1:], all_kwargs[1:])
|
||||||
]
|
]
|
||||||
|
|
||||||
if driver_args is None:
|
if driver_args is None:
|
||||||
@ -309,9 +313,13 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
driver_kwargs = kwargs
|
driver_kwargs = kwargs
|
||||||
|
|
||||||
# Start the driver worker after all the ray workers.
|
# Start the driver worker after all the ray workers.
|
||||||
driver_worker_output = getattr(self.driver_worker,
|
if not use_dummy_driver:
|
||||||
method)(*driver_args, **driver_kwargs)
|
driver_worker_output = self.driver_worker.execute_method(
|
||||||
|
method, *all_args[0], **all_kwargs[0])
|
||||||
|
else:
|
||||||
|
driver_worker_output = ray.get(
|
||||||
|
self.driver_dummy_worker.execute_method.remote(
|
||||||
|
method, *all_args[0], **all_kwargs[0]))
|
||||||
# Get the results of the ray workers.
|
# Get the results of the ray workers.
|
||||||
if self.workers:
|
if self.workers:
|
||||||
if use_ray_compiled_dag:
|
if use_ray_compiled_dag:
|
||||||
@ -386,8 +394,12 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
|
|||||||
driver_kwargs = kwargs
|
driver_kwargs = kwargs
|
||||||
|
|
||||||
# Run the driver worker asynchronously.
|
# Run the driver worker asynchronously.
|
||||||
driver_executor = make_async(getattr(self.driver_worker, method))
|
def helper():
|
||||||
coros.append(driver_executor(*driver_args, **driver_kwargs))
|
return self.driver_worker.execute_method(method, *driver_args,
|
||||||
|
**driver_kwargs)
|
||||||
|
|
||||||
|
driver_executor = make_async(helper)
|
||||||
|
coros.append(driver_executor())
|
||||||
|
|
||||||
# Run the ray workers asynchronously.
|
# Run the ray workers asynchronously.
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
|
|||||||
@ -271,8 +271,12 @@ def get_open_port() -> int:
|
|||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
def set_cuda_visible_devices(device_ids: List[int]) -> None:
|
def update_environment_variables(envs: Dict[str, str]):
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
|
for k, v in envs.items():
|
||||||
|
if k in os.environ:
|
||||||
|
logger.warning(f"Overwriting environment variable {k} "
|
||||||
|
f"from '{os.environ[k]}' to '{v}'")
|
||||||
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
def chunk_list(lst, chunk_size):
|
def chunk_list(lst, chunk_size):
|
||||||
@ -505,3 +509,11 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
|
|||||||
merged_dict[key].extend(value)
|
merged_dict[key].extend(value)
|
||||||
|
|
||||||
return dict(merged_dict)
|
return dict(merged_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def init_cached_hf_modules():
|
||||||
|
"""
|
||||||
|
Lazy initialization of the Hugging Face modules.
|
||||||
|
"""
|
||||||
|
from transformers.dynamic_module_utils import init_hf_modules
|
||||||
|
init_hf_modules()
|
||||||
|
|||||||
@ -138,7 +138,10 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert self.rank == 0, "The driver worker must have rank 0."
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
if self.model_config.trust_remote_code:
|
||||||
|
# note: lazy import to avoid importing torch before initializing
|
||||||
|
from vllm.utils import init_cached_hf_modules
|
||||||
|
init_cached_hf_modules()
|
||||||
self.model_runner = CPUModelRunner(model_config,
|
self.model_runner = CPUModelRunner(model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
|
|||||||
@ -29,6 +29,10 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
if self.model_config.trust_remote_code:
|
||||||
|
# note: lazy import to avoid importing torch before initializing
|
||||||
|
from vllm.utils import init_cached_hf_modules
|
||||||
|
init_cached_hf_modules()
|
||||||
|
|
||||||
self.model_runner = NeuronModelRunner(model_config, parallel_config,
|
self.model_runner = NeuronModelRunner(model_config, parallel_config,
|
||||||
scheduler_config, device_config)
|
scheduler_config, device_config)
|
||||||
|
|||||||
@ -60,6 +60,10 @@ class Worker(WorkerBase):
|
|||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert self.rank == 0, "The driver worker must have rank 0."
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
|
||||||
|
if self.model_config.trust_remote_code:
|
||||||
|
# note: lazy import to avoid importing torch before initializing
|
||||||
|
from vllm.utils import init_cached_hf_modules
|
||||||
|
init_cached_hf_modules()
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
if self.vision_language_config:
|
if self.vision_language_config:
|
||||||
assert not self.lora_config, (
|
assert not self.lora_config, (
|
||||||
|
|||||||
@ -1,8 +1,14 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkerBase(ABC):
|
class WorkerBase(ABC):
|
||||||
@ -82,3 +88,53 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
|||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> List[int]:
|
||||||
raise ValueError(f"{type(self)} does not support LoRA")
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerWrapperBase:
|
||||||
|
"""
|
||||||
|
The whole point of this class is to lazily initialize the worker.
|
||||||
|
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||||
|
and class name. Then, when we call `update_environment_variables`, and the
|
||||||
|
real initialization happens in `init_worker`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
worker_module_name=None,
|
||||||
|
worker_class_name=None) -> None:
|
||||||
|
self.worker_module_name = worker_module_name
|
||||||
|
self.worker_class_name = worker_class_name
|
||||||
|
self.worker = None
|
||||||
|
|
||||||
|
def update_environment_variables(self, envs: Dict[str, str]) -> None:
|
||||||
|
key = 'CUDA_VISIBLE_DEVICES'
|
||||||
|
if key in envs and key in os.environ:
|
||||||
|
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||||
|
# suppress the warning in `update_environment_variables`
|
||||||
|
del os.environ[key]
|
||||||
|
update_environment_variables(envs)
|
||||||
|
|
||||||
|
def init_worker(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Actual initialization of the worker class.
|
||||||
|
Arguments are passed to the worker class constructor.
|
||||||
|
"""
|
||||||
|
mod = importlib.import_module(self.worker_module_name)
|
||||||
|
worker_class = getattr(mod, self.worker_class_name)
|
||||||
|
self.worker = worker_class(*args, **kwargs)
|
||||||
|
|
||||||
|
def execute_method(self, method, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
if hasattr(self, method):
|
||||||
|
executor = getattr(self, method)
|
||||||
|
else:
|
||||||
|
executor = getattr(self.worker, method)
|
||||||
|
return executor(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# if the driver worker also execute methods,
|
||||||
|
# exceptions in the rest worker may cause deadlock in rpc like ray
|
||||||
|
# see https://github.com/vllm-project/vllm/issues/3455
|
||||||
|
# print the error and inform the user to solve the error
|
||||||
|
msg = (f"Error executing method {method}. "
|
||||||
|
"This might cause deadlock in distributed execution.")
|
||||||
|
logger.exception(msg)
|
||||||
|
raise e
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user