mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 04:45:01 +08:00
Support torchrun and SPMD-style offline inference (#12071)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
dd7c9ad870
commit
bf53e0c70b
@ -463,6 +463,7 @@ steps:
|
|||||||
- vllm/worker/worker.py
|
- vllm/worker/worker.py
|
||||||
- vllm/worker/model_runner.py
|
- vllm/worker/model_runner.py
|
||||||
commands:
|
commands:
|
||||||
|
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
|
|||||||
64
examples/offline_inference/torchrun_example.py
Normal file
64
examples/offline_inference/torchrun_example.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
experimental support for tensor-parallel inference with torchrun,
|
||||||
|
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||||
|
the motivation and use case for this example.
|
||||||
|
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
|
||||||
|
the argument 2 should match the `tensor_parallel_size` below.
|
||||||
|
see `tests/distributed/test_torchrun_example.py` for the unit test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Create prompts, the same across all ranks
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create sampling parameters, the same across all ranks
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# Use `distributed_executor_backend="external_launcher"` so that
|
||||||
|
# this llm engine/instance only creates one worker.
|
||||||
|
llm = LLM(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="external_launcher",
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# all ranks will have the same outputs
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, "
|
||||||
|
f"Generated text: {generated_text!r}")
|
||||||
|
"""
|
||||||
|
Further tips:
|
||||||
|
|
||||||
|
1. to communicate control messages across all ranks, use the cpu group,
|
||||||
|
a PyTorch ProcessGroup with GLOO backend.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
cpu_group = get_world_group().cpu_group
|
||||||
|
torch_rank = dist.get_rank(group=cpu_group)
|
||||||
|
if torch_rank == 0:
|
||||||
|
# do something for rank 0, e.g. saving the results to disk.
|
||||||
|
```
|
||||||
|
|
||||||
|
2. to communicate data across all ranks, use the model's device group,
|
||||||
|
a PyTorch ProcessGroup with NCCL backend.
|
||||||
|
```python
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
device_group = get_world_group().device_group
|
||||||
|
```
|
||||||
|
|
||||||
|
3. to access the model directly in every rank, use the following code:
|
||||||
|
```python
|
||||||
|
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||||
|
```
|
||||||
|
"""
|
||||||
56
tests/distributed/test_torchrun_example.py
Normal file
56
tests/distributed/test_torchrun_example.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# unit test for `examples/offline_inference/torchrun_example.py`
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
|
||||||
|
# Create prompts
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||||
|
# to test if all ranks agree on the same kv cache configuration.
|
||||||
|
llm = LLM(model="facebook/opt-125m",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="external_launcher",
|
||||||
|
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||||
|
swap_space=random.randint(1, 4))
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
cpu_group = get_world_group().cpu_group
|
||||||
|
|
||||||
|
torch_rank = dist.get_rank(group=cpu_group)
|
||||||
|
|
||||||
|
|
||||||
|
def test_consistent_across_ranks(obj):
|
||||||
|
if torch_rank == 0:
|
||||||
|
dist.broadcast_object_list([obj], src=0, group=cpu_group)
|
||||||
|
else:
|
||||||
|
container = [None]
|
||||||
|
dist.broadcast_object_list(container, src=0, group=cpu_group)
|
||||||
|
assert container[0] == obj
|
||||||
|
|
||||||
|
|
||||||
|
test_consistent_across_ranks(
|
||||||
|
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||||
|
test_consistent_across_ranks(
|
||||||
|
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||||
|
|
||||||
|
# all ranks should have the same outputs
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
test_consistent_across_ranks(prompt)
|
||||||
|
test_consistent_across_ranks(generated_text)
|
||||||
|
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
|
||||||
|
f"Generated text: {generated_text!r}")
|
||||||
@ -22,7 +22,7 @@ class DummyWorkerWrapper(WorkerWrapperBase):
|
|||||||
# simulate error case
|
# simulate error case
|
||||||
raise worker_input
|
raise worker_input
|
||||||
|
|
||||||
return self.rank, input
|
return self.rpc_rank, input
|
||||||
|
|
||||||
|
|
||||||
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
||||||
|
|||||||
@ -1338,14 +1338,15 @@ class ParallelConfig:
|
|||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
if self.distributed_executor_backend not in (
|
if self.distributed_executor_backend not in (
|
||||||
"ray", "mp", "uni", None) and not (isinstance(
|
"ray", "mp", "uni",
|
||||||
|
"external_launcher", None) and not (isinstance(
|
||||||
self.distributed_executor_backend, type) and issubclass(
|
self.distributed_executor_backend, type) and issubclass(
|
||||||
self.distributed_executor_backend, ExecutorBase)):
|
self.distributed_executor_backend, ExecutorBase)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized distributed executor backend "
|
"Unrecognized distributed executor backend "
|
||||||
f"{self.distributed_executor_backend}. Supported "
|
f"{self.distributed_executor_backend}. Supported "
|
||||||
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
|
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
|
||||||
" subclass.")
|
" custom ExecutorBase subclass.")
|
||||||
if self.use_ray:
|
if self.use_ray:
|
||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
ray_utils.assert_ray_available()
|
ray_utils.assert_ray_available()
|
||||||
|
|||||||
@ -388,7 +388,7 @@ class EngineArgs:
|
|||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--distributed-executor-backend',
|
'--distributed-executor-backend',
|
||||||
choices=['ray', 'mp'],
|
choices=['ray', 'mp', 'uni', 'external_launcher'],
|
||||||
default=EngineArgs.distributed_executor_backend,
|
default=EngineArgs.distributed_executor_backend,
|
||||||
help='Backend to use for distributed model '
|
help='Backend to use for distributed model '
|
||||||
'workers, either "ray" or "mp" (multiprocessing). If the product '
|
'workers, either "ray" or "mp" (multiprocessing). If the product '
|
||||||
|
|||||||
@ -457,6 +457,11 @@ class LLMEngine:
|
|||||||
# JAX-style, single-process, multi-device executor.
|
# JAX-style, single-process, multi-device executor.
|
||||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||||
executor_class = UniProcExecutor
|
executor_class = UniProcExecutor
|
||||||
|
elif distributed_executor_backend == "external_launcher":
|
||||||
|
# executor with external launcher
|
||||||
|
from vllm.executor.uniproc_executor import ( # noqa
|
||||||
|
ExecutorWithExternalLauncher)
|
||||||
|
executor_class = ExecutorWithExternalLauncher
|
||||||
else:
|
else:
|
||||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||||
executor_class = UniProcExecutor
|
executor_class = UniProcExecutor
|
||||||
|
|||||||
@ -172,7 +172,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
scheduling_strategy=scheduling_strategy,
|
scheduling_strategy=scheduling_strategy,
|
||||||
**ray_remote_kwargs,
|
**ray_remote_kwargs,
|
||||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||||
rank=rank)
|
rpc_rank=rank)
|
||||||
else:
|
else:
|
||||||
worker = ray.remote(
|
worker = ray.remote(
|
||||||
num_cpus=0,
|
num_cpus=0,
|
||||||
@ -181,7 +181,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
scheduling_strategy=scheduling_strategy,
|
scheduling_strategy=scheduling_strategy,
|
||||||
**ray_remote_kwargs,
|
**ray_remote_kwargs,
|
||||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||||
rank=rank)
|
rpc_rank=rank)
|
||||||
worker_metadata.append(
|
worker_metadata.append(
|
||||||
RayWorkerMetaData(worker=worker, created_rank=rank))
|
RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||||
rank += 1
|
rank += 1
|
||||||
@ -204,7 +204,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
# as the resource holder for the driver process.
|
# as the resource holder for the driver process.
|
||||||
self.driver_dummy_worker = worker
|
self.driver_dummy_worker = worker
|
||||||
self.driver_worker = RayWorkerWrapper(
|
self.driver_worker = RayWorkerWrapper(
|
||||||
vllm_config=self.vllm_config, rank=0)
|
vllm_config=self.vllm_config, rpc_rank=0)
|
||||||
worker_metadata.pop(i)
|
worker_metadata.pop(i)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,10 @@
|
|||||||
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
@ -16,7 +21,7 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
"""Initialize the worker and load the model.
|
"""Initialize the worker and load the model.
|
||||||
"""
|
"""
|
||||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||||
rank=0)
|
rpc_rank=0)
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
get_ip(), get_open_port())
|
get_ip(), get_open_port())
|
||||||
local_rank = 0
|
local_rank = 0
|
||||||
@ -55,3 +60,77 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
|
|
||||||
|
|
||||||
UniProcExecutorAsync = UniProcExecutor
|
UniProcExecutorAsync = UniProcExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||||
|
"""An executor that uses external launchers to launch engines,
|
||||||
|
specially designed for torchrun-compatible launchers, for
|
||||||
|
offline inference with tensor parallelism.
|
||||||
|
|
||||||
|
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||||
|
the motivation, and examples/offline_inference/torchrun_example.py
|
||||||
|
for the usage example.
|
||||||
|
|
||||||
|
The key idea: although it is tensor-parallel inference, we only
|
||||||
|
create one worker per executor, users will launch multiple
|
||||||
|
engines with torchrun-compatible launchers, and all these engines
|
||||||
|
work together to process the same prompts. When scheduling is
|
||||||
|
deterministic, all the engines will generate the same outputs,
|
||||||
|
and they don't need to synchronize the states with each other.
|
||||||
|
"""
|
||||||
|
uses_ray: bool = False
|
||||||
|
|
||||||
|
def _init_executor(self) -> None:
|
||||||
|
"""Initialize the worker and load the model.
|
||||||
|
"""
|
||||||
|
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
|
||||||
|
("ExecutorWithExternalLauncher does not "
|
||||||
|
"support pipeline parallelism.")
|
||||||
|
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
|
||||||
|
("ExecutorWithExternalLauncher needs deterministic "
|
||||||
|
"execution, so it"
|
||||||
|
"does not support delay_factor in scheduling")
|
||||||
|
assert not envs.VLLM_USE_V1, \
|
||||||
|
("V1 architecture cannot guarantee deterministic execution, "
|
||||||
|
"so it is not supported in ExecutorWithExternalLauncher.")
|
||||||
|
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||||
|
rpc_rank=0)
|
||||||
|
# engines are launched in torchrun-compatible launchers
|
||||||
|
# so we can use the env:// method.
|
||||||
|
# required env vars:
|
||||||
|
# - RANK
|
||||||
|
# - MASTER_ADDR
|
||||||
|
# - MASTER_PORT
|
||||||
|
distributed_init_method = "env://"
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
local_rank = rank
|
||||||
|
is_driver_worker = True
|
||||||
|
kwargs = dict(
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
|
)
|
||||||
|
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||||
|
self.collective_rpc("init_device")
|
||||||
|
self.collective_rpc("load_model")
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Determine the number of available KV blocks.
|
||||||
|
Add an additional all_reduce to get the min across all ranks.
|
||||||
|
Note that even if we have the same `gpu_memory_utilization` and
|
||||||
|
`swap_space`, the available memory in every rank might still
|
||||||
|
differ because NCCL can take different amounts of memory in
|
||||||
|
different ranks. Therefore, it is necessary to test if all ranks
|
||||||
|
agree on the same KV cache configuration.
|
||||||
|
"""
|
||||||
|
a, b = super().determine_num_available_blocks()
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
cpu_group = get_world_group().cpu_group
|
||||||
|
a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
|
||||||
|
b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
|
||||||
|
dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||||
|
dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||||
|
return a_tensor.item(), b_tensor.item()
|
||||||
|
|||||||
@ -940,8 +940,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
return self.base_layer.soft_cap
|
return self.base_layer.soft_cap
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_gather(self):
|
def use_all_gather(self):
|
||||||
return self.base_layer.use_gather
|
return self.base_layer.use_all_gather
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def org_vocab_size(self):
|
def org_vocab_size(self):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_gather)
|
tensor_model_parallel_gather)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -44,8 +45,10 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.soft_cap = soft_cap
|
self.soft_cap = soft_cap
|
||||||
# Whether to use gather or all-gather to gather the logits.
|
# Whether to use gather or all-gather to gather the logits.
|
||||||
|
|
||||||
self.use_gather = not current_platform.is_tpu(
|
parallel_config = get_current_vllm_config().parallel_config
|
||||||
) and not envs.VLLM_USE_V1
|
self.use_all_gather = current_platform.is_tpu() \
|
||||||
|
or envs.VLLM_USE_V1 \
|
||||||
|
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -88,16 +91,17 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits = lm_head.linear_method.apply(lm_head,
|
logits = lm_head.linear_method.apply(lm_head,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
bias=embedding_bias)
|
bias=embedding_bias)
|
||||||
if self.use_gather:
|
|
||||||
# None may be returned for rank > 0
|
if self.use_all_gather:
|
||||||
logits = tensor_model_parallel_gather(logits)
|
|
||||||
else:
|
|
||||||
# Gather is not supported for some devices such as TPUs.
|
# Gather is not supported for some devices such as TPUs.
|
||||||
# Use all-gather instead.
|
# Use all-gather instead.
|
||||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||||
# because XLA requires strict SPMD among all devices. Every device
|
# because XLA requires strict SPMD among all devices. Every device
|
||||||
# should execute the same operations after gathering the logits.
|
# should execute the same operations after gathering the logits.
|
||||||
logits = tensor_model_parallel_all_gather(logits)
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
|
else:
|
||||||
|
# None may be returned for rank > 0
|
||||||
|
logits = tensor_model_parallel_gather(logits)
|
||||||
# Remove paddings in vocab (if any).
|
# Remove paddings in vocab (if any).
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
logits = logits[..., :self.org_vocab_size]
|
logits = logits[..., :self.org_vocab_size]
|
||||||
|
|||||||
@ -246,7 +246,7 @@ class WorkerProc:
|
|||||||
ready_path: str,
|
ready_path: str,
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank)
|
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||||
all_kwargs: List[Dict] = [
|
all_kwargs: List[Dict] = [
|
||||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||||
|
|||||||
@ -55,9 +55,6 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
if is_driver_worker:
|
|
||||||
assert rank % self.parallel_config.tensor_parallel_size == 0, \
|
|
||||||
"Driver worker should be rank 0 of tensor parallel group."
|
|
||||||
if self.model_config.trust_remote_code:
|
if self.model_config.trust_remote_code:
|
||||||
# note: lazy import to avoid importing torch before initializing
|
# note: lazy import to avoid importing torch before initializing
|
||||||
from vllm.utils import init_cached_hf_modules
|
from vllm.utils import init_cached_hf_modules
|
||||||
|
|||||||
@ -461,7 +461,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
|
|
||||||
class WorkerWrapperBase:
|
class WorkerWrapperBase:
|
||||||
"""
|
"""
|
||||||
The whole point of this class is to lazily initialize the worker.
|
This class represents one process in an executor/engine. It is responsible
|
||||||
|
for lazily initializing the worker and handling the worker's lifecycle.
|
||||||
We first instantiate the WorkerWrapper, which remembers the worker module
|
We first instantiate the WorkerWrapper, which remembers the worker module
|
||||||
and class name. Then, when we call `update_environment_variables`, and the
|
and class name. Then, when we call `update_environment_variables`, and the
|
||||||
real initialization happens in `init_worker`.
|
real initialization happens in `init_worker`.
|
||||||
@ -470,9 +471,19 @@ class WorkerWrapperBase:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
rank: int = 0,
|
rpc_rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.rank = rank
|
"""
|
||||||
|
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
||||||
|
Note: rpc_rank is the rank of the worker in the executor. In most cases,
|
||||||
|
it is also the rank of the worker in the distributed group. However,
|
||||||
|
when multiple executors work together, they can be different.
|
||||||
|
e.g. in the case of SPMD-style offline inference with TP=2,
|
||||||
|
users can launch 2 engines/executors, each with only 1 worker.
|
||||||
|
All workers have rpc_rank=0, but they have different ranks in the TP
|
||||||
|
group.
|
||||||
|
"""
|
||||||
|
self.rpc_rank = rpc_rank
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.worker: Optional[WorkerBase] = None
|
self.worker: Optional[WorkerBase] = None
|
||||||
if vllm_config.model_config is not None:
|
if vllm_config.model_config is not None:
|
||||||
@ -485,16 +496,16 @@ class WorkerWrapperBase:
|
|||||||
|
|
||||||
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
|
||||||
"""
|
"""
|
||||||
Adjust the rank based on the given mapping.
|
Adjust the rpc_rank based on the given mapping.
|
||||||
It is only used during the initialization of the executor,
|
It is only used during the initialization of the executor,
|
||||||
to adjust the rank of workers after we create all workers.
|
to adjust the rpc_rank of workers after we create all workers.
|
||||||
"""
|
"""
|
||||||
if self.rank in rank_mapping:
|
if self.rpc_rank in rank_mapping:
|
||||||
self.rank = rank_mapping[self.rank]
|
self.rpc_rank = rank_mapping[self.rpc_rank]
|
||||||
|
|
||||||
def update_environment_variables(self, envs_list: List[Dict[str,
|
def update_environment_variables(self, envs_list: List[Dict[str,
|
||||||
str]]) -> None:
|
str]]) -> None:
|
||||||
envs = envs_list[self.rank]
|
envs = envs_list[self.rpc_rank]
|
||||||
key = 'CUDA_VISIBLE_DEVICES'
|
key = 'CUDA_VISIBLE_DEVICES'
|
||||||
if key in envs and key in os.environ:
|
if key in envs and key in os.environ:
|
||||||
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
|
||||||
@ -507,7 +518,7 @@ class WorkerWrapperBase:
|
|||||||
Here we inject some common logic before initializing the worker.
|
Here we inject some common logic before initializing the worker.
|
||||||
Arguments are passed to the worker class constructor.
|
Arguments are passed to the worker class constructor.
|
||||||
"""
|
"""
|
||||||
kwargs = all_kwargs[self.rank]
|
kwargs = all_kwargs[self.rpc_rank]
|
||||||
enable_trace_function_call_for_thread(self.vllm_config)
|
enable_trace_function_call_for_thread(self.vllm_config)
|
||||||
|
|
||||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user