mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 22:06:02 +08:00
Refactor AsyncLLMEngine (#880)
This commit is contained in:
parent
bf87484efa
commit
ce741ba3e4
@ -1,6 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, SchedulerConfig
|
from vllm.config import CacheConfig, SchedulerConfig
|
||||||
from vllm.core.block_manager import BlockSpaceManager
|
from vllm.core.block_manager import BlockSpaceManager
|
||||||
@ -87,17 +87,22 @@ class Scheduler:
|
|||||||
# Add sequence groups to the waiting queue.
|
# Add sequence groups to the waiting queue.
|
||||||
self.waiting.append(seq_group)
|
self.waiting.append(seq_group)
|
||||||
|
|
||||||
def abort_seq_group(self, request_id: str) -> None:
|
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
|
if isinstance(request_id, str):
|
||||||
|
request_id = (request_id, )
|
||||||
|
request_ids = set(request_id)
|
||||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||||
for seq_group in state_queue:
|
for seq_group in state_queue:
|
||||||
if seq_group.request_id == request_id:
|
if seq_group.request_id in request_ids:
|
||||||
# Remove the sequence group from the state queue.
|
# Remove the sequence group from the state queue.
|
||||||
state_queue.remove(seq_group)
|
state_queue.remove(seq_group)
|
||||||
for seq in seq_group.seqs:
|
for seq in seq_group.seqs:
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
||||||
return
|
request_ids.remove(seq_group.request_id)
|
||||||
|
if not request_ids:
|
||||||
|
return
|
||||||
|
|
||||||
def has_unfinished_seqs(self) -> bool:
|
def has_unfinished_seqs(self) -> bool:
|
||||||
return self.waiting or self.running or self.swapped
|
return self.waiting or self.running or self.swapped
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from functools import partial
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@ -12,7 +13,105 @@ from vllm.sampling_params import SamplingParams
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
|
||||||
|
class AsyncStream:
|
||||||
|
"""A stream of RequestOutputs for a request that can be
|
||||||
|
iterated over asynchronously."""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
self._finished = False
|
||||||
|
|
||||||
|
def put(self, item: RequestOutput) -> None:
|
||||||
|
if self._finished:
|
||||||
|
return
|
||||||
|
self._queue.put_nowait(item)
|
||||||
|
|
||||||
|
def finish(self) -> None:
|
||||||
|
self._queue.put_nowait(StopIteration)
|
||||||
|
self._finished = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finished(self) -> bool:
|
||||||
|
return self._finished
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> RequestOutput:
|
||||||
|
result = await self._queue.get()
|
||||||
|
if result is StopIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_exception_on_finish(task: asyncio.Task) -> None:
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Task finished unexpectedly.") from e
|
||||||
|
raise RuntimeError("Task finished unexpectedly.")
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
|
"""Extension of LLMEngine to add async methods."""
|
||||||
|
|
||||||
|
async def step_async(self) -> List[RequestOutput]:
|
||||||
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
The workers are ran asynchronously if possible.
|
||||||
|
|
||||||
|
This function performs one decoding iteration of the engine. It first
|
||||||
|
schedules the sequences to be executed in the next iteration and the
|
||||||
|
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||||
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
|
the sequences and returns the newly generated results.
|
||||||
|
"""
|
||||||
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
|
early_return) = self._schedule()
|
||||||
|
if early_return is not None:
|
||||||
|
return early_return
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
output = await self._run_workers_async(
|
||||||
|
"execute_model",
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._process_worker_outputs(output, scheduler_outputs)
|
||||||
|
|
||||||
|
async def _run_workers_async(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*args,
|
||||||
|
get_all_outputs: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Runs the given method on all workers."""
|
||||||
|
all_outputs = []
|
||||||
|
for worker in self.workers:
|
||||||
|
if self.parallel_config.worker_use_ray:
|
||||||
|
executor = partial(worker.execute_method.remote, method)
|
||||||
|
else:
|
||||||
|
executor = getattr(worker, method)
|
||||||
|
|
||||||
|
output = executor(*args, **kwargs)
|
||||||
|
all_outputs.append(output)
|
||||||
|
|
||||||
|
if self.parallel_config.worker_use_ray:
|
||||||
|
all_outputs = await asyncio.gather(*all_outputs)
|
||||||
|
|
||||||
|
if get_all_outputs:
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
# Make sure all workers have the same results.
|
||||||
|
output = all_outputs[0]
|
||||||
|
for other_output in all_outputs[1:]:
|
||||||
|
assert output == other_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMEngine:
|
class AsyncLLMEngine:
|
||||||
@ -37,49 +136,111 @@ class AsyncLLMEngine:
|
|||||||
*args, *kwargs: Arguments for LLMEngine.
|
*args, *kwargs: Arguments for LLMEngine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
worker_use_ray: bool,
|
worker_use_ray: bool,
|
||||||
engine_use_ray: bool,
|
engine_use_ray: bool,
|
||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
|
start_engine_loop: bool = False,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.engine_use_ray = engine_use_ray
|
self.engine_use_ray = engine_use_ray
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
if not self.engine_use_ray:
|
self.engine = self._init_engine(*args, **kwargs)
|
||||||
engine_class = LLMEngine
|
|
||||||
elif self.worker_use_ray:
|
|
||||||
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
|
||||||
else:
|
|
||||||
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
|
||||||
self.engine = engine_class(*args, **kwargs)
|
|
||||||
# Request id -> request output.
|
|
||||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
|
||||||
# Request id -> event to notify that there is new output.
|
|
||||||
self.request_events: Dict[str, asyncio.Event] = {}
|
|
||||||
self.is_engine_running = False
|
|
||||||
self.kicking_request_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def engine_step(self, kicking_request_id: Optional[str] = None):
|
# Request id -> stream.
|
||||||
|
self.request_streams: Dict[str, AsyncStream] = {}
|
||||||
|
self.finished_requests: Set[str] = set()
|
||||||
|
self.background_loop = None
|
||||||
|
if start_engine_loop:
|
||||||
|
self._start_background_loop()
|
||||||
|
|
||||||
|
def _start_background_loop(self) -> None:
|
||||||
|
"""Start the background loop."""
|
||||||
|
if self.background_loop is not None:
|
||||||
|
raise RuntimeError("Background loop is already running.")
|
||||||
|
self.background_loop = asyncio.get_event_loop().create_task(
|
||||||
|
self.run_engine_loop())
|
||||||
|
self.background_loop.add_done_callback(_raise_exception_on_finish)
|
||||||
|
|
||||||
|
def _init_engine(self, *args,
|
||||||
|
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||||
|
if not self.engine_use_ray:
|
||||||
|
engine_class = self._engine_class
|
||||||
|
elif self.worker_use_ray:
|
||||||
|
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||||
|
else:
|
||||||
|
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||||
|
return engine_class(*args, **kwargs)
|
||||||
|
|
||||||
|
async def engine_step(self):
|
||||||
"""Kick the engine to process the waiting requests."""
|
"""Kick the engine to process the waiting requests."""
|
||||||
self.is_engine_running = True
|
|
||||||
self.kicking_request_id = kicking_request_id
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
request_outputs = await self.engine.step.remote()
|
request_outputs = await self.engine.step.remote()
|
||||||
else:
|
else:
|
||||||
# Yield to the event loop to allow other coroutines to run
|
request_outputs = await self.engine.step_async()
|
||||||
# while is_engine_running is True. This let the engine to add new
|
|
||||||
# requests into the queue.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
request_outputs = self.engine.step()
|
|
||||||
self.is_engine_running = False
|
|
||||||
self.kicking_request_id = None
|
|
||||||
|
|
||||||
# Notify the waiting coroutines that there are new outputs ready.
|
# Put the outputs into the corresponding streams.
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
request_id = request_output.request_id
|
request_id = request_output.request_id
|
||||||
self.request_outputs[request_id] = request_output
|
self.request_streams[request_id].put(request_output)
|
||||||
self.request_events[request_id].set()
|
if request_output.finished:
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info(f"Finished request {request_id}.")
|
||||||
|
self.request_streams[request_id].finish()
|
||||||
|
self.finished_requests.add(request_id)
|
||||||
|
|
||||||
|
await self._engine_abort(self.finished_requests)
|
||||||
|
for request_id in self.finished_requests:
|
||||||
|
del self.request_streams[request_id]
|
||||||
|
self.finished_requests.clear()
|
||||||
|
|
||||||
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
|
if self.engine_use_ray:
|
||||||
|
await self.engine.abort_request.remote(request_ids)
|
||||||
|
else:
|
||||||
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
|
async def run_engine_loop(self):
|
||||||
|
while True:
|
||||||
|
await self.engine_step()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
) -> AsyncStream:
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info(f"Received request {request_id}: "
|
||||||
|
f"prompt: {prompt!r}, "
|
||||||
|
f"sampling params: {sampling_params}, "
|
||||||
|
f"prompt token ids: {prompt_token_ids}.")
|
||||||
|
|
||||||
|
stream = AsyncStream(request_id)
|
||||||
|
self.request_streams[request_id] = stream
|
||||||
|
|
||||||
|
# Add the request into the vLLM engine's waiting queue.
|
||||||
|
if self.engine_use_ray:
|
||||||
|
await self.engine.add_request.remote(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
arrival_time=arrival_time)
|
||||||
|
else:
|
||||||
|
self.engine.add_request(request_id,
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
arrival_time=arrival_time)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@ -108,76 +269,19 @@ class AsyncLLMEngine:
|
|||||||
# Preprocess the request.
|
# Preprocess the request.
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
# Create an event to notify us that there is new output from the
|
try:
|
||||||
# vLLM engine.
|
stream = await self.add_request(request_id,
|
||||||
request_event = asyncio.Event()
|
prompt,
|
||||||
self.request_events[request_id] = request_event
|
sampling_params,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
arrival_time=arrival_time)
|
||||||
|
|
||||||
if self.log_requests:
|
async for request_output in stream:
|
||||||
logger.info(f"Received request {request_id}: "
|
yield request_output
|
||||||
f"prompt: {prompt!r}, "
|
except Exception as e:
|
||||||
f"sampling params: {sampling_params}, "
|
# If there is an exception, abort the request.
|
||||||
f"prompt token ids: {prompt_token_ids}.")
|
self._abort(request_id)
|
||||||
|
raise e
|
||||||
# Add the request into the vLLM engine's waiting queue.
|
|
||||||
if self.engine_use_ray:
|
|
||||||
await self.engine.add_request.remote(
|
|
||||||
request_id,
|
|
||||||
prompt,
|
|
||||||
sampling_params,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
arrival_time=arrival_time)
|
|
||||||
else:
|
|
||||||
self.engine.add_request(request_id,
|
|
||||||
prompt,
|
|
||||||
sampling_params,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
arrival_time=arrival_time)
|
|
||||||
|
|
||||||
# The vLLM engine does not have a background loop that keeps
|
|
||||||
# processing incoming requests. Therefore, we need to keep kicking
|
|
||||||
# the engine to process the requests.
|
|
||||||
while True:
|
|
||||||
if request_id not in self.request_events:
|
|
||||||
# The request has been aborted.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Kick the engine if the engine is not running.
|
|
||||||
if not self.is_engine_running:
|
|
||||||
try:
|
|
||||||
await self.engine_step(request_id)
|
|
||||||
except RuntimeError as e:
|
|
||||||
await self.abort(request_id)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Wait for new output. The group_event will be set in engine_step
|
|
||||||
# when there is new output available for the sequence group.
|
|
||||||
# Added a timeout to prevent deadlock.
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(request_event.wait(),
|
|
||||||
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
# Reset the event to wait for the next output.
|
|
||||||
request_event.clear()
|
|
||||||
|
|
||||||
# Decode and return new outputs.
|
|
||||||
request_output = self.request_outputs[request_id]
|
|
||||||
yield request_output
|
|
||||||
|
|
||||||
# Once finished, release the resources of the sequence group.
|
|
||||||
if request_output.finished:
|
|
||||||
if self.log_requests:
|
|
||||||
logger.info(f"Finished request {request_id}.")
|
|
||||||
|
|
||||||
del self.request_outputs[request_id]
|
|
||||||
del self.request_events[request_id]
|
|
||||||
# Kick the engine if the engine is not running. This is to
|
|
||||||
# prevent that there are still requests in engine's waiting
|
|
||||||
# queue to be executed.
|
|
||||||
if not self.is_engine_running:
|
|
||||||
await self.engine_step()
|
|
||||||
break
|
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
"""Abort a request.
|
"""Abort a request.
|
||||||
@ -188,28 +292,27 @@ class AsyncLLMEngine:
|
|||||||
Args:
|
Args:
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
"""
|
"""
|
||||||
if request_id not in self.request_events:
|
return self._abort(request_id)
|
||||||
|
|
||||||
|
def _abort(self, request_id: str) -> None:
|
||||||
|
"""Abort a request.
|
||||||
|
|
||||||
|
Abort a submitted request. If the request is finished or not found,
|
||||||
|
this method will be a no-op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The unique id of the request.
|
||||||
|
"""
|
||||||
|
if request_id not in self.request_streams or self.request_streams[
|
||||||
|
request_id].finished:
|
||||||
# The request has already finished or been aborted.
|
# The request has already finished or been aborted.
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info(f"Aborted request {request_id}.")
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
if self.engine_use_ray:
|
self.request_streams[request_id].finish()
|
||||||
await self.engine.abort_request.remote(request_id)
|
self.finished_requests.add(request_id)
|
||||||
else:
|
|
||||||
self.engine.abort_request(request_id)
|
|
||||||
|
|
||||||
if request_id in self.request_events:
|
|
||||||
del self.request_events[request_id]
|
|
||||||
if request_id in self.request_outputs:
|
|
||||||
del self.request_outputs[request_id]
|
|
||||||
|
|
||||||
# To prevent deadlock when a request is aborted while the engine is
|
|
||||||
# running.
|
|
||||||
if self.kicking_request_id == request_id:
|
|
||||||
self.is_engine_running = False
|
|
||||||
self.kicking_request_id = None
|
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
"""Get the model configuration of the vLLM engine."""
|
"""Get the model configuration of the vLLM engine."""
|
||||||
|
|||||||
@ -1,17 +1,18 @@
|
|||||||
import time
|
|
||||||
import copy
|
import copy
|
||||||
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
|
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||||
|
SequenceStatus)
|
||||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||||
get_tokenizer)
|
get_tokenizer)
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
@ -135,7 +136,8 @@ class LLMEngine:
|
|||||||
get_all_outputs=True,
|
get_all_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup"):
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
|
**ray_remote_kwargs):
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||||
@ -150,6 +152,7 @@ class LLMEngine:
|
|||||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
placement_group_capture_child_tasks=True),
|
placement_group_capture_child_tasks=True),
|
||||||
|
**ray_remote_kwargs,
|
||||||
)(RayWorker).remote()
|
)(RayWorker).remote()
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
@ -268,11 +271,11 @@ class LLMEngine:
|
|||||||
# Add the sequence group to the scheduler.
|
# Add the sequence group to the scheduler.
|
||||||
self.scheduler.add_seq_group(seq_group)
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
def abort_request(self, request_id: str) -> None:
|
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
"""Aborts a request with the given ID.
|
"""Aborts a request(s) with the given ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The ID of the request to abort.
|
request_id: The ID(s) of the request to abort.
|
||||||
"""
|
"""
|
||||||
self.scheduler.abort_seq_group(request_id)
|
self.scheduler.abort_seq_group(request_id)
|
||||||
|
|
||||||
@ -288,35 +291,21 @@ class LLMEngine:
|
|||||||
"""Returns True if there are unfinished requests."""
|
"""Returns True if there are unfinished requests."""
|
||||||
return self.scheduler.has_unfinished_seqs()
|
return self.scheduler.has_unfinished_seqs()
|
||||||
|
|
||||||
def step(self) -> List[RequestOutput]:
|
def _schedule(
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
self
|
||||||
|
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||||
This function performs one decoding iteration of the engine. It first
|
Optional[List[RequestOutput]]]:
|
||||||
schedules the sequences to be executed in the next iteration and the
|
|
||||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
|
||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
|
||||||
the sequences and returns the newly generated results.
|
|
||||||
"""
|
|
||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
if scheduler_outputs.is_empty():
|
if scheduler_outputs.is_empty():
|
||||||
if not scheduler_outputs.ignored_seq_groups:
|
return seq_group_metadata_list, scheduler_outputs, [
|
||||||
# Nothing to do.
|
|
||||||
return []
|
|
||||||
# If there are ignored seq groups, we need to return them as the
|
|
||||||
# request outputs.
|
|
||||||
return [
|
|
||||||
RequestOutput.from_seq_group(seq_group)
|
RequestOutput.from_seq_group(seq_group)
|
||||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||||
]
|
]
|
||||||
|
return seq_group_metadata_list, scheduler_outputs, None
|
||||||
|
|
||||||
# Execute the model.
|
def _process_worker_outputs(
|
||||||
output = self._run_workers(
|
self, output,
|
||||||
"execute_model",
|
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
|
||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
|
||||||
)
|
|
||||||
# Update the scheduler with the model outputs.
|
# Update the scheduler with the model outputs.
|
||||||
seq_groups = self.scheduler.update(output)
|
seq_groups = self.scheduler.update(output)
|
||||||
|
|
||||||
@ -339,6 +328,31 @@ class LLMEngine:
|
|||||||
scheduler_outputs.num_batched_tokens)
|
scheduler_outputs.num_batched_tokens)
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
|
def step(self) -> List[RequestOutput]:
|
||||||
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
|
||||||
|
This function performs one decoding iteration of the engine. It first
|
||||||
|
schedules the sequences to be executed in the next iteration and the
|
||||||
|
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||||
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
|
the sequences and returns the newly generated results.
|
||||||
|
"""
|
||||||
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
|
early_return) = self._schedule()
|
||||||
|
if early_return is not None:
|
||||||
|
return early_return
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
output = self._run_workers(
|
||||||
|
"execute_model",
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._process_worker_outputs(output, scheduler_outputs)
|
||||||
|
|
||||||
def _log_system_stats(
|
def _log_system_stats(
|
||||||
self,
|
self,
|
||||||
prompt_run: bool,
|
prompt_run: bool,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user