[Bugfix] fix RAY_CGRAPH_get_timeout is not set successfully (#19725)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-06-19 01:26:16 +08:00 committed by GitHub
parent 8b6e1d639c
commit 12575cfa7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -557,8 +557,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _compiled_ray_dag(self, enable_asyncio: bool):
assert self.parallel_config.use_ray
self._check_ray_cgraph_installation()
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
# Note: we should set this env var before importing
# ray.dag, otherwise it will not take effect.
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
from ray.dag import InputNode, MultiOutputNode
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
@ -570,15 +579,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
#