mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 13:45:45 +08:00
Rename servers to engines (#152)
This commit is contained in:
parent
bab8f3dd0d
commit
e5464ee484
@ -14,7 +14,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
# Process all the requests in a single batch if possible.
|
# Process all the requests in a single batch if possible.
|
||||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
# the server will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
On the server side, run one of the following commands:
|
On the server side, run one of the following commands:
|
||||||
(CacheFlow backend)
|
(CacheFlow backend)
|
||||||
python -m cacheflow.entrypoints.simple_fastapi_frontend \
|
python -m cacheflow.entrypoints.api_server \
|
||||||
--disable-log-requests --model <your_model>
|
--disable-log-requests --model <your_model>
|
||||||
|
|
||||||
(TGI backend)
|
(TGI backend)
|
||||||
|
|||||||
@ -84,7 +84,7 @@ def run_cacheflow(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the requests to the server.
|
# Add the requests to the engine.
|
||||||
for prompt, _, output_len in requests:
|
for prompt, _, output_len in requests:
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=n,
|
n=n,
|
||||||
@ -103,7 +103,7 @@ def run_cacheflow(
|
|||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
# FIXME(woosuk): Do use internal method.
|
# FIXME(woosuk): Do use internal method.
|
||||||
llm._run_server(use_tqdm=True)
|
llm._run_engine(use_tqdm=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
|
from cacheflow.engine.arg_utils import EngineArgs
|
||||||
|
from cacheflow.engine.llm_engine import LLMEngine
|
||||||
|
from cacheflow.engine.ray_utils import initialize_cluster
|
||||||
from cacheflow.entrypoints.llm import LLM
|
from cacheflow.entrypoints.llm import LLM
|
||||||
from cacheflow.outputs import RequestOutput, CompletionOutput
|
from cacheflow.outputs import CompletionOutput, RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
|
||||||
from cacheflow.server.llm_server import LLMEngine
|
|
||||||
from cacheflow.server.ray_utils import initialize_cluster
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
@ -13,6 +13,6 @@ __all__ = [
|
|||||||
"RequestOutput",
|
"RequestOutput",
|
||||||
"CompletionOutput",
|
"CompletionOutput",
|
||||||
"LLMEngine",
|
"LLMEngine",
|
||||||
"ServerArgs",
|
"EngineArgs",
|
||||||
"initialize_cluster",
|
"initialize_cluster",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -216,7 +216,7 @@ class Scheduler:
|
|||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return scheduler_outputs, prompt_group_ids
|
return scheduler_outputs, prompt_group_ids
|
||||||
|
|
||||||
# TODO(woosuk): Move the below code to server.
|
# TODO(woosuk): Move the below code to the engine.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if num_batched_tokens > 0:
|
if num_batched_tokens > 0:
|
||||||
self.num_input_tokens.append((now, num_batched_tokens))
|
self.num_input_tokens.append((now, num_batched_tokens))
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ServerArgs:
|
class EngineArgs:
|
||||||
"""Arguments for CacheFlow servers."""
|
"""Arguments for CacheFlow engine."""
|
||||||
model: str
|
model: str
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
use_np_weights: bool = False
|
use_np_weights: bool = False
|
||||||
@ -33,12 +33,12 @@ class ServerArgs:
|
|||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
parser: argparse.ArgumentParser,
|
parser: argparse.ArgumentParser,
|
||||||
) -> argparse.ArgumentParser:
|
) -> argparse.ArgumentParser:
|
||||||
"""Shared CLI arguments for CacheFlow servers."""
|
"""Shared CLI arguments for CacheFlow engine."""
|
||||||
# Model arguments
|
# Model arguments
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||||
help='name or path of the huggingface model to use')
|
help='name or path of the huggingface model to use')
|
||||||
parser.add_argument('--download-dir', type=str,
|
parser.add_argument('--download-dir', type=str,
|
||||||
default=ServerArgs.download_dir,
|
default=EngineArgs.download_dir,
|
||||||
help='directory to download and load the weights, '
|
help='directory to download and load the weights, '
|
||||||
'default to the default cache dir of '
|
'default to the default cache dir of '
|
||||||
'huggingface')
|
'huggingface')
|
||||||
@ -49,7 +49,7 @@ class ServerArgs:
|
|||||||
parser.add_argument('--use-dummy-weights', action='store_true',
|
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||||
help='use dummy values for model weights')
|
help='use dummy values for model weights')
|
||||||
# TODO(woosuk): Support FP32.
|
# TODO(woosuk): Support FP32.
|
||||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
|
||||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||||
help='data type for model weights and activations. '
|
help='data type for model weights and activations. '
|
||||||
'The "auto" option will use FP16 precision '
|
'The "auto" option will use FP16 precision '
|
||||||
@ -60,46 +60,46 @@ class ServerArgs:
|
|||||||
help='use Ray for distributed serving, will be '
|
help='use Ray for distributed serving, will be '
|
||||||
'automatically set when using more than 1 GPU')
|
'automatically set when using more than 1 GPU')
|
||||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||||
default=ServerArgs.pipeline_parallel_size,
|
default=EngineArgs.pipeline_parallel_size,
|
||||||
help='number of pipeline stages')
|
help='number of pipeline stages')
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||||
default=ServerArgs.tensor_parallel_size,
|
default=EngineArgs.tensor_parallel_size,
|
||||||
help='number of tensor parallel replicas')
|
help='number of tensor parallel replicas')
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
parser.add_argument('--block-size', type=int,
|
parser.add_argument('--block-size', type=int,
|
||||||
default=ServerArgs.block_size,
|
default=EngineArgs.block_size,
|
||||||
choices=[8, 16, 32],
|
choices=[8, 16, 32],
|
||||||
help='token block size')
|
help='token block size')
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
parser.add_argument('--seed', type=int, default=EngineArgs.seed,
|
||||||
help='random seed')
|
help='random seed')
|
||||||
parser.add_argument('--swap-space', type=int,
|
parser.add_argument('--swap-space', type=int,
|
||||||
default=ServerArgs.swap_space,
|
default=EngineArgs.swap_space,
|
||||||
help='CPU swap space size (GiB) per GPU')
|
help='CPU swap space size (GiB) per GPU')
|
||||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||||
default=ServerArgs.gpu_memory_utilization,
|
default=EngineArgs.gpu_memory_utilization,
|
||||||
help='the percentage of GPU memory to be used for'
|
help='the percentage of GPU memory to be used for'
|
||||||
'the model executor')
|
'the model executor')
|
||||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||||
default=ServerArgs.max_num_batched_tokens,
|
default=EngineArgs.max_num_batched_tokens,
|
||||||
help='maximum number of batched tokens per '
|
help='maximum number of batched tokens per '
|
||||||
'iteration')
|
'iteration')
|
||||||
parser.add_argument('--max-num-seqs', type=int,
|
parser.add_argument('--max-num-seqs', type=int,
|
||||||
default=ServerArgs.max_num_seqs,
|
default=EngineArgs.max_num_seqs,
|
||||||
help='maximum number of sequences per iteration')
|
help='maximum number of sequences per iteration')
|
||||||
parser.add_argument('--disable-log-stats', action='store_true',
|
parser.add_argument('--disable-log-stats', action='store_true',
|
||||||
help='disable logging statistics')
|
help='disable logging statistics')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
|
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
|
||||||
# Get the list of attributes of this dataclass.
|
# Get the list of attributes of this dataclass.
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
# Set the attributes from the parsed arguments.
|
# Set the attributes from the parsed arguments.
|
||||||
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
return server_args
|
return engine_args
|
||||||
|
|
||||||
def create_server_configs(
|
def create_engine_configs(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||||
# Initialize the configs.
|
# Initialize the configs.
|
||||||
@ -117,19 +117,19 @@ class ServerArgs:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AsyncServerArgs(ServerArgs):
|
class AsyncEngineArgs(EngineArgs):
|
||||||
"""Arguments for asynchronous CacheFlow servers."""
|
"""Arguments for asynchronous CacheFlow engine."""
|
||||||
server_use_ray: bool = False
|
engine_use_ray: bool = False
|
||||||
disable_log_requests: bool = False
|
disable_log_requests: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
parser: argparse.ArgumentParser,
|
parser: argparse.ArgumentParser,
|
||||||
) -> argparse.ArgumentParser:
|
) -> argparse.ArgumentParser:
|
||||||
parser = ServerArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
parser.add_argument('--server-use-ray', action='store_true',
|
parser.add_argument('--engine-use-ray', action='store_true',
|
||||||
help='use Ray to start the LLM server in a '
|
help='use Ray to start the LLM engine in a '
|
||||||
'separate process as the web server process.')
|
'separate process as the server process.')
|
||||||
parser.add_argument('--disable-log-requests', action='store_true',
|
parser.add_argument('--disable-log-requests', action='store_true',
|
||||||
help='disable logging requests')
|
help='disable logging requests')
|
||||||
return parser
|
return parser
|
||||||
@ -2,12 +2,12 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from cacheflow.engine.llm_engine import LLMEngine
|
||||||
|
from cacheflow.engine.ray_utils import initialize_cluster, ray
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
|
||||||
from cacheflow.server.llm_server import LLMEngine
|
|
||||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -29,44 +29,44 @@ class AsyncLLMEngine:
|
|||||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||||
distributed execution. Should be the same as
|
distributed execution. Should be the same as
|
||||||
`parallel_config.worker_use_ray`.
|
`parallel_config.worker_use_ray`.
|
||||||
server_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
||||||
async frontend will be executed in a separate process as the
|
async frontend will be executed in a separate process as the
|
||||||
model workers.
|
model workers.
|
||||||
log_requests: Whether to log the requests.
|
log_requests: Whether to log the requests.
|
||||||
*args, *kwargs: Arguments for LLMEngine.
|
*args, *kwargs: Arguments for LLMEngine.
|
||||||
"""
|
"""
|
||||||
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
|
||||||
log_requests: bool = True, *args, **kwargs) -> None:
|
log_requests: bool = True, *args, **kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.server_use_ray = server_use_ray
|
self.engine_use_ray = engine_use_ray
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
if not self.server_use_ray:
|
if not self.engine_use_ray:
|
||||||
server_class = LLMEngine
|
engine_class = LLMEngine
|
||||||
elif self.worker_use_ray:
|
elif self.worker_use_ray:
|
||||||
server_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
||||||
else:
|
else:
|
||||||
server_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
||||||
self.server = server_class(*args, **kwargs)
|
self.engine = engine_class(*args, **kwargs)
|
||||||
# Request id -> request output.
|
# Request id -> request output.
|
||||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||||
# Request id -> event to notify that there is new output.
|
# Request id -> event to notify that there is new output.
|
||||||
self.request_events: Dict[str, asyncio.Event] = {}
|
self.request_events: Dict[str, asyncio.Event] = {}
|
||||||
self.is_server_running = False
|
self.is_engine_running = False
|
||||||
self.kicking_request_id: Optional[str] = None
|
self.kicking_request_id: Optional[str] = None
|
||||||
|
|
||||||
async def server_step(self, kicking_request_id: Optional[str] = None):
|
async def engine_step(self, kicking_request_id: Optional[str] = None):
|
||||||
"""Kick the server to process the waiting requests."""
|
"""Kick the engine to process the waiting requests."""
|
||||||
self.is_server_running = True
|
self.is_engine_running = True
|
||||||
self.kicking_request_id = kicking_request_id
|
self.kicking_request_id = kicking_request_id
|
||||||
if self.server_use_ray:
|
if self.engine_use_ray:
|
||||||
request_outputs = await self.server.step.remote()
|
request_outputs = await self.engine.step.remote()
|
||||||
else:
|
else:
|
||||||
# Yield to the event loop to allow other coroutines to run
|
# Yield to the event loop to allow other coroutines to run
|
||||||
# while is_server_running is True. This let the server to add new
|
# while is_engine_running is True. This let the engine to add new
|
||||||
# requests into the queue.
|
# requests into the queue.
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
request_outputs = self.server.step()
|
request_outputs = self.engine.step()
|
||||||
self.is_server_running = False
|
self.is_engine_running = False
|
||||||
self.kicking_request_id = None
|
self.kicking_request_id = None
|
||||||
|
|
||||||
# Notify the waiting coroutines that there are new outputs ready.
|
# Notify the waiting coroutines that there are new outputs ready.
|
||||||
@ -104,7 +104,7 @@ class AsyncLLMEngine:
|
|||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
# Create an event to notify us that there is new output from the
|
# Create an event to notify us that there is new output from the
|
||||||
# cacheflow server.
|
# cacheflow engine.
|
||||||
request_event = asyncio.Event()
|
request_event = asyncio.Event()
|
||||||
self.request_events[request_id] = request_event
|
self.request_events[request_id] = request_event
|
||||||
|
|
||||||
@ -114,31 +114,31 @@ class AsyncLLMEngine:
|
|||||||
f"sampling params: {sampling_params}, "
|
f"sampling params: {sampling_params}, "
|
||||||
f"prompt token ids: {prompt_token_ids}.")
|
f"prompt token ids: {prompt_token_ids}.")
|
||||||
|
|
||||||
# Add the request into the cacheflow server's waiting queue.
|
# Add the request into the cacheflow engine's waiting queue.
|
||||||
if self.server_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.server.add_request.remote(
|
await self.engine.add_request.remote(
|
||||||
request_id, prompt, sampling_params,
|
request_id, prompt, sampling_params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time)
|
arrival_time=arrival_time)
|
||||||
else:
|
else:
|
||||||
self.server.add_request(
|
self.engine.add_request(
|
||||||
request_id, prompt, sampling_params,
|
request_id, prompt, sampling_params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time)
|
arrival_time=arrival_time)
|
||||||
|
|
||||||
# The cacheflow server does not have a background loop that keeps
|
# The cacheflow engine does not have a background loop that keeps
|
||||||
# processing incoming requests. Therefore, we need to keep kicking
|
# processing incoming requests. Therefore, we need to keep kicking
|
||||||
# the server to process the requests.
|
# the engine to process the requests.
|
||||||
while True:
|
while True:
|
||||||
if request_id not in self.request_events:
|
if request_id not in self.request_events:
|
||||||
# The request has been aborted.
|
# The request has been aborted.
|
||||||
return
|
return
|
||||||
|
|
||||||
# Kick the server if the server is not running.
|
# Kick the engine if the engine is not running.
|
||||||
if not self.is_server_running:
|
if not self.is_engine_running:
|
||||||
await self.server_step(request_id)
|
await self.engine_step(request_id)
|
||||||
|
|
||||||
# Wait for new output. The group_event will be set in server_step
|
# Wait for new output. The group_event will be set in engine_step
|
||||||
# when there is new output available for the sequence group.
|
# when there is new output available for the sequence group.
|
||||||
# Added a timeout to prevent deadlock.
|
# Added a timeout to prevent deadlock.
|
||||||
try:
|
try:
|
||||||
@ -160,11 +160,11 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
del self.request_outputs[request_id]
|
del self.request_outputs[request_id]
|
||||||
del self.request_events[request_id]
|
del self.request_events[request_id]
|
||||||
# Kick the server if the server is not running. This is to
|
# Kick the engine if the engine is not running. This is to
|
||||||
# prevent that there are still requests in server's waiting
|
# prevent that there are still requests in engine's waiting
|
||||||
# queue to be executed.
|
# queue to be executed.
|
||||||
if not self.is_server_running:
|
if not self.is_engine_running:
|
||||||
await self.server_step()
|
await self.engine_step()
|
||||||
break
|
break
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
@ -183,36 +183,36 @@ class AsyncLLMEngine:
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info(f"Aborted request {request_id}.")
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
if self.server_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.server.abort_request.remote(request_id)
|
await self.engine.abort_request.remote(request_id)
|
||||||
else:
|
else:
|
||||||
self.server.abort_request(request_id)
|
self.engine.abort_request(request_id)
|
||||||
|
|
||||||
if request_id in self.request_events:
|
if request_id in self.request_events:
|
||||||
del self.request_events[request_id]
|
del self.request_events[request_id]
|
||||||
if request_id in self.request_outputs:
|
if request_id in self.request_outputs:
|
||||||
del self.request_outputs[request_id]
|
del self.request_outputs[request_id]
|
||||||
|
|
||||||
# To prevent deadlock when a request is aborted while the server is
|
# To prevent deadlock when a request is aborted while the engine is
|
||||||
# running.
|
# running.
|
||||||
if self.kicking_request_id == request_id:
|
if self.kicking_request_id == request_id:
|
||||||
self.is_server_running = False
|
self.is_engine_running = False
|
||||||
self.kicking_request_id = None
|
self.kicking_request_id = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine":
|
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM server from the server arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the server configs.
|
# Create the engine configs.
|
||||||
server_configs = server_args.create_server_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
parallel_config = server_configs[2]
|
parallel_config = engine_configs[2]
|
||||||
# Initialize the cluster.
|
# Initialize the cluster.
|
||||||
distributed_init_method, devices = initialize_cluster(
|
distributed_init_method, devices = initialize_cluster(
|
||||||
parallel_config, server_args.server_use_ray)
|
parallel_config, engine_args.engine_use_ray)
|
||||||
# Create the LLM server.
|
# Create the async LLM engine.
|
||||||
server = cls(server_args.worker_use_ray,
|
engine = cls(engine_args.worker_use_ray,
|
||||||
server_args.server_use_ray,
|
engine_args.engine_use_ray,
|
||||||
not server_args.disable_log_requests,
|
not engine_args.disable_log_requests,
|
||||||
*server_configs,
|
*engine_configs,
|
||||||
distributed_init_method, devices,
|
distributed_init_method, devices,
|
||||||
log_stats=not server_args.disable_log_stats)
|
log_stats=not engine_args.disable_log_stats)
|
||||||
return server
|
return engine
|
||||||
@ -4,13 +4,13 @@ from typing import Any, List, Optional
|
|||||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from cacheflow.core.scheduler import Scheduler
|
from cacheflow.core.scheduler import Scheduler
|
||||||
|
from cacheflow.engine.arg_utils import EngineArgs
|
||||||
|
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray
|
||||||
|
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally,
|
||||||
|
get_tokenizer)
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
|
||||||
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
|
|
||||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
|
||||||
detokenize_incrementally)
|
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||||
from cacheflow.utils import Counter
|
from cacheflow.utils import Counter
|
||||||
from cacheflow.worker.worker import Worker
|
from cacheflow.worker.worker import Worker
|
||||||
@ -19,9 +19,9 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""An LLM server that receives requests and generates texts.
|
"""An LLM engine that receives requests and generates texts.
|
||||||
|
|
||||||
This is the main class for the CacheFlow LLM server. It receives requests
|
This is the main class for the CacheFlow LLM engine. It receives requests
|
||||||
from clients and generates texts from the LLM. It includes a tokenizer, a
|
from clients and generates texts from the LLM. It includes a tokenizer, a
|
||||||
language model (possibly distributed across multiple GPUs), and GPU memory
|
language model (possibly distributed across multiple GPUs), and GPU memory
|
||||||
space allocated for intermediate states (aka KV cache). This class utilizes
|
space allocated for intermediate states (aka KV cache). This class utilizes
|
||||||
@ -31,8 +31,8 @@ class LLMEngine:
|
|||||||
The `LLM` class wraps this class for offline batched inference and the
|
The `LLM` class wraps this class for offline batched inference and the
|
||||||
`AsyncLLMEngine` class wraps this class for online serving.
|
`AsyncLLMEngine` class wraps this class for online serving.
|
||||||
|
|
||||||
NOTE: The config arguments are derived from the `ServerArgs` class. For the
|
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
||||||
comprehensive list of arguments, see `ServerArgs`.
|
comprehensive list of arguments, see `EngineArgs`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_config: The configuration related to the LLM model.
|
model_config: The configuration related to the LLM model.
|
||||||
@ -58,7 +58,7 @@ class LLMEngine:
|
|||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Initializing an LLM server with config: "
|
"Initializing an LLM engine with config: "
|
||||||
f"model={model_config.model!r}, "
|
f"model={model_config.model!r}, "
|
||||||
f"dtype={model_config.dtype}, "
|
f"dtype={model_config.dtype}, "
|
||||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
||||||
@ -135,17 +135,17 @@ class LLMEngine:
|
|||||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine":
|
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
|
||||||
"""Creates an LLM server from the server arguments."""
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
# Create the server configs.
|
# Create the engine configs.
|
||||||
server_configs = server_args.create_server_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
parallel_config = server_configs[2]
|
parallel_config = engine_configs[2]
|
||||||
# Initialize the cluster.
|
# Initialize the cluster.
|
||||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||||
# Create the LLM server.
|
# Create the LLM engine.
|
||||||
server = cls(*server_configs, distributed_init_method, devices,
|
engine = cls(*engine_configs, distributed_init_method, devices,
|
||||||
log_stats=not server_args.disable_log_stats)
|
log_stats=not engine_args.disable_log_stats)
|
||||||
return server
|
return engine
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
@ -155,10 +155,10 @@ class LLMEngine:
|
|||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a request to the server's request pool.
|
"""Add a request to the engine's request pool.
|
||||||
|
|
||||||
The request is added to the request pool and will be processed by the
|
The request is added to the request pool and will be processed by the
|
||||||
scheduler as `server.step()` is called. The exact scheduling policy is
|
scheduler as `engine.step()` is called. The exact scheduling policy is
|
||||||
determined by the scheduler.
|
determined by the scheduler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -211,7 +211,7 @@ class LLMEngine:
|
|||||||
def step(self) -> List[RequestOutput]:
|
def step(self) -> List[RequestOutput]:
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
|
||||||
This function performs one decoding iteration for the server. It first
|
This function performs one decoding iteration of the engine. It first
|
||||||
schedules the sequences to be executed in the next iteration and the
|
schedules the sequences to be executed in the next iteration and the
|
||||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
@ -13,15 +13,15 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
|
|||||||
|
|
||||||
def initialize_cluster(
|
def initialize_cluster(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
server_use_ray: bool = False,
|
engine_use_ray: bool = False,
|
||||||
ray_server_address: Optional[str] = None,
|
ray_address: Optional[str] = None,
|
||||||
) -> Tuple[str, List[List[DeviceID]]]:
|
) -> Tuple[str, List[List[DeviceID]]]:
|
||||||
"""Initialize the distributed cluster probably with Ray.
|
"""Initialize the distributed cluster probably with Ray.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parallel_config: The configurations for parallel execution.
|
parallel_config: The configurations for parallel execution.
|
||||||
server_use_ray: Whether to use Ray for async server.
|
engine_use_ray: Whether to use Ray for async engine.
|
||||||
ray_server_address: The address of the Ray cluster. If None, uses
|
ray_address: The address of the Ray cluster. If None, uses
|
||||||
the default Ray cluster address.
|
the default Ray cluster address.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -31,13 +31,13 @@ def initialize_cluster(
|
|||||||
each worker in each pipeline stage. Each device ID is a tuple of
|
each worker in each pipeline stage. Each device ID is a tuple of
|
||||||
(rank, node resource, device id).
|
(rank, node resource, device id).
|
||||||
"""
|
"""
|
||||||
if parallel_config.worker_use_ray or server_use_ray:
|
if parallel_config.worker_use_ray or engine_use_ray:
|
||||||
if ray is None:
|
if ray is None:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Ray is not installed. Please install Ray to use distributed "
|
"Ray is not installed. Please install Ray to use distributed "
|
||||||
"serving.")
|
"serving.")
|
||||||
# Connect to a ray cluster.
|
# Connect to a ray cluster.
|
||||||
ray.init(address=ray_server_address)
|
ray.init(address=ray_address)
|
||||||
|
|
||||||
if not parallel_config.worker_use_ray:
|
if not parallel_config.worker_use_ray:
|
||||||
# Initialize cluster locally.
|
# Initialize cluster locally.
|
||||||
@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
|
|||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
|
||||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
|
||||||
from cacheflow.utils import random_uuid
|
from cacheflow.utils import random_uuid
|
||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||||
@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
|
|||||||
stream = request_dict.pop("stream", False)
|
stream = request_dict.pop("stream", False)
|
||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
results_generator = server.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||||
@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
|
|||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
async def abort_request() -> None:
|
async def abort_request() -> None:
|
||||||
await server.abort(request_id)
|
await engine.abort(request_id)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
background_tasks = BackgroundTasks()
|
background_tasks = BackgroundTasks()
|
||||||
@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
|
|||||||
async for request_output in results_generator:
|
async for request_output in results_generator:
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
await server.abort(request_id)
|
await engine.abort(request_id)
|
||||||
return Response(status_code=499)
|
return Response(status_code=499)
|
||||||
final_output = request_output
|
final_output = request_output
|
||||||
|
|
||||||
@ -75,11 +75,11 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser = AsyncServerArgs.add_cli_args(parser)
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
server_args = AsyncServerArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
server = AsyncLLMEngine.from_server_args(server_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from cacheflow.engine.arg_utils import EngineArgs
|
||||||
|
from cacheflow.engine.llm_engine import LLMEngine
|
||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
|
||||||
from cacheflow.server.llm_server import LLMEngine
|
|
||||||
from cacheflow.utils import Counter
|
from cacheflow.utils import Counter
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class LLM:
|
|||||||
|
|
||||||
NOTE: This class is intended to be used for offline inference. For online
|
NOTE: This class is intended to be used for offline inference. For online
|
||||||
serving, use the `AsyncLLMEngine` class instead.
|
serving, use the `AsyncLLMEngine` class instead.
|
||||||
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
|
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The name or path of a HuggingFace Transformers model.
|
model: The name or path of a HuggingFace Transformers model.
|
||||||
@ -45,20 +45,20 @@ class LLM:
|
|||||||
) -> None:
|
) -> None:
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
server_args = ServerArgs(
|
engine_args = EngineArgs(
|
||||||
model=model,
|
model=model,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.llm_server = LLMEngine.from_server_args(server_args)
|
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||||
self.request_counter = Counter()
|
self.request_counter = Counter()
|
||||||
|
|
||||||
def get_tokenizer(
|
def get_tokenizer(
|
||||||
self,
|
self,
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
return self.llm_server.tokenizer
|
return self.llm_engine.tokenizer
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -99,7 +99,7 @@ class LLM:
|
|||||||
# Use default sampling params.
|
# Use default sampling params.
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
# Add requests to the server.
|
# Add requests to the engine.
|
||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
num_requests = len(prompts)
|
num_requests = len(prompts)
|
||||||
else:
|
else:
|
||||||
@ -111,7 +111,7 @@ class LLM:
|
|||||||
else:
|
else:
|
||||||
token_ids = prompt_token_ids[i]
|
token_ids = prompt_token_ids[i]
|
||||||
self._add_request(prompt, sampling_params, token_ids)
|
self._add_request(prompt, sampling_params, token_ids)
|
||||||
return self._run_server(use_tqdm)
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
@ -120,18 +120,18 @@ class LLM:
|
|||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
self.llm_engine.add_request(request_id, prompt, sampling_params,
|
||||||
prompt_token_ids)
|
prompt_token_ids)
|
||||||
|
|
||||||
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
|
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||||
# Initialize tqdm.
|
# Initialize tqdm.
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
num_requests = self.llm_server.get_num_unfinished_requests()
|
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||||
# Run the server.
|
# Run the engine.
|
||||||
outputs: List[RequestOutput] = []
|
outputs: List[RequestOutput] = []
|
||||||
while self.llm_server.has_unfinished_requests():
|
while self.llm_engine.has_unfinished_requests():
|
||||||
step_outputs = self.llm_server.step()
|
step_outputs = self.llm_engine.step()
|
||||||
for output in step_outputs:
|
for output in step_outputs:
|
||||||
if output.finished():
|
if output.finished():
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|||||||
@ -10,29 +10,20 @@ import fastapi
|
|||||||
from fastapi import BackgroundTasks, Request
|
from fastapi import BackgroundTasks, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
from cacheflow.engine.tokenizer_utils import get_tokenizer
|
||||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
from cacheflow.entrypoints.openai.protocol import (
|
||||||
|
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||||
|
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
|
||||||
|
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.utils import random_uuid
|
from cacheflow.utils import random_uuid
|
||||||
from cacheflow.entrypoints.openai.protocol import (
|
|
||||||
CompletionRequest,
|
|
||||||
CompletionResponse,
|
|
||||||
CompletionResponseChoice,
|
|
||||||
CompletionResponseStreamChoice,
|
|
||||||
CompletionStreamResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
LogProbs,
|
|
||||||
ModelCard,
|
|
||||||
ModelList,
|
|
||||||
ModelPermission,
|
|
||||||
UsageInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
|
|||||||
for the API specification. This API mimics the OpenAI Completion API.
|
for the API specification. This API mimics the OpenAI Completion API.
|
||||||
|
|
||||||
NOTE: Currently we do not support the following features:
|
NOTE: Currently we do not support the following features:
|
||||||
- echo (since the cacheflow server does not currently support
|
- echo (since the cacheflow engine does not currently support
|
||||||
getting the logprobs of prompt tokens)
|
getting the logprobs of prompt tokens)
|
||||||
- suffix (the language models we currently support do not support
|
- suffix (the language models we currently support do not support
|
||||||
suffix)
|
suffix)
|
||||||
- logit_bias (to be supported in cacheflow server)
|
- logit_bias (to be supported in cacheflow engine)
|
||||||
"""
|
"""
|
||||||
request = CompletionRequest(**await raw_request.json())
|
request = CompletionRequest(**await raw_request.json())
|
||||||
logger.info(f"Received completion request: {request}")
|
logger.info(f"Received completion request: {request}")
|
||||||
@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
|
|||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
if request.echo:
|
if request.echo:
|
||||||
# We do not support echo since the cacheflow server does not
|
# We do not support echo since the cacheflow engine does not
|
||||||
# currently support getting the logprobs of prompt tokens.
|
# currently support getting the logprobs of prompt tokens.
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"echo is not currently supported")
|
"echo is not currently supported")
|
||||||
@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
|
|||||||
"suffix is not currently supported")
|
"suffix is not currently supported")
|
||||||
|
|
||||||
if request.logit_bias is not None:
|
if request.logit_bias is not None:
|
||||||
# TODO: support logit_bias in cacheflow server.
|
# TODO: support logit_bias in cacheflow engine.
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = server.generate(prompt, sampling_params,
|
result_generator = engine.generate(prompt, sampling_params,
|
||||||
request_id)
|
request_id)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
|
|||||||
not request.use_beam_search)
|
not request.use_beam_search)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
async def abort_request() -> None:
|
||||||
await server.abort(request_id)
|
await engine.abort(request_id)
|
||||||
|
|
||||||
def create_stream_response_json(index: int,
|
def create_stream_response_json(index: int,
|
||||||
text: str,
|
text: str,
|
||||||
@ -303,7 +294,7 @@ if __name__ == "__main__":
|
|||||||
help="The model name used in the API. If not specified, "
|
help="The model name used in the API. If not specified, "
|
||||||
"the model name will be the same as the "
|
"the model name will be the same as the "
|
||||||
"huggingface name.")
|
"huggingface name.")
|
||||||
parser = AsyncServerArgs.add_cli_args(parser)
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@ -318,8 +309,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
served_model = args.served_model_name or args.model
|
served_model = args.served_model_name or args.model
|
||||||
|
|
||||||
server_args = AsyncServerArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
server = AsyncLLMEngine.from_server_args(server_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
# A separate tokenizer to map token IDs to strings.
|
||||||
tokenizer = get_tokenizer(args.model)
|
tokenizer = get_tokenizer(args.model)
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from cacheflow import ServerArgs, LLMEngine, SamplingParams
|
from cacheflow import EngineArgs, LLMEngine, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
# Parse the CLI argument and initialize the server.
|
# Parse the CLI argument and initialize the engine.
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
server = LLMEngine.from_server_args(server_args)
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
# Test the following prompts.
|
# Test the following prompts.
|
||||||
test_prompts = [
|
test_prompts = [
|
||||||
@ -19,27 +19,27 @@ def main(args: argparse.Namespace):
|
|||||||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run the server by calling `server.step()` manually.
|
# Run the engine by calling `engine.step()` manually.
|
||||||
request_id = 0
|
request_id = 0
|
||||||
while True:
|
while True:
|
||||||
# To test iteration-level scheduling, we add one request at each step.
|
# To test iteration-level scheduling, we add one request at each step.
|
||||||
if test_prompts:
|
if test_prompts:
|
||||||
prompt, sampling_params = test_prompts.pop(0)
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
server.add_request(str(request_id), prompt, sampling_params)
|
engine.add_request(str(request_id), prompt, sampling_params)
|
||||||
request_id += 1
|
request_id += 1
|
||||||
|
|
||||||
request_outputs = server.step()
|
request_outputs = engine.step()
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if request_output.finished():
|
if request_output.finished():
|
||||||
print(request_output)
|
print(request_output)
|
||||||
|
|
||||||
if not (server.has_unfinished_requests() or test_prompts):
|
if not (engine.has_unfinished_requests() or test_prompts):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Demo on using the LLMEngine class synchronously')
|
description='Demo on using the LLMEngine class directly')
|
||||||
parser = ServerArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
Loading…
x
Reference in New Issue
Block a user