[DOC] Add additional comments for LLMEngine and AsyncLLMEngine (#1011)

This commit is contained in:
Jiaxiang 2024-01-12 11:26:49 +08:00 committed by GitHub
parent 50376faa7b
commit 6549aef245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 242 additions and 15 deletions

View File

@ -9,11 +9,15 @@
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import os
import sys
from sphinx.ext import autodoc
import logging
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
logger = logging.getLogger(__name__)
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
@ -21,7 +25,6 @@ project = 'vLLM'
copyright = '2023, vLLM Team' copyright = '2023, vLLM Team'
author = 'the vLLM Team' author = 'the vLLM Team'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
@ -32,6 +35,8 @@ extensions = [
"sphinx.ext.viewcode", "sphinx.ext.viewcode",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",
"sphinx_copybutton", "sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
] ]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
@ -55,7 +60,6 @@ html_title = project
html_theme = 'sphinx_book_theme' html_theme = 'sphinx_book_theme'
html_logo = 'assets/logos/vllm-logo-text-light.png' html_logo = 'assets/logos/vllm-logo-text-light.png'
html_theme_options = { html_theme_options = {
'logo_only': True,
'path_to_docs': 'docs/source', 'path_to_docs': 'docs/source',
'repository_url': 'https://github.com/vllm-project/vllm', 'repository_url': 'https://github.com/vllm-project/vllm',
'use_repository_button': True, 'use_repository_button': True,
@ -64,4 +68,29 @@ html_theme_options = {
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] # html_static_path = ['_static']
# Mock out external dependencies here.
autodoc_mock_imports = [
"torch", "transformers", "psutil", "aioprometheus", "sentencepiece",
"vllm.cuda_utils", "vllm._C"
]
for mock_target in autodoc_mock_imports:
if mock_target in sys.modules:
logger.info(
f"Potentially problematic mock target ({mock_target}) found; "
"autodoc_mock_imports cannot mock modules that have already "
"been loaded into sys.modules when the sphinx build starts.")
class MockedClassDocumenter(autodoc.ClassDocumenter):
"""Remove note about base class when a class is derived from object."""
def add_line(self, line: str, source: str, *lineno: int) -> None:
if line == " Bases: :py:class:`object`":
return
super().add_line(line, source, *lineno)
autodoc.ClassDocumenter = MockedClassDocumenter

View File

@ -0,0 +1,7 @@
AsyncLLMEngine
=================================
.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
:members: generate, abort
:show-inheritance:

View File

@ -0,0 +1,13 @@
vLLM Engine
=================================
.. automodule:: vllm.engine
.. currentmodule:: vllm.engine
.. toctree::
:maxdepth: 2
:caption: Engines
llm_engine
async_llm_engine

View File

@ -0,0 +1,6 @@
LLMEngine
=================================
.. autoclass:: vllm.engine.llm_engine.LLMEngine
:members: add_request, abort_request, step, _init_cache
:show-inheritance:

View File

@ -85,4 +85,16 @@ Documentation
:maxdepth: 1 :maxdepth: 1
:caption: Quantization :caption: Quantization
quantization/auto_awq quantization/auto_awq
.. toctree::
:maxdepth: 2
:caption: Developer Documentation
dev/engine/engine_index
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`

View File

@ -88,6 +88,18 @@ class Scheduler:
self.waiting.append(seq_group) self.waiting.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
is present in any of the state queue.
If present, remove the sequence group from the state queue.
Also, if any of the sequences in the sequence group is not finished,
free the sequence with status `FINISHED_ABORTED`.
Otherwise, do nothing.
Args:
request_id: The ID(s) of the sequence group to abort.
"""
if isinstance(request_id, str): if isinstance(request_id, str):
request_id = (request_id, ) request_id = (request_id, )
request_ids = set(request_id) request_ids = set(request_id)

View File

@ -253,7 +253,8 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine. *args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
""" """
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
@ -428,6 +429,49 @@ class AsyncLLMEngine:
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine for the The output `RequestOutput` objects from the LLMEngine for the
request. request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "prompt": "What is LLM?",
>>> "stream": False, # assume the non-streaming case
>>> "temperature": 0.0,
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.generate(
>>> example_input["prompt"],
>>> SamplingParams(temperature=example_input["temperature"]),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
""" """
# Preprocess the request. # Preprocess the request.
# This should not be used for logging, as it is monotonic time. # This should not be used for logging, as it is monotonic time.

View File

@ -257,7 +257,26 @@ class LLMEngine:
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None: def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.""" """Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameters.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers( num_blocks = self._run_workers(
"profile_num_available_blocks", "profile_num_available_blocks",
@ -334,6 +353,30 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Example:
>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>> str(request_id),
>>> example_prompt,
>>> SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
""" """
if arrival_time is None: if arrival_time is None:
arrival_time = time.monotonic() arrival_time = time.monotonic()
@ -358,6 +401,17 @@ class LLMEngine:
Args: Args:
request_id: The ID(s) of the request to abort. request_id: The ID(s) of the request to abort.
Details:
- Refer to the
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class :class:`~vllm.core.scheduler.Scheduler`.
Example:
>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
""" """
self.scheduler.abort_seq_group(request_id) self.scheduler.abort_seq_group(request_id)
@ -617,11 +671,53 @@ class LLMEngine:
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first .. figure:: https://i.imgur.com/sv2HssD.png
schedules the sequences to be executed in the next iteration and the :alt: Overview of the step function
token blocks to be swapped in/out/copy. Then, it executes the model :align: center
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. Overview of the step function.
Details:
- Step 1: Schedules the sequences to be executed in the next
iteration and the token blocks to be swapped in/out/copy.
- Depending on the scheduling policy,
sequences may be `preempted/reordered`.
- A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt.
- Step 2: Calls the workers to execute the model.
- Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs.
- Updates the scheduled sequence groups with model outputs
based on its `sampling parameters` (`use_beam_search` or not).
- Frees the finished sequence groups.
- Finally, it creates and returns the newly generated results.
Example:
>>> # Please see the example/ folder for more detailed examples.
>>>
>>> # initialize engine and request arguments
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> example_inputs = [(0, "What is LLM?",
>>> SamplingParams(temperature=0.0))]
>>>
>>> # Start the engine with an event loop
>>> while True:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params)
>>>
>>> # continue the request processing
>>> request_outputs = engine.step()
>>> for request_output in request_outputs:
>>> if request_output.finished:
>>> # return or show the request output
>>>
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

View File

@ -87,6 +87,14 @@ class Worker:
gpu_memory_utilization: float, gpu_memory_utilization: float,
cpu_swap_space: int, cpu_swap_space: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
Args:
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
"""
# Profile the memory usage of the model and get the maximum number of # Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory. # cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache() torch.cuda.empty_cache()