diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 13f2761b0db06..aa54bd66bed67 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # The CLI entrypoint to vLLM. -import os import signal import sys @@ -9,11 +8,9 @@ import vllm.entrypoints.cli.benchmark.main import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.serve import vllm.version -from vllm.logger import init_logger +from vllm.entrypoints.utils import cli_env_setup from vllm.utils import FlexibleArgumentParser -logger = init_logger(__name__) - CMD_MODULES = [ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, @@ -30,29 +27,8 @@ def register_signal_handlers(): signal.signal(signal.SIGTSTP, signal_handler) -def env_setup(): - # The safest multiprocessing method is `spawn`, as the default `fork` method - # is not compatible with some accelerators. The default method will be - # changing in future versions of Python, so we should use it explicitly when - # possible. - # - # We only set it here in the CLI entrypoint, because changing to `spawn` - # could break some existing code using vLLM as a library. `spawn` will cause - # unexpected behavior if the code is not protected by - # `if __name__ == "__main__":`. - # - # References: - # - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods - # - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing - # - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors - # - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders - if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: - logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - def main(): - env_setup() + cli_env_setup() parser = FlexibleArgumentParser(description="vLLM CLI") parser.add_argument('-v', diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 18d75a04ab0f3..2a61259896a37 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -82,7 +82,8 @@ from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription) from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.utils import load_aware_call, with_cancellation +from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, + with_cancellation) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.config import ( @@ -1106,6 +1107,7 @@ if __name__ == "__main__": # NOTE(simon): # This section should be in sync with vllm/entrypoints/cli/main.py for CLI # entrypoints. + cli_env_setup() parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 773f52fa38f88..b88c2b3a080fd 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -2,11 +2,16 @@ import asyncio import functools +import os from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from starlette.background import BackgroundTask, BackgroundTasks +from vllm.logger import init_logger + +logger = init_logger(__name__) + async def listen_for_disconnect(request: Request) -> None: """Returns if a disconnect message is received""" @@ -108,3 +113,24 @@ def load_aware_call(func): return response return wrapper + + +def cli_env_setup(): + # The safest multiprocessing method is `spawn`, as the default `fork` method + # is not compatible with some accelerators. The default method will be + # changing in future versions of Python, so we should use it explicitly when + # possible. + # + # We only set it here in the CLI entrypoint, because changing to `spawn` + # could break some existing code using vLLM as a library. `spawn` will cause + # unexpected behavior if the code is not protected by + # `if __name__ == "__main__":`. + # + # References: + # - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + # - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + # - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors + # - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders + if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: + logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"