mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 06:08:44 +08:00
[misc] fix cross-node TP (#12166)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
7b98a65ae6
commit
2b83503227
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
@ -10,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, make_async, run_method)
|
||||
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
||||
get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async, run_method, update_environment_variables)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -22,7 +24,39 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _check_cuda(self) -> None:
|
||||
"""Check that the number of GPUs is sufficient for the parallel
|
||||
configuration. Separate from _init_executor to reduce the number of
|
||||
indented blocks.
|
||||
"""
|
||||
parallel_config = self.parallel_config
|
||||
world_size = parallel_config.world_size
|
||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
if tensor_parallel_size > cuda_device_count:
|
||||
raise RuntimeError(
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
if world_size > cuda_device_count:
|
||||
raise RuntimeError(
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
self._check_cuda()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
@ -139,28 +139,6 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
world_size = parallel_config.world_size
|
||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
update_environment_variables)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
assert world_size <= cuda_device_count, (
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user