mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:25:01 +08:00
[executor] feat: add supports_pp attr to executors (#21786)
Signed-off-by: Haibin Lin <haibin.lin@bytedance.com>
This commit is contained in:
parent
7de45db9a5
commit
24d1dffbeb
@ -1490,14 +1490,18 @@ class EngineArgs:
|
|||||||
and _warn_or_fallback("Engine in background thread")):
|
and _warn_or_fallback("Engine in background thread")):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if (self.pipeline_parallel_size > 1
|
if self.pipeline_parallel_size > 1:
|
||||||
and self.distributed_executor_backend
|
supports_pp = getattr(self.distributed_executor_backend,
|
||||||
not in (ParallelConfig.distributed_executor_backend, "ray",
|
'supports_pp', False)
|
||||||
"mp", "external_launcher")):
|
if not supports_pp and self.distributed_executor_backend not in (
|
||||||
name = "Pipeline Parallelism without Ray distributed executor " \
|
ParallelConfig.distributed_executor_backend, "ray", "mp",
|
||||||
"or multiprocessing executor or external launcher"
|
"external_launcher"):
|
||||||
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
|
name = "Pipeline Parallelism without Ray distributed " \
|
||||||
return False
|
"executor or multiprocessing executor or external " \
|
||||||
|
"launcher"
|
||||||
|
_raise_or_fallback(feature_name=name,
|
||||||
|
recommend_to_remove=False)
|
||||||
|
return False
|
||||||
|
|
||||||
# The platform may be supported on V1, but off by default for now.
|
# The platform may be supported on V1, but off by default for now.
|
||||||
if not current_platform.default_v1( # noqa: SIM103
|
if not current_platform.default_v1( # noqa: SIM103
|
||||||
|
|||||||
@ -35,6 +35,7 @@ class ExecutorBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
uses_ray: bool # whether the executor uses Ray for orchestration.
|
uses_ray: bool # whether the executor uses Ray for orchestration.
|
||||||
|
supports_pp: bool = False # whether the executor supports PP
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -41,6 +41,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class MultiprocExecutor(Executor):
|
class MultiprocExecutor(Executor):
|
||||||
|
|
||||||
|
supports_pp: bool = True
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
# Call self.shutdown at exit to clean up
|
# Call self.shutdown at exit to clean up
|
||||||
# and ensure workers will be terminated.
|
# and ensure workers will be terminated.
|
||||||
|
|||||||
@ -43,6 +43,8 @@ class FutureWrapper(Future):
|
|||||||
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||||
"""Ray distributed executor using Ray Compiled Graphs."""
|
"""Ray distributed executor using Ray Compiled Graphs."""
|
||||||
|
|
||||||
|
supports_pp: bool = True
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
super()._init_executor()
|
super()._init_executor()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user