Fix various issues of async servers (#135)

This commit is contained in:
Zhuohan Li 2023-06-05 23:44:50 +08:00 committed by GitHub
parent 8274ca23ac
commit 1a956e136b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 289 additions and 121 deletions

View File

@ -0,0 +1,58 @@
import argparse
import json
import threading
import time
import requests
def main(args: argparse.Namespace):
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
for i in range(args.n_threads)]
headers = {"User-Agent": "CacheFlow Benchmark Client"}
ploads = [{
"prompt": p,
"max_tokens": args.max_tokens,
"temperature": 0.0,
"ignore_eos": True,
} for p in prompts]
def send_request(results, i):
response = requests.post(args.api_url, headers=headers,
json=ploads[i], stream=True)
results[i] = response
# use args.n_threads to prompt the backend
tik = time.time()
threads = []
results = [None] * args.n_threads
for i in range(args.n_threads):
t = threading.Thread(target=send_request, args=(results, i))
t.start()
threads.append(t)
for t in threads:
t.join()
print(f"Time (POST): {time.time() - tik} s")
n_words = 0
for i, response in enumerate(results):
k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
response_new_words = json.loads(k[-2].decode("utf-8"))["text"][0]
n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
time_seconds = time.time() - tik
print(f"Time (total): {time_seconds:.3f}s to finish, n_threads: {args.n_threads}, "
f"throughput: {n_words / time_seconds} words/s.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args()
main(args)

View File

@ -116,15 +116,15 @@ class ParallelConfig:
self, self,
pipeline_parallel_size: int, pipeline_parallel_size: int,
tensor_parallel_size: int, tensor_parallel_size: int,
use_ray: bool, worker_use_ray: bool,
) -> None: ) -> None:
self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size self.tensor_parallel_size = tensor_parallel_size
self.use_ray = use_ray self.worker_use_ray = worker_use_ray
self.world_size = pipeline_parallel_size * tensor_parallel_size self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1: if self.world_size > 1:
self.use_ray = True self.worker_use_ray = True
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:

View File

@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group. # the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set() blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status): if seq.is_finished():
continue continue
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
for block in block_table: for block in block_table:
@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block. # CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status): if seq.is_finished():
continue continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block. # GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status): if seq.is_finished():
continue continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
@ -231,6 +231,9 @@ class BlockSpaceManager:
self.cpu_allocator.free(block) self.cpu_allocator.free(block)
def free(self, seq: Sequence) -> None: def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table) self._free_block_table(block_table)
del self.block_tables[seq.seq_id] del self.block_tables[seq.seq_id]

View File

@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
logger = init_logger(__name__) logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 10 _LOGGING_INTERVAL_SEC = 5
class PreemptionMode(enum.Enum): class PreemptionMode(enum.Enum):
@ -84,6 +84,18 @@ class Scheduler:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
def abort_seq_group(self, request_id: str) -> None:
for state_queue in [self.waiting, self.running, self.swapped]:
for seq_group in state_queue:
if seq_group.request_id == request_id:
# Remove the sequence group from the state queue.
state_queue.remove(seq_group)
for seq in seq_group.seqs:
if seq.is_finished():
continue
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
return
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return self.waiting or self.running or self.swapped

View File

@ -7,13 +7,14 @@ import time
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
import fastapi import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn import uvicorn
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
UsageInfo, UsageInfo,
) )
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(request: CompletionRequest): async def create_completion(raw_request: Request):
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params, result_generator = server.generate(prompt, sampling_params,
request_id=request_id) request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
(request.best_of is None or request.n == request.best_of) and (request.best_of is None or request.n == request.best_of) and
not request.use_beam_search) not request.use_beam_search)
async def abort_request() -> None:
await server.abort(request_id)
def create_stream_response_json(index: int, def create_stream_response_json(index: int,
text: str, text: str,
logprobs: Optional[LogProbs] = None, logprobs: Optional[LogProbs] = None,
@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
# Streaming response # Streaming response
if stream: if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream",
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res final_res = res
assert final_res is not None assert final_res is not None
choices = [] choices = []
@ -276,7 +291,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not specified, " help="The model name used in the API. If not specified, "
"the model name will be the same as the " "the model name will be the same as the "
"huggingface name.") "huggingface name.")
parser = ServerArgs.add_cli_args(parser) parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
app.add_middleware( app.add_middleware(
@ -291,10 +306,11 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model served_model = args.served_model_name or args.model
server_args = ServerArgs.from_cli_args(args) server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args) server = AsyncLLMServer.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model) tokenizer = get_tokenizer(args.model)
uvicorn.run(app, host=args.host, port=args.port, log_level="info") uvicorn.run(app, host=args.host, port=args.port, log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -2,15 +2,16 @@ import argparse
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import uvicorn import uvicorn
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI() app = FastAPI()
@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
results_generator = server.generate(prompt, sampling_params) request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id)
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator: async for request_output in results_generator:
@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
} }
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
return StreamingResponse(stream_results()) async def abort_request() -> None:
await server.abort(request_id)
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser = ServerArgs.add_cli_args(parser) parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args) server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args) server = AsyncLLMServer.from_server_args(server_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="info") uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
SWAPPED = enum.auto() SWAPPED = enum.auto()
FINISHED_STOPPED = enum.auto() FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
@staticmethod @staticmethod
def is_finished(status: "SequenceStatus") -> bool: def is_finished(status: "SequenceStatus") -> bool:
return status in [ return status in [
SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED,
] ]
@staticmethod @staticmethod
@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
finish_reason = "stop" finish_reason = "stop"
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
finish_reason = "length" finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort"
else: else:
finish_reason = None finish_reason = None
return finish_reason return finish_reason
class SequenceData: class SequenceData:
def __init__( def __init__(
@ -137,6 +142,9 @@ class Sequence:
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob return self.data.cumulative_logprob
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None: def fork(self, child_seq: 'Sequence') -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
@ -182,7 +190,7 @@ class SequenceGroup:
raise ValueError(f'Sequence {seq_id} not found.') raise ValueError(f'Sequence {seq_id} not found.')
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs) return all(seq.is_finished() for seq in self.seqs)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, " return (f"SequenceGroup(request_id={self.request_id}, "

View File

@ -15,7 +15,7 @@ class ServerArgs:
use_dummy_weights: bool = False use_dummy_weights: bool = False
dtype: str = "default" dtype: str = "default"
seed: int = 0 seed: int = 0
use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
block_size: int = 16 block_size: int = 16
@ -32,7 +32,63 @@ class ServerArgs:
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser: ) -> argparse.ArgumentParser:
return _add_server_arguments(parser) """Shared CLI arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for '
'faster loading. This can increase the disk '
'usage by up to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
help='data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments
parser.add_argument('--worker-use-ray', action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int,
default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int,
default=ServerArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs": def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
@ -53,65 +109,22 @@ class ServerArgs:
self.swap_space) self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.use_ray) self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs) self.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
def _add_server_arguments( @dataclass
parser: argparse.ArgumentParser, class AsyncServerArgs(ServerArgs):
)-> argparse.ArgumentParser: server_use_ray: bool = False
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments @staticmethod
parser.add_argument('--model', type=str, default='facebook/opt-125m', def add_cli_args(
help='name or path of the huggingface model to use') parser: argparse.ArgumentParser,
parser.add_argument('--download-dir', type=str, ) -> argparse.ArgumentParser:
default=ServerArgs.download_dir, parser = ServerArgs.add_cli_args(parser)
help='directory to download and load the weights, ' parser.add_argument('--server-use-ray', action='store_true',
'default to the default cache dir of huggingface') help='use Ray to start the LLM server in a '
parser.add_argument('--use-np-weights', action='store_true', 'separate process as the web server process.')
help='save a numpy copy of model weights for faster ' return parser
'loading. This can increase the disk usage by up '
'to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for the '
'model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser

View File

@ -2,37 +2,52 @@ import asyncio
import time import time
from typing import Dict, Optional from typing import Dict, Optional
import ray from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.llm_server import LLMServer from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import ray, initialize_cluster
from cacheflow.utils import random_uuid
logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMServer: class AsyncLLMServer:
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None: def __init__(self, worker_use_ray: bool, server_use_ray: bool,
if server_use_ray: *args, **kwargs) -> None:
remote_server_class = ray.remote(num_cpus=0)(LLMServer) self.worker_use_ray = worker_use_ray
self.server_use_ray = server_use_ray
if not self.server_use_ray:
server_class = LLMServer
elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMServer).remote
else: else:
remote_server_class = ray.remote(num_gpus=1)(LLMServer) server_class = ray.remote(num_gpus=1)(LLMServer).remote
self.server = remote_server_class.remote(*args, **kwargs) self.server = server_class(*args, **kwargs)
# Request id -> request output. # Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {} self.request_outputs: Dict[str, RequestOutput] = {}
# Request id -> event to notify that there is new output. # Request id -> event to notify that there is new output.
self.request_events: Dict[str, asyncio.Event] = {} self.request_events: Dict[str, asyncio.Event] = {}
self.is_server_running = False self.is_server_running = False
self.kicking_request_id: Optional[str] = None
async def server_step(self): async def server_step(self, kicking_request_id: Optional[str] = None):
self.is_server_running = True self.is_server_running = True
request_outputs = await self.server.step.remote() self.kicking_request_id = kicking_request_id
if self.server_use_ray:
request_outputs = await self.server.step.remote()
else:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# requests into the queue.
await asyncio.sleep(0)
request_outputs = self.server.step()
self.is_server_running = False self.is_server_running = False
self.kicking_request_id = None
# Notify the waiting coroutines that there are new outputs ready. # Notify the waiting coroutines that there are new outputs ready.
for request_output in request_outputs: for request_output in request_outputs:
request_id = request_output.request_id request_id = request_output.request_id
@ -40,20 +55,26 @@ class AsyncLLMServer:
self.request_events[request_id].set() self.request_events[request_id].set()
async def generate(self, prompt: str, sampling_params: SamplingParams, async def generate(self, prompt: str, sampling_params: SamplingParams,
request_id: Optional[str] = None) -> RequestOutput: request_id: str) -> RequestOutput:
# Preprocess the request. # Preprocess the request.
arrival_time = time.time() arrival_time = time.time()
# Create an event to notify us that there is new output from the # Create an event to notify us that there is new output from the
# cacheflow server. # cacheflow server.
if request_id is None:
request_id = random_uuid()
request_event = asyncio.Event() request_event = asyncio.Event()
self.request_events[request_id] = request_event self.request_events[request_id] = request_event
logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}.")
# Add the request into the cacheflow server's waiting queue. # Add the request into the cacheflow server's waiting queue.
await self.server.add_request.remote( if self.server_use_ray:
request_id, prompt, sampling_params, arrival_time=arrival_time) await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time)
else:
self.server.add_request(
request_id, prompt, sampling_params, arrival_time=arrival_time)
# The cacheflow server does not have a background loop that keeps # The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking # processing incoming requests. Therefore, we need to keep kicking
@ -61,7 +82,7 @@ class AsyncLLMServer:
while True: while True:
# Kick the server if the server is not running. # Kick the server if the server is not running.
if not self.is_server_running: if not self.is_server_running:
await self.server_step() await self.server_step(request_id)
# Wait for new output. The group_event will be set in server_step # Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group. # when there is new output available for the sequence group.
@ -80,6 +101,8 @@ class AsyncLLMServer:
# Once finished, release the resources of the sequence group. # Once finished, release the resources of the sequence group.
if request_output.finished(): if request_output.finished():
logger.info(f"Finished request {request_id}.")
del self.request_outputs[request_id] del self.request_outputs[request_id]
del self.request_events[request_id] del self.request_events[request_id]
# Kick the server if the server is not running. This is to # Kick the server if the server is not running. This is to
@ -89,15 +112,41 @@ class AsyncLLMServer:
await self.server_step() await self.server_step()
break break
async def abort(self, request_id: str) -> None:
if request_id not in self.request_events:
# The request has already finished or been aborted.
return
logger.info(f"Aborted request {request_id}.")
if self.server_use_ray:
await self.server.abort_request.remote(request_id)
else:
self.server.abort_request(request_id)
if request_id in self.request_events:
del self.request_events[request_id]
if request_id in self.request_outputs:
del self.request_outputs[request_id]
# To prevent deadlock when a request is aborted while the server is
# running.
if self.kicking_request_id == request_id:
self.is_server_running = False
self.kicking_request_id = None
@classmethod @classmethod
def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer": def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
# Create the server configs. # Create the server configs.
server_configs = server_args.create_server_configs() server_configs = server_args.create_server_configs()
parallel_config = server_configs[2] parallel_config = server_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, devices = initialize_cluster(
parallel_config, server_args.server_use_ray)
# Create the LLM server. # Create the LLM server.
server = cls(server_args.use_ray, *server_configs, server = cls(server_args.worker_use_ray,
server_args.server_use_ray,
*server_configs,
distributed_init_method, devices, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats) log_stats=not server_args.disable_log_stats)
return server return server

View File

@ -1,11 +1,6 @@
import time import time
from typing import Any, List, Optional from typing import Any, List, Optional
try:
import ray
except ImportError:
ray = None
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from cacheflow.core.scheduler import Scheduler from cacheflow.core.scheduler import Scheduler
@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import ray, initialize_cluster
from cacheflow.server.tokenizer_utils import (get_tokenizer, from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally) detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
@ -62,7 +57,7 @@ class LLMServer:
assert len(stage_devices) == 1, "Only support one stage for now." assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]: for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker worker_cls = Worker
if self.parallel_config.use_ray: if self.parallel_config.worker_use_ray:
worker_cls = ray.remote( worker_cls = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=1, num_gpus=1,
@ -152,6 +147,9 @@ class LLMServer:
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group) self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None:
self.scheduler.abort_seq_group(request_id)
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
return self.scheduler.get_num_unfinished_seq_groups() return self.scheduler.get_num_unfinished_seq_groups()
@ -243,13 +241,13 @@ class LLMServer:
all_outputs = [] all_outputs = []
for worker in self.workers: for worker in self.workers:
executor = getattr(worker, method) executor = getattr(worker, method)
if self.parallel_config.use_ray: if self.parallel_config.worker_use_ray:
executor = executor.remote executor = executor.remote
output = executor(*args, **kwargs) output = executor(*args, **kwargs)
all_outputs.append(output) all_outputs.append(output)
if self.parallel_config.use_ray: if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs) all_outputs = ray.get(all_outputs)
if get_all_outputs: if get_all_outputs:

View File

@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def initialize_cluster( def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
server_use_ray: bool = False,
address: Optional[str] = None, address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]: ) -> Tuple[str, List[List[DeviceID]]]:
if not parallel_config.use_ray: if parallel_config.worker_use_ray or server_use_ray:
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
if not parallel_config.worker_use_ray:
# Initialize cluster locally. # Initialize cluster locally.
port = random.randint(10000, 20000) port = random.randint(10000, 20000)
# We need to setup the distributed init method to make sure # We need to setup the distributed init method to make sure
@ -24,13 +33,6 @@ def initialize_cluster(
all_stage_devices = [[(0, None, 0)]] all_stage_devices = [[(0, None, 0)]]
return distributed_init_method, all_stage_devices return distributed_init_method, all_stage_devices
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
# Assume we have a uniform cluster that each node has the same number of # Assume we have a uniform cluster that each node has the same number of
# GPUs for now. # GPUs for now.
valid_node_resources = [] valid_node_resources = []