Add docstrings for LLMServer and related classes and examples (#142)

This commit is contained in:
Zhuohan Li 2023-06-07 18:25:20 +08:00 committed by GitHub
parent e38074b1e6
commit 4298374265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 212 additions and 18 deletions

View File

@ -12,6 +12,20 @@ _GiB = 1 << 30
class ModelConfig: class ModelConfig:
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
This can increase the disk usage by up to 2x.
use_dummy_weights: Use dummy values for model weights (for profiling).
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
"""
def __init__( def __init__(
self, self,
@ -68,7 +82,14 @@ class ModelConfig:
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.
Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
CacheFlow execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
"""
def __init__( def __init__(
self, self,
block_size: int, block_size: int,
@ -111,7 +132,15 @@ class CacheConfig:
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution.
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
"""
def __init__( def __init__(
self, self,
pipeline_parallel_size: int, pipeline_parallel_size: int,
@ -134,7 +163,14 @@ class ParallelConfig:
class SchedulerConfig: class SchedulerConfig:
"""Scheduler configuration.
Args:
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
"""
def __init__( def __init__(
self, self,
max_num_batched_tokens: int, max_num_batched_tokens: int,

View File

@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int],
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(raw_request: Request): async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
"""
request = CompletionRequest(**await raw_request.json()) request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")

View File

@ -18,6 +18,12 @@ app = FastAPI()
@app.post("/generate") @app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse: async def generate_stream(request: Request) -> StreamingResponse:
""" Stream the results of the generation request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)

View File

@ -9,6 +9,7 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
@dataclass @dataclass
class ServerArgs: class ServerArgs:
"""Arguments for CacheFlow servers."""
model: str model: str
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
@ -117,6 +118,7 @@ class ServerArgs:
@dataclass @dataclass
class AsyncServerArgs(ServerArgs): class AsyncServerArgs(ServerArgs):
"""Arguments for asynchronous CacheFlow servers."""
server_use_ray: bool = False server_use_ray: bool = False
@staticmethod @staticmethod

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import time import time
from typing import Dict, Optional from typing import Dict, List, Optional
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
@ -15,7 +15,25 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMServer: class AsyncLLMServer:
"""An asynchronous wrapper for LLMServer.
This class is used to wrap the LLMServer class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMServer is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMServer to the caller.
NOTE: For the comprehensive list of arguments, see `LLMServer`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
*args, *kwargs: Arguments for LLMServer.
"""
def __init__(self, worker_use_ray: bool, server_use_ray: bool, def __init__(self, worker_use_ray: bool, server_use_ray: bool,
*args, **kwargs) -> None: *args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
@ -35,6 +53,7 @@ class AsyncLLMServer:
self.kicking_request_id: Optional[str] = None self.kicking_request_id: Optional[str] = None
async def server_step(self, kicking_request_id: Optional[str] = None): async def server_step(self, kicking_request_id: Optional[str] = None):
"""Kick the server to process the waiting requests."""
self.is_server_running = True self.is_server_running = True
self.kicking_request_id = kicking_request_id self.kicking_request_id = kicking_request_id
if self.server_use_ray: if self.server_use_ray:
@ -54,8 +73,31 @@ class AsyncLLMServer:
self.request_outputs[request_id] = request_output self.request_outputs[request_id] = request_output
self.request_events[request_id].set() self.request_events[request_id].set()
async def generate(self, prompt: str, sampling_params: SamplingParams, async def generate(
request_id: str) -> RequestOutput: self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
) -> RequestOutput:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMServer and streams the outputs
from the LLMServer to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
Yields:
The output `RequestOutput` objects from the LLMServer for the
request.
"""
# Preprocess the request. # Preprocess the request.
arrival_time = time.time() arrival_time = time.time()
@ -66,20 +108,29 @@ class AsyncLLMServer:
logger.info(f"Received request {request_id}: " logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, " f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}.") f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
# Add the request into the cacheflow server's waiting queue. # Add the request into the cacheflow server's waiting queue.
if self.server_use_ray: if self.server_use_ray:
await self.server.add_request.remote( await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time) request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else: else:
self.server.add_request( self.server.add_request(
request_id, prompt, sampling_params, arrival_time=arrival_time) request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The cacheflow server does not have a background loop that keeps # The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking # processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests. # the server to process the requests.
while True: while True:
if request_id not in self.request_events:
# The request has been aborted.
return
# Kick the server if the server is not running. # Kick the server if the server is not running.
if not self.is_server_running: if not self.is_server_running:
await self.server_step(request_id) await self.server_step(request_id)
@ -113,6 +164,14 @@ class AsyncLLMServer:
break break
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
if request_id not in self.request_events: if request_id not in self.request_events:
# The request has already finished or been aborted. # The request has already finished or been aborted.
return return
@ -137,6 +196,7 @@ class AsyncLLMServer:
@classmethod @classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer": def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
"""Creates an async LLM server from the server arguments."""
# Create the server configs. # Create the server configs.
server_configs = server_args.create_server_configs() server_configs = server_args.create_server_configs()
parallel_config = server_configs[2] parallel_config = server_configs[2]

View File

@ -8,7 +8,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import ray, initialize_cluster from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.server.tokenizer_utils import (get_tokenizer, from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally) detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
@ -19,6 +19,33 @@ logger = init_logger(__name__)
class LLMServer: class LLMServer:
"""An LLM server that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMServer` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
stage_devices: The list of devices for each stage. Each stage is a list
of (rank, node_resource, device) tuples.
log_stats: Whether to log statistics.
"""
def __init__( def __init__(
self, self,
@ -27,7 +54,7 @@ class LLMServer:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str, distributed_init_method: str,
stage_devices: List[List[Any]], stage_devices: List[List[DeviceID]],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
@ -83,6 +110,7 @@ class LLMServer:
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None: def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers( num_blocks = self._run_workers(
"profile_num_available_blocks", "profile_num_available_blocks",
@ -108,6 +136,7 @@ class LLMServer:
@classmethod @classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer": def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
"""Creates an LLM server from the server arguments."""
# Create the server configs. # Create the server configs.
server_configs = server_args.create_server_configs() server_configs = server_args.create_server_configs()
parallel_config = server_configs[2] parallel_config = server_configs[2]
@ -126,6 +155,22 @@ class LLMServer:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
) -> None: ) -> None:
"""Add a request to the server's request pool.
The request is added to the request pool and will be processed by the
scheduler as `server.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current time.
"""
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
if prompt_token_ids is None: if prompt_token_ids is None:
@ -148,15 +193,30 @@ class LLMServer:
self.scheduler.add_seq_group(seq_group) self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None: def abort_request(self, request_id: str) -> None:
"""Aborts a request with the given ID.
Args:
request_id: The ID of the request to abort.
"""
self.scheduler.abort_seq_group(request_id) self.scheduler.abort_seq_group(request_id)
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration for the server. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
# Nothing to do. # Nothing to do.
@ -188,7 +248,7 @@ class LLMServer:
return request_outputs return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Decode the sequence outputs. """Decodes the sequence outputs."""
for seq_group in seq_groups: for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally( new_token, new_output_text = detokenize_incrementally(
@ -201,7 +261,7 @@ class LLMServer:
seq.output_text = new_output_text seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences. """Stop the finished sequences."""
for seq_group in seq_groups: for seq_group in seq_groups:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -238,6 +298,7 @@ class LLMServer:
*args, *args,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers."""
all_outputs = [] all_outputs = []
for worker in self.workers: for worker in self.workers:
executor = getattr(worker, method) executor = getattr(worker, method)

View File

@ -14,15 +14,30 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def initialize_cluster( def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
server_use_ray: bool = False, server_use_ray: bool = False,
address: Optional[str] = None, ray_server_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]: ) -> Tuple[str, List[List[DeviceID]]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
server_use_ray: Whether to use Ray for async server.
ray_server_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
A tuple of (`distributed_init_method`, `all_stage_devices`). The
`distributed_init_method` is the address for initializing the
distributed backend. `all_stage_devices` includes device IDs for
each worker in each pipeline stage. Each device ID is a tuple of
(rank, node resource, device id).
"""
if parallel_config.worker_use_ray or server_use_ray: if parallel_config.worker_use_ray or server_use_ray:
if ray is None: if ray is None:
raise ImportError( raise ImportError(
"Ray is not installed. Please install Ray to use distributed " "Ray is not installed. Please install Ray to use distributed "
"serving.") "serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
ray.init(address=address) ray.init(address=ray_server_address)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:
# Initialize cluster locally. # Initialize cluster locally.

View File

@ -15,6 +15,7 @@ def get_tokenizer(
*args, *args,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama" and getattr(kwargs, "use_fast", True): if config.model_type == "llama" and getattr(kwargs, "use_fast", True):
# LLaMA fast tokenizer causes protobuf errors in some environments. # LLaMA fast tokenizer causes protobuf errors in some environments.

View File

@ -1,14 +1,15 @@
import openai import openai
# Modify OpenAI's API key and API base to use CacheFlow's API server.
openai.api_key = "EMPTY" openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1" openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m" model = "facebook/opt-125m"
# list models # Test list models API
models = openai.Model.list() models = openai.Model.list()
print(models) print("Models:", models)
# create a completion
# Test completion API
stream = True stream = True
completion = openai.Completion.create( completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2, model=model, prompt="A robot may not injure a human being", echo=False, n=2,
@ -19,4 +20,4 @@ if stream:
for c in completion: for c in completion:
print(c) print(c)
else: else:
print("completion:", completion) print("Completion result:", completion)

View File

@ -19,7 +19,7 @@ def main(args: argparse.Namespace):
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
] ]
# Run the server. # Run the server by calling `server.step()` manually.
request_id = 0 request_id = 0
while True: while True:
# To test iteration-level scheduling, we add one request at each step. # To test iteration-level scheduling, we add one request at each step.