mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 20:45:01 +08:00
[Core] Eliminate parallel worker per-step task scheduling overhead (#4894)
This commit is contained in:
parent
97b030005c
commit
eb6d3c264d
@ -234,6 +234,14 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
|
|
||||||
|
if not request_outputs:
|
||||||
|
# Stop the execute model loop in parallel workers until there are
|
||||||
|
# more requests to process. This avoids waiting indefinitely in
|
||||||
|
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||||
|
# the RPC thread in the workers so that they can process any other
|
||||||
|
# queued control plane messages, such as add/remove lora adapters.
|
||||||
|
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||||
|
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
async def encode_request_async(
|
async def encode_request_async(
|
||||||
|
|||||||
@ -692,6 +692,14 @@ class LLMEngine:
|
|||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
|
|
||||||
|
if not request_outputs:
|
||||||
|
# Stop the execute model loop in parallel workers until there are
|
||||||
|
# more requests to process. This avoids waiting indefinitely in
|
||||||
|
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||||
|
# the RPC thread in the workers so that they can process any other
|
||||||
|
# queued control plane messages, such as add/remove lora adapters.
|
||||||
|
self.model_executor.stop_remote_worker_execution_loop()
|
||||||
|
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
def do_log_stats(
|
def do_log_stats(
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
|
import asyncio
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -13,6 +14,16 @@ logger = init_logger(__name__)
|
|||||||
class DistributedGPUExecutor(GPUExecutor):
|
class DistributedGPUExecutor(GPUExecutor):
|
||||||
"""Abstract superclass of multi-GPU executor implementations."""
|
"""Abstract superclass of multi-GPU executor implementations."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# This is non-None when the execute model loop is running
|
||||||
|
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||||
|
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||||
|
# Updated by implementations that require additional args to be passed
|
||||||
|
# to the _run_workers execute_model call
|
||||||
|
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Determine the number of available KV blocks.
|
"""Determine the number of available KV blocks.
|
||||||
|
|
||||||
@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
|
|||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks)
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
|
||||||
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
|
def execute_model(
|
||||||
all_outputs = self._run_workers("execute_model",
|
self,
|
||||||
driver_args=args,
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
driver_kwargs=kwargs)
|
if self.parallel_worker_tasks is None:
|
||||||
|
self.parallel_worker_tasks = self._run_workers(
|
||||||
|
"start_worker_execution_loop",
|
||||||
|
async_run_remote_workers_only=True,
|
||||||
|
**self.extra_execute_model_run_workers_kwargs)
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
# Only the driver worker returns the sampling results.
|
||||||
return all_outputs[0]
|
return self._driver_execute_model(execute_model_req)
|
||||||
|
|
||||||
|
def stop_remote_worker_execution_loop(self) -> None:
|
||||||
|
if self.parallel_worker_tasks is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._driver_execute_model()
|
||||||
|
parallel_worker_tasks = self.parallel_worker_tasks
|
||||||
|
self.parallel_worker_tasks = None
|
||||||
|
# Ensure that workers exit model loop cleanly
|
||||||
|
# (this will raise otherwise)
|
||||||
|
self._wait_for_tasks_completion(parallel_worker_tasks)
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||||
@ -88,39 +114,84 @@ class DistributedGPUExecutor(GPUExecutor):
|
|||||||
pattern=pattern,
|
pattern=pattern,
|
||||||
max_size=max_size)
|
max_size=max_size)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _driver_execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
|
Passing None will cause the driver to stop the model execution
|
||||||
|
loop running in each of the remote workers.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
*args,
|
*args,
|
||||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
async_run_remote_workers_only: bool = False,
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Runs the given method on all workers."""
|
"""Runs the given method on all workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_run_remote_workers_only: If True the method will be run only
|
||||||
|
in the remote workers, not the driver worker. It will also be
|
||||||
|
run asynchronously and return a list of futures rather than
|
||||||
|
blocking on the results.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||||
|
"""Wait for futures returned from _run_workers() with
|
||||||
|
async_run_remote_workers_only to complete."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
|
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
|
||||||
|
|
||||||
@abstractmethod
|
async def execute_model_async(
|
||||||
async def _run_workers_async(
|
|
||||||
self,
|
self,
|
||||||
method: str,
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
*args,
|
if self.parallel_worker_tasks is None:
|
||||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
# Start model execution loop running in the parallel workers
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
self.parallel_worker_tasks = asyncio.create_task(
|
||||||
**kwargs,
|
self._start_worker_execution_loop())
|
||||||
) -> Any:
|
|
||||||
"""Runs the given method on all workers."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def execute_model_async(self, *args,
|
|
||||||
**kwargs) -> List[SamplerOutput]:
|
|
||||||
all_outputs = await self._run_workers_async("execute_model",
|
|
||||||
driver_args=args,
|
|
||||||
driver_kwargs=kwargs)
|
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
# Only the driver worker returns the sampling results.
|
||||||
return all_outputs[0]
|
return await self._driver_execute_model_async(execute_model_req)
|
||||||
|
|
||||||
|
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||||
|
if self.parallel_worker_tasks is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._driver_execute_model_async()
|
||||||
|
parallel_worker_tasks = self.parallel_worker_tasks
|
||||||
|
self.parallel_worker_tasks = None
|
||||||
|
# Ensure that workers exit model loop cleanly
|
||||||
|
# (this will raise otherwise)
|
||||||
|
await parallel_worker_tasks
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _driver_execute_model_async(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Execute the model asynchronously in the driver worker.
|
||||||
|
|
||||||
|
Passing None will cause the driver to stop the model execution
|
||||||
|
loop running in each of the remote workers.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _start_worker_execution_loop(self):
|
||||||
|
"""Run execution loop on all workers. It guarantees all workers run
|
||||||
|
the loop or None of them is running the loop. Loop can be stopped by
|
||||||
|
`stop_remote_worker_execution_loop`.
|
||||||
|
The API is idempotent (guarantee only 1 loop run at any moment)."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -74,6 +74,10 @@ class ExecutorBase(ABC):
|
|||||||
"""Executes at least one model step on the given sequences."""
|
"""Executes at least one model step on the given sequences."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def stop_remote_worker_execution_loop(self) -> None:
|
||||||
|
"""Releases parallel workers from model loop."""
|
||||||
|
return
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -109,6 +113,10 @@ class ExecutorAsyncBase(ExecutorBase):
|
|||||||
"""Executes one model step on the given sequences."""
|
"""Executes one model step on the given sequences."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||||
|
"""Releases parallel workers from model loop."""
|
||||||
|
return
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
async def check_health_async(self) -> None:
|
||||||
"""Checks if the executor is healthy. If not, it should raise an
|
"""Checks if the executor is healthy. If not, it should raise an
|
||||||
exception."""
|
exception."""
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||||
ResultHandler, WorkerMonitor)
|
ResultHandler, WorkerMonitor)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
get_vllm_instance_id, make_async)
|
get_vllm_instance_id, make_async)
|
||||||
|
|
||||||
@ -71,16 +72,34 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
None)) is not None:
|
None)) is not None:
|
||||||
worker_monitor.close()
|
worker_monitor.close()
|
||||||
|
|
||||||
|
def _driver_execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
|
Passing None will cause the driver to stop the model execution
|
||||||
|
loop running in each of the remote workers.
|
||||||
|
"""
|
||||||
|
return self.driver_worker.execute_model(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
*args,
|
*args,
|
||||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
async_run_remote_workers_only: bool = False,
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Runs the given method on all workers."""
|
"""Runs the given method on all workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_run_remote_workers_only: If True the method will be run only
|
||||||
|
in the remote workers, not the driver worker. It will also be
|
||||||
|
run asynchronously and return a list of futures rather than
|
||||||
|
blocking on the results.
|
||||||
|
"""
|
||||||
|
|
||||||
if max_concurrent_workers:
|
if max_concurrent_workers:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -92,15 +111,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
for worker in self.workers
|
for worker in self.workers
|
||||||
]
|
]
|
||||||
|
|
||||||
if driver_args is None:
|
if async_run_remote_workers_only:
|
||||||
driver_args = args
|
# Just return futures
|
||||||
if driver_kwargs is None:
|
return worker_outputs
|
||||||
driver_kwargs = kwargs
|
|
||||||
|
|
||||||
# Start the driver worker after all the ray workers.
|
|
||||||
driver_worker_method = getattr(self.driver_worker, method)
|
driver_worker_method = getattr(self.driver_worker, method)
|
||||||
driver_worker_output = driver_worker_method(*driver_args,
|
driver_worker_output = driver_worker_method(*args, **kwargs)
|
||||||
**driver_kwargs)
|
|
||||||
|
|
||||||
# Get the results of the workers.
|
# Get the results of the workers.
|
||||||
return [driver_worker_output
|
return [driver_worker_output
|
||||||
@ -111,30 +127,29 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
if not self.worker_monitor.is_alive():
|
if not self.worker_monitor.is_alive():
|
||||||
raise RuntimeError("Worker processes are not running")
|
raise RuntimeError("Worker processes are not running")
|
||||||
|
|
||||||
|
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||||
|
"""Wait for futures returned from _run_workers() with
|
||||||
|
async_run_remote_workers_only to complete."""
|
||||||
|
for result in parallel_worker_tasks:
|
||||||
|
result.get()
|
||||||
|
|
||||||
|
|
||||||
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
|
class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
|
||||||
DistributedGPUExecutorAsync):
|
DistributedGPUExecutorAsync):
|
||||||
|
|
||||||
async def _run_workers_async(
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
||||||
|
|
||||||
|
async def _driver_execute_model_async(
|
||||||
self,
|
self,
|
||||||
method: str,
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
*args,
|
) -> List[SamplerOutput]:
|
||||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
return await self.driver_exec_model(execute_model_req)
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Any:
|
|
||||||
"""Runs the given method on all workers."""
|
|
||||||
if driver_args is None:
|
|
||||||
driver_args = args
|
|
||||||
if driver_kwargs is None:
|
|
||||||
driver_kwargs = kwargs
|
|
||||||
|
|
||||||
driver_executor = make_async(getattr(self.driver_worker, method))
|
async def _start_worker_execution_loop(self):
|
||||||
|
coros = [
|
||||||
# Run all the workers asynchronously.
|
worker.execute_method_async("start_worker_execution_loop")
|
||||||
coros = [driver_executor(*driver_args, **driver_kwargs)] + [
|
|
||||||
worker.execute_method_async(method, *args, **kwargs)
|
|
||||||
for worker in self.workers
|
for worker in self.workers
|
||||||
]
|
]
|
||||||
|
|
||||||
return await asyncio.gather(*coros)
|
return await asyncio.gather(*coros)
|
||||||
|
|||||||
@ -42,6 +42,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
self.forward_dag = None
|
self.forward_dag = None
|
||||||
if USE_RAY_COMPILED_DAG:
|
if USE_RAY_COMPILED_DAG:
|
||||||
self.forward_dag = self._compiled_ray_dag()
|
self.forward_dag = self._compiled_ray_dag()
|
||||||
|
self.extra_execute_model_run_workers_kwargs[
|
||||||
|
"use_ray_compiled_dag"] = True
|
||||||
|
|
||||||
def _configure_ray_workers_use_nsight(self,
|
def _configure_ray_workers_use_nsight(self,
|
||||||
ray_remote_kwargs) -> Dict[str, Any]:
|
ray_remote_kwargs) -> Dict[str, Any]:
|
||||||
@ -171,23 +173,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
max_concurrent_workers=self.parallel_config.
|
max_concurrent_workers=self.parallel_config.
|
||||||
max_parallel_loading_workers)
|
max_parallel_loading_workers)
|
||||||
|
|
||||||
def execute_model(
|
def _driver_execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
all_outputs = self._run_workers(
|
) -> List[SamplerOutput]:
|
||||||
"execute_model",
|
"""Run execute_model in the driver worker.
|
||||||
driver_kwargs={"execute_model_req": execute_model_req},
|
|
||||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
Passing None will cause the driver to stop the model execution
|
||||||
return all_outputs[0]
|
loop running in each of the remote workers.
|
||||||
|
"""
|
||||||
|
return self.driver_worker.execute_method("execute_model",
|
||||||
|
execute_model_req)
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
*args,
|
*args,
|
||||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
async_run_remote_workers_only: bool = False,
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||||
use_dummy_driver: bool = False,
|
use_dummy_driver: bool = False,
|
||||||
@ -198,9 +200,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
"""Runs the given method on all workers. Can be used in the following
|
"""Runs the given method on all workers. Can be used in the following
|
||||||
ways:
|
ways:
|
||||||
|
|
||||||
|
- async_run_remote_workers_only: If True the method will be run only
|
||||||
|
in the remote workers, not the driver worker. It will also be
|
||||||
|
run asynchronously and return a list of futures rather than blocking
|
||||||
|
on the results.
|
||||||
- args/kwargs: All workers share the same args/kwargs
|
- args/kwargs: All workers share the same args/kwargs
|
||||||
- args/kwargs and driver_args/driver_kwargs: Driver worker has
|
|
||||||
different args
|
|
||||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||||
individually
|
individually
|
||||||
"""
|
"""
|
||||||
@ -209,11 +213,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"max_concurrent_workers is not supported yet.")
|
"max_concurrent_workers is not supported yet.")
|
||||||
|
|
||||||
if driver_args is None:
|
|
||||||
driver_args = args if all_args is None else all_args[0]
|
|
||||||
if driver_kwargs is None:
|
|
||||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
|
||||||
|
|
||||||
count = len(self.workers)
|
count = len(self.workers)
|
||||||
all_worker_args = repeat(args, count) if all_args is None \
|
all_worker_args = repeat(args, count) if all_args is None \
|
||||||
else islice(all_args, 1, None)
|
else islice(all_args, 1, None)
|
||||||
@ -225,6 +224,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
# input. TODO(sang): Fix it.
|
# input. TODO(sang): Fix it.
|
||||||
assert self.forward_dag is not None
|
assert self.forward_dag is not None
|
||||||
output_channels = self.forward_dag.execute(1)
|
output_channels = self.forward_dag.execute(1)
|
||||||
|
ray_worker_outputs = []
|
||||||
else:
|
else:
|
||||||
# Start the ray workers first.
|
# Start the ray workers first.
|
||||||
ray_worker_outputs = [
|
ray_worker_outputs = [
|
||||||
@ -234,6 +234,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if async_run_remote_workers_only:
|
||||||
|
# Just return futures
|
||||||
|
return ray_worker_outputs
|
||||||
|
|
||||||
|
driver_args = args if all_args is None else all_args[0]
|
||||||
|
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||||
|
|
||||||
# Start the driver worker after all the ray workers.
|
# Start the driver worker after all the ray workers.
|
||||||
if not use_dummy_driver:
|
if not use_dummy_driver:
|
||||||
driver_worker_output = self.driver_worker.execute_method(
|
driver_worker_output = self.driver_worker.execute_method(
|
||||||
@ -260,6 +267,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
|
|
||||||
return [driver_worker_output] + ray_worker_outputs
|
return [driver_worker_output] + ray_worker_outputs
|
||||||
|
|
||||||
|
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||||
|
"""Wait for futures returned from _run_workers() with
|
||||||
|
async_run_remote_workers_only to complete."""
|
||||||
|
ray.get(parallel_worker_tasks)
|
||||||
|
|
||||||
def _compiled_ray_dag(self):
|
def _compiled_ray_dag(self):
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
required_version = "2.9"
|
required_version = "2.9"
|
||||||
@ -303,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.driver_executor = make_async(self.driver_worker.execute_method)
|
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||||
|
|
||||||
async def _run_workers_async(
|
async def _driver_execute_model_async(
|
||||||
self,
|
self,
|
||||||
method: str,
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
*args,
|
) -> List[SamplerOutput]:
|
||||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
return await self.driver_exec_method("execute_model",
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
execute_model_req)
|
||||||
**kwargs,
|
|
||||||
) -> Any:
|
|
||||||
"""Runs the given method on all workers."""
|
|
||||||
coros = []
|
|
||||||
|
|
||||||
if driver_args is None:
|
async def _start_worker_execution_loop(self):
|
||||||
driver_args = args
|
coros = [
|
||||||
if driver_kwargs is None:
|
worker.execute_method.remote("start_worker_execution_loop")
|
||||||
driver_kwargs = kwargs
|
for worker in self.workers
|
||||||
|
]
|
||||||
coros.append(
|
return await asyncio.gather(*coros)
|
||||||
self.driver_executor(method, *driver_args, **driver_kwargs))
|
|
||||||
|
|
||||||
# Run the ray workers asynchronously.
|
|
||||||
for worker in self.workers:
|
|
||||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
|
||||||
|
|
||||||
all_outputs = await asyncio.gather(*coros)
|
|
||||||
return all_outputs
|
|
||||||
|
|||||||
@ -47,7 +47,9 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
# NGram don't need gpu sampler
|
# NGram don't need gpu sampler
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
|
def execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
|
||||||
"""NGram doesn't depend on model execution, just pass this function"""
|
"""NGram doesn't depend on model execution, just pass this function"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -231,35 +231,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks)
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
|
||||||
def _broadcast_control_flow_decision(
|
|
||||||
self,
|
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
|
||||||
disable_all_speculation: bool = False) -> Tuple[int, bool]:
|
|
||||||
"""Broadcast how many lookahead slots are scheduled for this step, and
|
|
||||||
whether all speculation is disabled, to all non-driver workers.
|
|
||||||
|
|
||||||
This is required as if the number of draft model runs changes
|
|
||||||
dynamically, the non-driver workers won't know unless we perform a
|
|
||||||
communication to inform then.
|
|
||||||
|
|
||||||
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.rank == self._driver_rank:
|
|
||||||
assert execute_model_req is not None
|
|
||||||
|
|
||||||
broadcast_dict = dict(
|
|
||||||
num_lookahead_slots=execute_model_req.num_lookahead_slots,
|
|
||||||
disable_all_speculation=disable_all_speculation,
|
|
||||||
)
|
|
||||||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
|
||||||
else:
|
|
||||||
assert execute_model_req is None
|
|
||||||
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
|
|
||||||
|
|
||||||
return (broadcast_dict["num_lookahead_slots"],
|
|
||||||
broadcast_dict["disable_all_speculation"])
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -267,25 +238,40 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
"""Perform speculative decoding on the input batch.
|
"""Perform speculative decoding on the input batch.
|
||||||
"""
|
"""
|
||||||
|
if self.rank != self._driver_rank:
|
||||||
|
self._run_non_driver_rank()
|
||||||
|
return []
|
||||||
|
|
||||||
|
if execute_model_req is None:
|
||||||
|
# This signals that there's no more requests to process for now.
|
||||||
|
# All workers are running infinite loop with broadcast_tensor_dict,
|
||||||
|
# and it stops the loop when the driver broadcasts an empty input.
|
||||||
|
# Send an empty input to notify all other workers to stop their
|
||||||
|
# execution loop.
|
||||||
|
broadcast_tensor_dict({}, src=0)
|
||||||
|
return []
|
||||||
|
|
||||||
disable_all_speculation = False
|
|
||||||
if self.rank == self._driver_rank:
|
|
||||||
disable_all_speculation = self._should_disable_all_speculation(
|
disable_all_speculation = self._should_disable_all_speculation(
|
||||||
execute_model_req)
|
execute_model_req)
|
||||||
|
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||||
|
|
||||||
(num_lookahead_slots,
|
# Broadcast how many lookahead slots are scheduled for this step, and
|
||||||
disable_all_speculation) = self._broadcast_control_flow_decision(
|
# whether all speculation is disabled, to all non-driver workers.
|
||||||
execute_model_req, disable_all_speculation)
|
|
||||||
|
|
||||||
if self.rank == self._driver_rank:
|
# This is required as if the number of draft model runs changes
|
||||||
assert execute_model_req is not None
|
# dynamically, the non-driver workers won't know unless we perform a
|
||||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
# communication to inform then.
|
||||||
"speculative decoding requires non-None seq_group_metadata_list"
|
broadcast_dict = dict(
|
||||||
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
|
disable_all_speculation=disable_all_speculation,
|
||||||
)
|
)
|
||||||
|
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
||||||
|
|
||||||
|
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||||
|
"speculative decoding requires non-None seq_group_metadata_list")
|
||||||
|
|
||||||
self._maybe_disable_speculative_tokens(
|
self._maybe_disable_speculative_tokens(
|
||||||
disable_all_speculation,
|
disable_all_speculation, execute_model_req.seq_group_metadata_list)
|
||||||
execute_model_req.seq_group_metadata_list)
|
|
||||||
|
|
||||||
# If no spec tokens, call the proposer and scorer workers normally.
|
# If no spec tokens, call the proposer and scorer workers normally.
|
||||||
# Used for prefill.
|
# Used for prefill.
|
||||||
@ -296,9 +282,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
return self._run_speculative_decoding_step(execute_model_req,
|
return self._run_speculative_decoding_step(execute_model_req,
|
||||||
num_lookahead_slots)
|
num_lookahead_slots)
|
||||||
else:
|
|
||||||
self._run_non_driver_rank(num_lookahead_slots)
|
@torch.inference_mode()
|
||||||
return []
|
def start_worker_execution_loop(self) -> None:
|
||||||
|
"""Execute model loop to perform speculative decoding
|
||||||
|
in parallel worker."""
|
||||||
|
while self._run_non_driver_rank():
|
||||||
|
pass
|
||||||
|
|
||||||
def _should_disable_all_speculation(
|
def _should_disable_all_speculation(
|
||||||
self, execute_model_req: ExecuteModelRequest) -> bool:
|
self, execute_model_req: ExecuteModelRequest) -> bool:
|
||||||
@ -346,13 +336,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
sampler_output.logprobs = None
|
sampler_output.logprobs = None
|
||||||
return [sampler_output]
|
return [sampler_output]
|
||||||
|
|
||||||
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
|
def _run_non_driver_rank(self) -> bool:
|
||||||
"""Run proposer and verifier model in non-driver workers. This is used
|
"""Run proposer and verifier model in non-driver workers. This is used
|
||||||
for both speculation cases (num_lookahead_slots>0) and non-speculation
|
for both speculation cases (num_lookahead_slots>0) and non-speculation
|
||||||
cases (e.g. prefill).
|
cases (e.g. prefill).
|
||||||
|
|
||||||
|
Returns True iff there are remaining sequences to process.
|
||||||
"""
|
"""
|
||||||
# In non-driver workers the input is None
|
assert self.rank != self._driver_rank
|
||||||
execute_model_req = None
|
|
||||||
|
data = broadcast_tensor_dict(src=self._driver_rank)
|
||||||
|
if not data:
|
||||||
|
return False
|
||||||
|
num_lookahead_slots = data["num_lookahead_slots"]
|
||||||
|
|
||||||
# Even if num_lookahead_slots is zero, we want to run the proposer model
|
# Even if num_lookahead_slots is zero, we want to run the proposer model
|
||||||
# as it may have KV.
|
# as it may have KV.
|
||||||
@ -360,9 +356,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# We run the proposer once per lookahead slot. In the future we should
|
# We run the proposer once per lookahead slot. In the future we should
|
||||||
# delegate how many times it runs to the proposer.
|
# delegate how many times it runs to the proposer.
|
||||||
for _ in range(max(num_lookahead_slots, 1)):
|
for _ in range(max(num_lookahead_slots, 1)):
|
||||||
self.proposer_worker.execute_model(execute_model_req)
|
self.proposer_worker.execute_model()
|
||||||
|
|
||||||
self.scorer_worker.execute_model(execute_model_req)
|
self.scorer_worker.execute_model()
|
||||||
|
return True
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||||
def _run_speculative_decoding_step(
|
def _run_speculative_decoding_step(
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[PoolerOutput]:
|
) -> Optional[PoolerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, pooling_metadata,
|
(input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||||
@ -84,10 +84,11 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
||||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
(
|
(
|
||||||
input_tokens,
|
input_tokens,
|
||||||
|
|||||||
@ -609,10 +609,11 @@ class ModelRunner:
|
|||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
(
|
(
|
||||||
input_tokens,
|
input_tokens,
|
||||||
@ -676,7 +677,7 @@ class ModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||||
|
|||||||
@ -226,27 +226,27 @@ class Worker(WorkerBase):
|
|||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
self._execute_model_non_driver()
|
||||||
|
return []
|
||||||
|
|
||||||
if execute_model_req is None:
|
if execute_model_req is None:
|
||||||
seq_group_metadata_list = None
|
# This signals that there's no more requests to process for now.
|
||||||
else:
|
# All workers are running infinite loop with broadcast_tensor_dict,
|
||||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
# and it stops the loop when the driver broadcasts an empty input.
|
||||||
|
# Send an empty input to notify all other workers to stop their
|
||||||
|
# execution loop.
|
||||||
|
broadcast_tensor_dict({}, src=0)
|
||||||
|
return []
|
||||||
|
|
||||||
blocks_to_swap_in: torch.Tensor
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
blocks_to_swap_out: torch.Tensor
|
|
||||||
blocks_to_copy: torch.Tensor
|
|
||||||
if self.is_driver_worker:
|
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
assert execute_model_req is not None
|
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||||
# they contain parameters to launch cudamemcpyasync.
|
# they contain parameters to launch cudamemcpyasync.
|
||||||
blocks_to_swap_in = torch.tensor(
|
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||||||
execute_model_req.blocks_to_swap_in,
|
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int64).view(-1, 2)
|
dtype=torch.int64).view(-1, 2)
|
||||||
blocks_to_swap_out = torch.tensor(
|
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
|
||||||
execute_model_req.blocks_to_swap_out,
|
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int64).view(-1, 2)
|
dtype=torch.int64).view(-1, 2)
|
||||||
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||||||
@ -262,12 +262,6 @@ class Worker(WorkerBase):
|
|||||||
"blocks_to_copy": blocks_to_copy,
|
"blocks_to_copy": blocks_to_copy,
|
||||||
}
|
}
|
||||||
broadcast_tensor_dict(data, src=0)
|
broadcast_tensor_dict(data, src=0)
|
||||||
else:
|
|
||||||
data = broadcast_tensor_dict(src=0)
|
|
||||||
num_seq_groups = data["num_seq_groups"]
|
|
||||||
blocks_to_swap_in = data["blocks_to_swap_in"]
|
|
||||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
|
||||||
blocks_to_copy = data["blocks_to_copy"]
|
|
||||||
|
|
||||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
@ -282,6 +276,39 @@ class Worker(WorkerBase):
|
|||||||
# to conform to interface.
|
# to conform to interface.
|
||||||
return [output]
|
return [output]
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def start_worker_execution_loop(self) -> None:
|
||||||
|
"""Execute model loop in parallel worker.
|
||||||
|
|
||||||
|
You can stop the loop by executing a driver worker with an empty output.
|
||||||
|
See `stop_remote_worker_execution_loop` for more details.
|
||||||
|
"""
|
||||||
|
while self._execute_model_non_driver():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _execute_model_non_driver(self) -> bool:
|
||||||
|
"""Execute model in parallel worker.
|
||||||
|
|
||||||
|
Returns True iff there are remaining sequences to process.
|
||||||
|
"""
|
||||||
|
assert not self.is_driver_worker
|
||||||
|
data = broadcast_tensor_dict(src=0)
|
||||||
|
if not data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
num_seq_groups = data.get("num_seq_groups", 0)
|
||||||
|
blocks_to_swap_in = data.get("blocks_to_swap_in")
|
||||||
|
blocks_to_swap_out = data.get("blocks_to_swap_out")
|
||||||
|
blocks_to_copy = data.get("blocks_to_copy")
|
||||||
|
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
|
# If there is no input, we don't need to execute the model.
|
||||||
|
if num_seq_groups == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.model_runner.execute_model(None, self.gpu_cache)
|
||||||
|
return True
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
return self.model_runner.add_lora(lora_request)
|
return self.model_runner.add_lora(lora_request)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -49,7 +49,8 @@ class WorkerBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
"""Executes at least one model step on the given sequences, unless no
|
"""Executes at least one model step on the given sequences, unless no
|
||||||
sequences are provided."""
|
sequences are provided."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user