mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 06:55:01 +08:00
[Core] Support async scheduling with uniproc executor (#24219)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Co-authored-by: Ronald1995 <ronaldautomobile@163.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
8226dd56bf
commit
4fdd6f5cbf
@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs(
|
|||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("enforce_eager", [False])
|
@pytest.mark.parametrize("enforce_eager", [False])
|
||||||
|
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||||
|
@pytest.mark.parametrize("model_executor", ["uni", "mp"])
|
||||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models(
|
def test_models(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@ -70,6 +72,8 @@ def test_models(
|
|||||||
backend: str,
|
backend: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
|
async_scheduling: bool,
|
||||||
|
model_executor: str,
|
||||||
enable_prompt_embeds: bool,
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
@ -77,6 +81,12 @@ def test_models(
|
|||||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||||
|
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
if async_scheduling:
|
||||||
|
pytest.skip("async_scheduling only supported in v1.")
|
||||||
|
if model_executor != "uni":
|
||||||
|
pytest.skip("only test uniproc executor for v0.")
|
||||||
|
|
||||||
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"{backend} does not support gemma2 with full context length.")
|
f"{backend} does not support gemma2 with full context length.")
|
||||||
@ -98,11 +108,15 @@ def test_models(
|
|||||||
prompt_embeds = hf_model.get_prompt_embeddings(
|
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||||
example_prompts)
|
example_prompts)
|
||||||
|
|
||||||
with VllmRunner(model,
|
with VllmRunner(
|
||||||
|
model,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
enable_prompt_embeds=enable_prompt_embeds,
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
gpu_memory_utilization=0.7) as vllm_model:
|
gpu_memory_utilization=0.7,
|
||||||
|
async_scheduling=async_scheduling,
|
||||||
|
distributed_executor_backend=model_executor,
|
||||||
|
) as vllm_model:
|
||||||
if enable_prompt_embeds:
|
if enable_prompt_embeds:
|
||||||
vllm_outputs = vllm_model.generate_greedy(
|
vllm_outputs = vllm_model.generate_greedy(
|
||||||
prompt_embeds, max_tokens)
|
prompt_embeds, max_tokens)
|
||||||
|
|||||||
@ -257,9 +257,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
|
non_block=False,
|
||||||
) -> Future[ModelRunnerOutput]:
|
) -> Future[ModelRunnerOutput]:
|
||||||
"""Make execute_model non-blocking."""
|
"""Make execute_model non-blocking."""
|
||||||
|
|
||||||
|
# DummyExecutor used only for testing async case.
|
||||||
|
assert non_block
|
||||||
|
|
||||||
def _execute():
|
def _execute():
|
||||||
output = self.collective_rpc("execute_model",
|
output = self.collective_rpc("execute_model",
|
||||||
args=(scheduler_output, ))
|
args=(scheduler_output, ))
|
||||||
|
|||||||
@ -1296,11 +1296,8 @@ class EngineArgs:
|
|||||||
# Async scheduling does not work with the uniprocess backend.
|
# Async scheduling does not work with the uniprocess backend.
|
||||||
if self.distributed_executor_backend is None:
|
if self.distributed_executor_backend is None:
|
||||||
self.distributed_executor_backend = "mp"
|
self.distributed_executor_backend = "mp"
|
||||||
logger.info("Using mp-based distributed executor backend "
|
logger.info("Defaulting to mp-based distributed executor "
|
||||||
"for async scheduling.")
|
"backend for async scheduling.")
|
||||||
if self.distributed_executor_backend == "uni":
|
|
||||||
raise ValueError("Async scheduling is not supported with "
|
|
||||||
"uni-process backend.")
|
|
||||||
if self.pipeline_parallel_size > 1:
|
if self.pipeline_parallel_size > 1:
|
||||||
raise ValueError("Async scheduling is not supported with "
|
raise ValueError("Async scheduling is not supported with "
|
||||||
"pipeline-parallel-size > 1.")
|
"pipeline-parallel-size > 1.")
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from functools import cached_property
|
||||||
from multiprocessing import Lock
|
from multiprocessing import Lock
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
|||||||
run_method)
|
run_method)
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||||
|
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -31,15 +33,7 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
"""
|
"""
|
||||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
||||||
rpc_rank=0)
|
rpc_rank=0)
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method, rank, local_rank = self._distributed_args()
|
||||||
get_ip(), get_open_port())
|
|
||||||
local_rank = 0
|
|
||||||
# set local rank as the device index if specified
|
|
||||||
device_info = self.vllm_config.device_config.device.__str__().split(
|
|
||||||
":")
|
|
||||||
if len(device_info) > 1:
|
|
||||||
local_rank = int(device_info[1])
|
|
||||||
rank = 0
|
|
||||||
is_driver_worker = True
|
is_driver_worker = True
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
@ -50,21 +44,56 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
)
|
)
|
||||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||||
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
||||||
|
|
||||||
|
self.async_output_thread: Optional[ThreadPoolExecutor] = None
|
||||||
|
if self.max_concurrent_batches > 1:
|
||||||
|
self.async_output_thread = ThreadPoolExecutor(
|
||||||
|
max_workers=1, thread_name_prefix="WorkerAsyncOutput")
|
||||||
|
|
||||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||||
self.collective_rpc("init_device")
|
self.collective_rpc("init_device")
|
||||||
self.collective_rpc("load_model")
|
self.collective_rpc("load_model")
|
||||||
|
|
||||||
|
def _distributed_args(self) -> tuple[str, int, int]:
|
||||||
|
"""Return (distributed_init_method, rank, local_rank)."""
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
get_ip(), get_open_port())
|
||||||
|
# set local rank as the device index if specified
|
||||||
|
device_info = self.vllm_config.device_config.device.__str__().split(
|
||||||
|
":")
|
||||||
|
local_rank = int(device_info[1]) if len(device_info) > 1 else 0
|
||||||
|
return distributed_init_method, 0, local_rank
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def max_concurrent_batches(self) -> int:
|
||||||
|
return 2 if self.scheduler_config.async_scheduling else 1
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None,
|
||||||
|
non_block: bool = False) -> List[Any]:
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if self.mm_receiver_cache is not None and method == "execute_model":
|
if self.mm_receiver_cache is not None and method == "execute_model":
|
||||||
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
||||||
answer = run_method(self.driver_worker, method, args, kwargs)
|
|
||||||
return [answer]
|
if not non_block:
|
||||||
|
return [run_method(self.driver_worker, method, args, kwargs)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = run_method(self.driver_worker, method, args, kwargs)
|
||||||
|
if isinstance(result, AsyncModelRunnerOutput):
|
||||||
|
if (async_thread := self.async_output_thread) is not None:
|
||||||
|
return [async_thread.submit(result.get_output)]
|
||||||
|
result = result.get_output()
|
||||||
|
future = Future[Any]()
|
||||||
|
future.set_result(result)
|
||||||
|
except Exception as e:
|
||||||
|
future = Future[Any]()
|
||||||
|
future.set_exception(e)
|
||||||
|
return [future]
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# UniProcExecutor will always be healthy as long as
|
# UniProcExecutor will always be healthy as long as
|
||||||
@ -116,8 +145,9 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
|||||||
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
|
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
|
||||||
("To get deterministic execution in V1, "
|
("To get deterministic execution in V1, "
|
||||||
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
|
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
|
||||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
super()._init_executor()
|
||||||
rpc_rank=0)
|
|
||||||
|
def _distributed_args(self) -> tuple[str, int, int]:
|
||||||
# engines are launched in torchrun-compatible launchers
|
# engines are launched in torchrun-compatible launchers
|
||||||
# so we can use the env:// method.
|
# so we can use the env:// method.
|
||||||
# required env vars:
|
# required env vars:
|
||||||
@ -128,19 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
|||||||
distributed_init_method = "env://"
|
distributed_init_method = "env://"
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
is_driver_worker = True
|
return distributed_init_method, rank, local_rank
|
||||||
kwargs = dict(
|
|
||||||
vllm_config=self.vllm_config,
|
|
||||||
local_rank=local_rank,
|
|
||||||
rank=rank,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
is_driver_worker=is_driver_worker,
|
|
||||||
)
|
|
||||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
|
||||||
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
|
||||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
|
||||||
self.collective_rpc("init_device")
|
|
||||||
self.collective_rpc("load_model")
|
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -159,6 +159,9 @@ class EngineCore:
|
|||||||
self.request_block_hasher = get_request_block_hasher(
|
self.request_block_hasher = get_request_block_hasher(
|
||||||
block_size, caching_hash_fn)
|
block_size, caching_hash_fn)
|
||||||
|
|
||||||
|
self.step_fn = (self.step if self.batch_queue is None else
|
||||||
|
self.step_with_batch_queue)
|
||||||
|
|
||||||
def _initialize_kv_caches(
|
def _initialize_kv_caches(
|
||||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -331,7 +334,8 @@ class EngineCore:
|
|||||||
model_executed = False
|
model_executed = False
|
||||||
if self.scheduler.has_requests():
|
if self.scheduler.has_requests():
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
future = self.model_executor.execute_model(scheduler_output)
|
future = self.model_executor.execute_model(scheduler_output,
|
||||||
|
non_block=True)
|
||||||
batch_queue.appendleft(
|
batch_queue.appendleft(
|
||||||
(future, scheduler_output)) # type: ignore[arg-type]
|
(future, scheduler_output)) # type: ignore[arg-type]
|
||||||
|
|
||||||
@ -534,9 +538,6 @@ class EngineCoreProc(EngineCore):
|
|||||||
assert addresses.coordinator_input is not None
|
assert addresses.coordinator_input is not None
|
||||||
logger.info("Waiting for READY message from DP Coordinator...")
|
logger.info("Waiting for READY message from DP Coordinator...")
|
||||||
|
|
||||||
self.step_fn = (self.step if self.batch_queue is None else
|
|
||||||
self.step_with_batch_queue)
|
|
||||||
|
|
||||||
# Mark the startup heap as static so that it's ignored by GC.
|
# Mark the startup heap as static so that it's ignored by GC.
|
||||||
# Reduces pause times of oldest generation collections.
|
# Reduces pause times of oldest generation collections.
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@ -245,8 +245,8 @@ class InprocClient(EngineCoreClient):
|
|||||||
self.engine_core = EngineCore(*args, **kwargs)
|
self.engine_core = EngineCore(*args, **kwargs)
|
||||||
|
|
||||||
def get_output(self) -> EngineCoreOutputs:
|
def get_output(self) -> EngineCoreOutputs:
|
||||||
outputs, _ = self.engine_core.step()
|
outputs, _ = self.engine_core.step_fn()
|
||||||
return outputs.get(0) or EngineCoreOutputs()
|
return outputs and outputs.get(0) or EngineCoreOutputs()
|
||||||
|
|
||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return self.engine_core.get_supported_tasks()
|
return self.engine_core.get_supported_tasks()
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from typing import Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -14,6 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa
|
|||||||
from vllm.executor.uniproc_executor import ( # noqa
|
from vllm.executor.uniproc_executor import ( # noqa
|
||||||
UniProcExecutor as UniProcExecutorV0)
|
UniProcExecutor as UniProcExecutorV0)
|
||||||
from vllm.utils import resolve_obj_by_qualname
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
|
|
||||||
@ -86,12 +87,22 @@ class Executor(ExecutorBase):
|
|||||||
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
|
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
|
||||||
return self.collective_rpc("get_kv_cache_spec")
|
return self.collective_rpc("get_kv_cache_spec")
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict] = None,
|
||||||
|
non_block: bool = False) -> list[Any]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output,
|
scheduler_output: SchedulerOutput,
|
||||||
|
non_block: bool = False,
|
||||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||||
output = self.collective_rpc("execute_model",
|
output = self.collective_rpc("execute_model",
|
||||||
args=(scheduler_output, ))
|
args=(scheduler_output, ),
|
||||||
|
non_block=non_block)
|
||||||
return output[0]
|
return output[0]
|
||||||
|
|
||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import weakref
|
|||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import partial
|
from functools import cached_property, partial
|
||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from multiprocessing.synchronize import Lock as LockType
|
from multiprocessing.synchronize import Lock as LockType
|
||||||
@ -37,6 +37,7 @@ from vllm.multimodal.cache import worker_receiver_cache_from_config
|
|||||||
from vllm.utils import (decorate_logs, get_distributed_init_method,
|
from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||||
get_loopback_ip, get_mp_context, get_open_port,
|
get_loopback_ip, get_mp_context, get_open_port,
|
||||||
set_process_title)
|
set_process_title)
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||||
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
|
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
|
||||||
@ -174,9 +175,9 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output,
|
scheduler_output: SchedulerOutput,
|
||||||
|
non_block: bool = False,
|
||||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||||
non_block = self.max_concurrent_batches > 1
|
|
||||||
|
|
||||||
if not self.has_connector:
|
if not self.has_connector:
|
||||||
# get output only from a single worker (output_rank)
|
# get output only from a single worker (output_rank)
|
||||||
@ -328,7 +329,7 @@ class MultiprocExecutor(Executor):
|
|||||||
self.collective_rpc("check_health", timeout=10)
|
self.collective_rpc("check_health", timeout=10)
|
||||||
return
|
return
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
if self.scheduler_config.async_scheduling:
|
if self.scheduler_config.async_scheduling:
|
||||||
return 2
|
return 2
|
||||||
@ -632,7 +633,8 @@ class WorkerProc:
|
|||||||
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
|
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
|
||||||
else:
|
else:
|
||||||
result = (WorkerProc.ResponseStatus.SUCCESS, output)
|
result = (WorkerProc.ResponseStatus.SUCCESS, output)
|
||||||
self.worker_response_mq.enqueue(result)
|
if (response_mq := self.worker_response_mq) is not None:
|
||||||
|
response_mq.enqueue(result)
|
||||||
|
|
||||||
def handle_output(self, output: Any):
|
def handle_output(self, output: Any):
|
||||||
"""Handles output from the worker. If async scheduling is enabled,
|
"""Handles output from the worker. If async scheduling is enabled,
|
||||||
|
|||||||
@ -66,11 +66,13 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
|
non_block: bool = False,
|
||||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||||
"""Execute the model on the Ray workers.
|
"""Execute the model on the Ray workers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scheduler_output: The scheduler output to execute.
|
scheduler_output: The scheduler output to execute.
|
||||||
|
non_block: If True, the method will return a Future.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The model runner output.
|
The model runner output.
|
||||||
@ -84,7 +86,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
|||||||
if not self.has_connector:
|
if not self.has_connector:
|
||||||
# Get output only from a single worker (output_rank)
|
# Get output only from a single worker (output_rank)
|
||||||
# When PP is not used, we block here until the result is available.
|
# When PP is not used, we block here until the result is available.
|
||||||
if self.max_concurrent_batches == 1:
|
if not non_block:
|
||||||
return refs[0].get()
|
return refs[0].get()
|
||||||
|
|
||||||
# When PP is used, we return a FutureWrapper immediately so that
|
# When PP is used, we return a FutureWrapper immediately so that
|
||||||
@ -92,7 +94,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
|||||||
return FutureWrapper(refs)
|
return FutureWrapper(refs)
|
||||||
|
|
||||||
# Get output from all workers when connector is present
|
# Get output from all workers when connector is present
|
||||||
if self.max_concurrent_batches == 1:
|
if not non_block:
|
||||||
# Block and get results from all workers
|
# Block and get results from all workers
|
||||||
outputs = [ref.get() for ref in refs]
|
outputs = [ref.get() for ref in refs]
|
||||||
return self.kv_output_aggregator.aggregate(outputs)
|
return self.kv_output_aggregator.aggregate(outputs)
|
||||||
@ -106,4 +108,3 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
|||||||
if reconfig_request.new_data_parallel_rank == \
|
if reconfig_request.new_data_parallel_rank == \
|
||||||
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
return
|
|
||||||
Loading…
x
Reference in New Issue
Block a user