[Core] Improve choice of Python multiprocessing method (#8823)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Russell Bryant 2024-09-28 21:17:07 -04:00 committed by GitHub
parent cc276443b5
commit d1537039ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 52 additions and 9 deletions

View File

@ -15,8 +15,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
cuda_is_initialized, get_distributed_init_method,
get_open_port, get_vllm_instance_id, make_async,
update_environment_variables)
logger = init_logger(__name__)
@ -122,6 +122,13 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (

View File

@ -27,9 +27,6 @@ RESET = '\033[0;0m'
JOIN_TIMEOUT_S = 2
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
mp = multiprocessing.get_context(mp_method)
@dataclass
class Result(Generic[T]):
@ -77,7 +74,7 @@ class ResultHandler(threading.Thread):
def __init__(self) -> None:
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.result_queue = get_mp_context().Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
def run(self):
@ -147,10 +144,11 @@ class ProcessWorkerWrapper:
def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue()
self.mp = get_mp_context()
self._task_queue = self.mp.Queue()
self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined]
target=_run_worker_process,
name="VllmWorkerProcess",
kwargs=dict(
@ -204,7 +202,7 @@ def _run_worker_process(
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name
process_name = get_mp_context().current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
@ -269,3 +267,8 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]
def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)

View File

@ -12,8 +12,11 @@ from openai.types.chat import ChatCompletionMessageParam
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
def register_signal_handlers():
@ -114,7 +117,30 @@ def _add_query_options(
return parser
def env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def main():
env_setup()
parser = FlexibleArgumentParser(description="vLLM CLI")
subparsers = parser.add_subparsers(required=True)

View File

@ -1091,6 +1091,13 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
def cuda_is_initialized() -> bool:
"""Check if CUDA is initialized."""
if not torch.cuda._is_compiled():
return False
return torch.cuda.is_initialized()
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
"""Make an instance method that weakly references
its associated instance and no-ops once that