Rename servers to engines (#152)

This commit is contained in:
Zhuohan Li 2023-06-17 17:25:21 +08:00 committed by GitHub
parent bab8f3dd0d
commit e5464ee484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 165 additions and 174 deletions

View File

@ -14,7 +14,7 @@ def main(args: argparse.Namespace):
# Process all the requests in a single batch if possible. # Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the server will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM( llm = LLM(
model=args.model, model=args.model,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,

View File

@ -2,7 +2,7 @@
On the server side, run one of the following commands: On the server side, run one of the following commands:
(CacheFlow backend) (CacheFlow backend)
python -m cacheflow.entrypoints.simple_fastapi_frontend \ python -m cacheflow.entrypoints.api_server \
--disable-log-requests --model <your_model> --disable-log-requests --model <your_model>
(TGI backend) (TGI backend)

View File

@ -84,7 +84,7 @@ def run_cacheflow(
seed=seed, seed=seed,
) )
# Add the requests to the server. # Add the requests to the engine.
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=n, n=n,
@ -103,7 +103,7 @@ def run_cacheflow(
start = time.time() start = time.time()
# FIXME(woosuk): Do use internal method. # FIXME(woosuk): Do use internal method.
llm._run_server(use_tqdm=True) llm._run_engine(use_tqdm=True)
end = time.time() end = time.time()
return end - start return end - start

View File

@ -1,9 +1,9 @@
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster
from cacheflow.entrypoints.llm import LLM from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput, CompletionOutput from cacheflow.outputs import CompletionOutput, RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import initialize_cluster
__version__ = "0.1.0" __version__ = "0.1.0"
@ -13,6 +13,6 @@ __all__ = [
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",
"LLMEngine", "LLMEngine",
"ServerArgs", "EngineArgs",
"initialize_cluster", "initialize_cluster",
] ]

View File

@ -216,7 +216,7 @@ class Scheduler:
if not self.log_stats: if not self.log_stats:
return scheduler_outputs, prompt_group_ids return scheduler_outputs, prompt_group_ids
# TODO(woosuk): Move the below code to server. # TODO(woosuk): Move the below code to the engine.
now = time.time() now = time.time()
if num_batched_tokens > 0: if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens)) self.num_input_tokens.append((now, num_batched_tokens))

View File

@ -8,8 +8,8 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
@dataclass @dataclass
class ServerArgs: class EngineArgs:
"""Arguments for CacheFlow servers.""" """Arguments for CacheFlow engine."""
model: str model: str
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
@ -33,12 +33,12 @@ class ServerArgs:
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser: ) -> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow servers.""" """Shared CLI arguments for CacheFlow engine."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use') help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str, parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir, default=EngineArgs.download_dir,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
'huggingface') 'huggingface')
@ -49,7 +49,7 @@ class ServerArgs:
parser.add_argument('--use-dummy-weights', action='store_true', parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights') help='use dummy values for model weights')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
choices=['auto', 'half', 'bfloat16', 'float'], choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. ' help='data type for model weights and activations. '
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
@ -60,46 +60,46 @@ class ServerArgs:
help='use Ray for distributed serving, will be ' help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU') 'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size, default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages') help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas') help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, parser.add_argument('--block-size', type=int,
default=ServerArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed, parser.add_argument('--seed', type=int, default=EngineArgs.seed,
help='random seed') help='random seed')
parser.add_argument('--swap-space', type=int, parser.add_argument('--swap-space', type=int,
default=ServerArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU') help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for' help='the percentage of GPU memory to be used for'
'the model executor') 'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per ' help='maximum number of batched tokens per '
'iteration') 'iteration')
parser.add_argument('--max-num-seqs', type=int, parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics') help='disable logging statistics')
return parser return parser
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs": def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
# Get the list of attributes of this dataclass. # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments. # Set the attributes from the parsed arguments.
server_args = cls(**{attr: getattr(args, attr) for attr in attrs}) engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return server_args return engine_args
def create_server_configs( def create_engine_configs(
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs. # Initialize the configs.
@ -117,19 +117,19 @@ class ServerArgs:
@dataclass @dataclass
class AsyncServerArgs(ServerArgs): class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous CacheFlow servers.""" """Arguments for asynchronous CacheFlow engine."""
server_use_ray: bool = False engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser: ) -> argparse.ArgumentParser:
parser = ServerArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--server-use-ray', action='store_true', parser.add_argument('--engine-use-ray', action='store_true',
help='use Ray to start the LLM server in a ' help='use Ray to start the LLM engine in a '
'separate process as the web server process.') 'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true', parser.add_argument('--disable-log-requests', action='store_true',
help='disable logging requests') help='disable logging requests')
return parser return parser

View File

@ -2,12 +2,12 @@ import asyncio
import time import time
from typing import Dict, List, Optional from typing import Dict, List, Optional
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster, ray
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import ray, initialize_cluster
logger = init_logger(__name__) logger = init_logger(__name__)
@ -29,44 +29,44 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as distributed execution. Should be the same as
`parallel_config.worker_use_ray`. `parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMEngine a Ray actor. If so, the engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine. *args, *kwargs: Arguments for LLMEngine.
""" """
def __init__(self, worker_use_ray: bool, server_use_ray: bool, def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None: log_requests: bool = True, *args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.server_use_ray = server_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
if not self.server_use_ray: if not self.engine_use_ray:
server_class = LLMEngine engine_class = LLMEngine
elif self.worker_use_ray: elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMEngine).remote engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
else: else:
server_class = ray.remote(num_gpus=1)(LLMEngine).remote engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.server = server_class(*args, **kwargs) self.engine = engine_class(*args, **kwargs)
# Request id -> request output. # Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {} self.request_outputs: Dict[str, RequestOutput] = {}
# Request id -> event to notify that there is new output. # Request id -> event to notify that there is new output.
self.request_events: Dict[str, asyncio.Event] = {} self.request_events: Dict[str, asyncio.Event] = {}
self.is_server_running = False self.is_engine_running = False
self.kicking_request_id: Optional[str] = None self.kicking_request_id: Optional[str] = None
async def server_step(self, kicking_request_id: Optional[str] = None): async def engine_step(self, kicking_request_id: Optional[str] = None):
"""Kick the server to process the waiting requests.""" """Kick the engine to process the waiting requests."""
self.is_server_running = True self.is_engine_running = True
self.kicking_request_id = kicking_request_id self.kicking_request_id = kicking_request_id
if self.server_use_ray: if self.engine_use_ray:
request_outputs = await self.server.step.remote() request_outputs = await self.engine.step.remote()
else: else:
# Yield to the event loop to allow other coroutines to run # Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new # while is_engine_running is True. This let the engine to add new
# requests into the queue. # requests into the queue.
await asyncio.sleep(0) await asyncio.sleep(0)
request_outputs = self.server.step() request_outputs = self.engine.step()
self.is_server_running = False self.is_engine_running = False
self.kicking_request_id = None self.kicking_request_id = None
# Notify the waiting coroutines that there are new outputs ready. # Notify the waiting coroutines that there are new outputs ready.
@ -104,7 +104,7 @@ class AsyncLLMEngine:
arrival_time = time.time() arrival_time = time.time()
# Create an event to notify us that there is new output from the # Create an event to notify us that there is new output from the
# cacheflow server. # cacheflow engine.
request_event = asyncio.Event() request_event = asyncio.Event()
self.request_events[request_id] = request_event self.request_events[request_id] = request_event
@ -114,31 +114,31 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, " f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.") f"prompt token ids: {prompt_token_ids}.")
# Add the request into the cacheflow server's waiting queue. # Add the request into the cacheflow engine's waiting queue.
if self.server_use_ray: if self.engine_use_ray:
await self.server.add_request.remote( await self.engine.add_request.remote(
request_id, prompt, sampling_params, request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time)
else: else:
self.server.add_request( self.engine.add_request(
request_id, prompt, sampling_params, request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time)
# The cacheflow server does not have a background loop that keeps # The cacheflow engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking # processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests. # the engine to process the requests.
while True: while True:
if request_id not in self.request_events: if request_id not in self.request_events:
# The request has been aborted. # The request has been aborted.
return return
# Kick the server if the server is not running. # Kick the engine if the engine is not running.
if not self.is_server_running: if not self.is_engine_running:
await self.server_step(request_id) await self.engine_step(request_id)
# Wait for new output. The group_event will be set in server_step # Wait for new output. The group_event will be set in engine_step
# when there is new output available for the sequence group. # when there is new output available for the sequence group.
# Added a timeout to prevent deadlock. # Added a timeout to prevent deadlock.
try: try:
@ -160,11 +160,11 @@ class AsyncLLMEngine:
del self.request_outputs[request_id] del self.request_outputs[request_id]
del self.request_events[request_id] del self.request_events[request_id]
# Kick the server if the server is not running. This is to # Kick the engine if the engine is not running. This is to
# prevent that there are still requests in server's waiting # prevent that there are still requests in engine's waiting
# queue to be executed. # queue to be executed.
if not self.is_server_running: if not self.is_engine_running:
await self.server_step() await self.engine_step()
break break
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
@ -183,36 +183,36 @@ class AsyncLLMEngine:
if self.log_requests: if self.log_requests:
logger.info(f"Aborted request {request_id}.") logger.info(f"Aborted request {request_id}.")
if self.server_use_ray: if self.engine_use_ray:
await self.server.abort_request.remote(request_id) await self.engine.abort_request.remote(request_id)
else: else:
self.server.abort_request(request_id) self.engine.abort_request(request_id)
if request_id in self.request_events: if request_id in self.request_events:
del self.request_events[request_id] del self.request_events[request_id]
if request_id in self.request_outputs: if request_id in self.request_outputs:
del self.request_outputs[request_id] del self.request_outputs[request_id]
# To prevent deadlock when a request is aborted while the server is # To prevent deadlock when a request is aborted while the engine is
# running. # running.
if self.kicking_request_id == request_id: if self.kicking_request_id == request_id:
self.is_server_running = False self.is_engine_running = False
self.kicking_request_id = None self.kicking_request_id = None
@classmethod @classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine": def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM server from the server arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the server configs. # Create the engine configs.
server_configs = server_args.create_server_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = server_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster( distributed_init_method, devices = initialize_cluster(
parallel_config, server_args.server_use_ray) parallel_config, engine_args.engine_use_ray)
# Create the LLM server. # Create the async LLM engine.
server = cls(server_args.worker_use_ray, engine = cls(engine_args.worker_use_ray,
server_args.server_use_ray, engine_args.engine_use_ray,
not server_args.disable_log_requests, not engine_args.disable_log_requests,
*server_configs, *engine_configs,
distributed_init_method, devices, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return server return engine

View File

@ -4,13 +4,13 @@ from typing import Any, List, Optional
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from cacheflow.core.scheduler import Scheduler from cacheflow.core.scheduler import Scheduler
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally,
get_tokenizer)
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker from cacheflow.worker.worker import Worker
@ -19,9 +19,9 @@ logger = init_logger(__name__)
class LLMEngine: class LLMEngine:
"""An LLM server that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests This is the main class for the CacheFlow LLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes space allocated for intermediate states (aka KV cache). This class utilizes
@ -31,8 +31,8 @@ class LLMEngine:
The `LLM` class wraps this class for offline batched inference and the The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving. `AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `ServerArgs`. comprehensive list of arguments, see `EngineArgs`.
Args: Args:
model_config: The configuration related to the LLM model. model_config: The configuration related to the LLM model.
@ -58,7 +58,7 @@ class LLMEngine:
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM server with config: " "Initializing an LLM engine with config: "
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, " f"use_dummy_weights={model_config.use_dummy_weights}, "
@ -135,17 +135,17 @@ class LLMEngine:
self._run_workers("init_cache_engine", cache_config=self.cache_config) self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod @classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine": def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
"""Creates an LLM server from the server arguments.""" """Creates an LLM engine from the engine arguments."""
# Create the server configs. # Create the engine configs.
server_configs = server_args.create_server_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = server_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server. # Create the LLM engine.
server = cls(*server_configs, distributed_init_method, devices, engine = cls(*engine_configs, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return server return engine
def add_request( def add_request(
self, self,
@ -155,10 +155,10 @@ class LLMEngine:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
) -> None: ) -> None:
"""Add a request to the server's request pool. """Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the The request is added to the request pool and will be processed by the
scheduler as `server.step()` is called. The exact scheduling policy is scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler. determined by the scheduler.
Args: Args:
@ -211,7 +211,7 @@ class LLMEngine:
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration for the server. It first This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes

View File

@ -13,15 +13,15 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def initialize_cluster( def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
server_use_ray: bool = False, engine_use_ray: bool = False,
ray_server_address: Optional[str] = None, ray_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]: ) -> Tuple[str, List[List[DeviceID]]]:
"""Initialize the distributed cluster probably with Ray. """Initialize the distributed cluster probably with Ray.
Args: Args:
parallel_config: The configurations for parallel execution. parallel_config: The configurations for parallel execution.
server_use_ray: Whether to use Ray for async server. engine_use_ray: Whether to use Ray for async engine.
ray_server_address: The address of the Ray cluster. If None, uses ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address. the default Ray cluster address.
Returns: Returns:
@ -31,13 +31,13 @@ def initialize_cluster(
each worker in each pipeline stage. Each device ID is a tuple of each worker in each pipeline stage. Each device ID is a tuple of
(rank, node resource, device id). (rank, node resource, device id).
""" """
if parallel_config.worker_use_ray or server_use_ray: if parallel_config.worker_use_ray or engine_use_ray:
if ray is None: if ray is None:
raise ImportError( raise ImportError(
"Ray is not installed. Please install Ray to use distributed " "Ray is not installed. Please install Ray to use distributed "
"serving.") "serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
ray.init(address=ray_server_address) ray.init(address=ray_address)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:
# Initialize cluster locally. # Initialize cluster locally.

View File

@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
import uvicorn import uvicorn
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.utils import random_uuid from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None: async def abort_request() -> None:
await server.abort(request_id) await engine.abort(request_id)
if stream: if stream:
background_tasks = BackgroundTasks() background_tasks = BackgroundTasks()
@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator: async for request_output in results_generator:
if await request.is_disconnected(): if await request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await server.abort(request_id) await engine.abort(request_id)
return Response(status_code=499) return Response(status_code=499)
final_output = request_output final_output = request_output
@ -75,11 +75,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser = AsyncServerArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
server_args = AsyncServerArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
server = AsyncLLMEngine.from_server_args(server_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug", uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -1,12 +1,12 @@
from typing import List, Optional, Union from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.utils import Counter from cacheflow.utils import Counter
@ -21,7 +21,7 @@ class LLM:
NOTE: This class is intended to be used for offline inference. For online NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead. serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`. NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args: Args:
model: The name or path of a HuggingFace Transformers model. model: The name or path of a HuggingFace Transformers model.
@ -45,20 +45,20 @@ class LLM:
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
server_args = ServerArgs( engine_args = EngineArgs(
model=model, model=model,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
seed=seed, seed=seed,
**kwargs, **kwargs,
) )
self.llm_server = LLMEngine.from_server_args(server_args) self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(
self, self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_server.tokenizer return self.llm_engine.tokenizer
def generate( def generate(
self, self,
@ -99,7 +99,7 @@ class LLM:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
# Add requests to the server. # Add requests to the engine.
if prompts is not None: if prompts is not None:
num_requests = len(prompts) num_requests = len(prompts)
else: else:
@ -111,7 +111,7 @@ class LLM:
else: else:
token_ids = prompt_token_ids[i] token_ids = prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids) self._add_request(prompt, sampling_params, token_ids)
return self._run_server(use_tqdm) return self._run_engine(use_tqdm)
def _add_request( def _add_request(
self, self,
@ -120,18 +120,18 @@ class LLM:
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params, self.llm_engine.add_request(request_id, prompt, sampling_params,
prompt_token_ids) prompt_token_ids)
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]: def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
num_requests = self.llm_server.get_num_unfinished_requests() num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts") pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the server. # Run the engine.
outputs: List[RequestOutput] = [] outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_server.step() step_outputs = self.llm_engine.step()
for output in step_outputs: for output in step_outputs:
if output.finished(): if output.finished():
outputs.append(output) outputs.append(output)

View File

@ -10,29 +10,20 @@ import fastapi
from fastapi import BackgroundTasks, Request from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn import uvicorn
from cacheflow.outputs import RequestOutput from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.server.async_llm_server import AsyncLLMEngine from cacheflow.engine.tokenizer_utils import get_tokenizer
from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import random_uuid from cacheflow.utils import random_uuid
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API. for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features: NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support - echo (since the cacheflow engine does not currently support
getting the logprobs of prompt tokens) getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support - suffix (the language models we currently support do not support
suffix) suffix)
- logit_bias (to be supported in cacheflow server) - logit_bias (to be supported in cacheflow engine)
""" """
request = CompletionRequest(**await raw_request.json()) request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")
@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
return error_check_ret return error_check_ret
if request.echo: if request.echo:
# We do not support echo since the cacheflow server does not # We do not support echo since the cacheflow engine does not
# currently support getting the logprobs of prompt tokens. # currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported") "echo is not currently supported")
@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None: if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow server. # TODO: support logit_bias in cacheflow engine.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported") "logit_bias is not currently supported")
@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params, result_generator = engine.generate(prompt, sampling_params,
request_id) request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
not request.use_beam_search) not request.use_beam_search)
async def abort_request() -> None: async def abort_request() -> None:
await server.abort(request_id) await engine.abort(request_id)
def create_stream_response_json(index: int, def create_stream_response_json(index: int,
text: str, text: str,
@ -303,7 +294,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not specified, " help="The model name used in the API. If not specified, "
"the model name will be the same as the " "the model name will be the same as the "
"huggingface name.") "huggingface name.")
parser = AsyncServerArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
app.add_middleware( app.add_middleware(
@ -318,8 +309,8 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
server = AsyncLLMEngine.from_server_args(server_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model) tokenizer = get_tokenizer(args.model)

View File

@ -1,12 +1,12 @@
import argparse import argparse
from cacheflow import ServerArgs, LLMEngine, SamplingParams from cacheflow import EngineArgs, LLMEngine, SamplingParams
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
# Parse the CLI argument and initialize the server. # Parse the CLI argument and initialize the engine.
server_args = ServerArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
server = LLMEngine.from_server_args(server_args) engine = LLMEngine.from_engine_args(engine_args)
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
@ -19,27 +19,27 @@ def main(args: argparse.Namespace):
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
] ]
# Run the server by calling `server.step()` manually. # Run the engine by calling `engine.step()` manually.
request_id = 0 request_id = 0
while True: while True:
# To test iteration-level scheduling, we add one request at each step. # To test iteration-level scheduling, we add one request at each step.
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
server.add_request(str(request_id), prompt, sampling_params) engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1 request_id += 1
request_outputs = server.step() request_outputs = engine.step()
for request_output in request_outputs: for request_output in request_outputs:
if request_output.finished(): if request_output.finished():
print(request_output) print(request_output)
if not (server.has_unfinished_requests() or test_prompts): if not (engine.has_unfinished_requests() or test_prompts):
break break
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class synchronously') description='Demo on using the LLMEngine class directly')
parser = ServerArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)