From 1a956e136beae057746af6257ffa8da601730f10 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 5 Jun 2023 23:44:50 +0800 Subject: [PATCH] Fix various issues of async servers (#135) --- benchmarks/benchmark_async_llm_server.py | 58 ++++++++ cacheflow/config.py | 6 +- cacheflow/core/block_manager.py | 9 +- cacheflow/core/scheduler.py | 14 +- .../entrypoints/openai/openai_frontend.py | 30 +++- .../entrypoints/simple_fastapi_frontend.py | 25 ++-- cacheflow/sequence.py | 10 +- cacheflow/server/arg_utils.py | 131 ++++++++++-------- cacheflow/server/async_llm_server.py | 93 ++++++++++--- cacheflow/server/llm_server.py | 16 +-- cacheflow/server/ray_utils.py | 18 +-- 11 files changed, 289 insertions(+), 121 deletions(-) create mode 100644 benchmarks/benchmark_async_llm_server.py diff --git a/benchmarks/benchmark_async_llm_server.py b/benchmarks/benchmark_async_llm_server.py new file mode 100644 index 0000000000000..4c6ed709c46c4 --- /dev/null +++ b/benchmarks/benchmark_async_llm_server.py @@ -0,0 +1,58 @@ +import argparse +import json +import threading +import time + +import requests + + +def main(args: argparse.Namespace): + prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words" + for i in range(args.n_threads)] + + headers = {"User-Agent": "CacheFlow Benchmark Client"} + ploads = [{ + "prompt": p, + "max_tokens": args.max_tokens, + "temperature": 0.0, + "ignore_eos": True, + } for p in prompts] + + def send_request(results, i): + response = requests.post(args.api_url, headers=headers, + json=ploads[i], stream=True) + results[i] = response + + # use args.n_threads to prompt the backend + tik = time.time() + threads = [] + results = [None] * args.n_threads + for i in range(args.n_threads): + t = threading.Thread(target=send_request, args=(results, i)) + t.start() + threads.append(t) + + for t in threads: + t.join() + + print(f"Time (POST): {time.time() - tik} s") + n_words = 0 + + for i, response in enumerate(results): + k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) + response_new_words = json.loads(k[-2].decode("utf-8"))["text"][0] + n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + + time_seconds = time.time() - tik + print(f"Time (total): {time_seconds:.3f}s to finish, n_threads: {args.n_threads}, " + f"throughput: {n_words / time_seconds} words/s.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate") + parser.add_argument("--max-tokens", type=int, default=128) + parser.add_argument("--n-threads", type=int, default=128) + args = parser.parse_args() + + main(args) diff --git a/cacheflow/config.py b/cacheflow/config.py index 157d22d7b304a..cf779723a9696 100644 --- a/cacheflow/config.py +++ b/cacheflow/config.py @@ -116,15 +116,15 @@ class ParallelConfig: self, pipeline_parallel_size: int, tensor_parallel_size: int, - use_ray: bool, + worker_use_ray: bool, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size - self.use_ray = use_ray + self.worker_use_ray = worker_use_ray self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: - self.use_ray = True + self.worker_use_ray = True self._verify_args() def _verify_args(self) -> None: diff --git a/cacheflow/core/block_manager.py b/cacheflow/core/block_manager.py index 07129b65b226c..93939f5ce2638 100644 --- a/cacheflow/core/block_manager.py +++ b/cacheflow/core/block_manager.py @@ -148,7 +148,7 @@ class BlockSpaceManager: # the sequences in the same group. blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): - if SequenceStatus.is_finished(seq.status): + if seq.is_finished(): continue block_table = self.block_tables[seq.seq_id] for block in block_table: @@ -169,7 +169,7 @@ class BlockSpaceManager: # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(): - if SequenceStatus.is_finished(seq.status): + if seq.is_finished(): continue new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] @@ -200,7 +200,7 @@ class BlockSpaceManager: # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(): - if SequenceStatus.is_finished(seq.status): + if seq.is_finished(): continue new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] @@ -231,6 +231,9 @@ class BlockSpaceManager: self.cpu_allocator.free(block) def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return block_table = self.block_tables[seq.seq_id] self._free_block_table(block_table) del self.block_tables[seq.seq_id] diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 0423ddae1a3bb..9ff5db9d5b199 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup, logger = init_logger(__name__) -_LOGGING_INTERVAL_SEC = 10 +_LOGGING_INTERVAL_SEC = 5 class PreemptionMode(enum.Enum): @@ -84,6 +84,18 @@ class Scheduler: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) + def abort_seq_group(self, request_id: str) -> None: + for state_queue in [self.waiting, self.running, self.swapped]: + for seq_group in state_queue: + if seq_group.request_id == request_id: + # Remove the sequence group from the state queue. + state_queue.remove(seq_group) + for seq in seq_group.seqs: + if seq.is_finished(): + continue + self.free_seq(seq, SequenceStatus.FINISHED_ABORTED) + return + def has_unfinished_seqs(self) -> bool: return self.waiting or self.running or self.swapped diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 4d32390bade1b..8f00db863e9bf 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -7,13 +7,14 @@ import time from typing import AsyncGenerator, Dict, List, Optional import fastapi +from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from cacheflow.outputs import RequestOutput -from cacheflow.server.arg_utils import ServerArgs +from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.server.async_llm_server import AsyncLLMServer from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.logger import init_logger @@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import ( UsageInfo, ) +TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None @@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int], @app.post("/v1/completions") -async def create_completion(request: CompletionRequest): +async def create_completion(raw_request: Request): + request = CompletionRequest(**await raw_request.json()) logger.info(f"Received completion request: {request}") error_check_ret = await check_model(request) @@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest): return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) result_generator = server.generate(prompt, sampling_params, - request_id=request_id) + request_id) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. @@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest): (request.best_of is None or request.n == request.best_of) and not request.use_beam_search) + async def abort_request() -> None: + await server.abort(request_id) + def create_stream_response_json(index: int, text: str, logprobs: Optional[LogProbs] = None, @@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest): # Streaming response if stream: + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) return StreamingResponse(completion_stream_generator(), - media_type="text/event-stream") + media_type="text/event-stream", + background=background_tasks) # Non-streaming response final_res: RequestOutput = None async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await server.abort(request_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + "Client disconnected") final_res = res assert final_res is not None choices = [] @@ -276,7 +291,7 @@ if __name__ == "__main__": help="The model name used in the API. If not specified, " "the model name will be the same as the " "huggingface name.") - parser = ServerArgs.add_cli_args(parser) + parser = AsyncServerArgs.add_cli_args(parser) args = parser.parse_args() app.add_middleware( @@ -291,10 +306,11 @@ if __name__ == "__main__": served_model = args.served_model_name or args.model - server_args = ServerArgs.from_cli_args(args) + server_args = AsyncServerArgs.from_cli_args(args) server = AsyncLLMServer.from_server_args(server_args) # A separate tokenizer to map token IDs to strings. tokenizer = get_tokenizer(args.model) - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + uvicorn.run(app, host=args.host, port=args.port, log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/simple_fastapi_frontend.py index e7e1357f38495..1fce4cf3bf8e1 100644 --- a/cacheflow/entrypoints/simple_fastapi_frontend.py +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -2,15 +2,16 @@ import argparse import json from typing import AsyncGenerator -from fastapi import FastAPI, Request +from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import StreamingResponse import uvicorn from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import ServerArgs +from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.server.async_llm_server import AsyncLLMServer -from cacheflow.server.ray_utils import initialize_cluster +from cacheflow.utils import random_uuid +TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds app = FastAPI() @@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse: request_dict = await request.json() prompt = request_dict.pop("prompt") sampling_params = SamplingParams(**request_dict) - results_generator = server.generate(prompt, sampling_params) + request_id = random_uuid() + results_generator = server.generate(prompt, sampling_params, request_id) async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: @@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse: } yield (json.dumps(ret) + "\0").encode("utf-8") - return StreamingResponse(stream_results()) + async def abort_request() -> None: + await server.abort(request_id) + + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + return StreamingResponse(stream_results(), background=background_tasks) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8001) - parser = ServerArgs.add_cli_args(parser) + parser = AsyncServerArgs.add_cli_args(parser) args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) + server_args = AsyncServerArgs.from_cli_args(args) server = AsyncLLMServer.from_server_args(server_args) - uvicorn.run(app, host=args.host, port=args.port, log_level="info") + uvicorn.run(app, host=args.host, port=args.port, log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index db86460987b91..8e6e729267802 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum): SWAPPED = enum.auto() FINISHED_STOPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() @staticmethod def is_finished(status: "SequenceStatus") -> bool: return status in [ SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_LENGTH_CAPPED, + SequenceStatus.FINISHED_ABORTED, ] @staticmethod @@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum): finish_reason = "stop" elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: finish_reason = "length" + elif status == SequenceStatus.FINISHED_ABORTED: + finish_reason = "abort" else: finish_reason = None return finish_reason + class SequenceData: def __init__( @@ -137,6 +142,9 @@ class Sequence: def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob + def is_finished(self) -> bool: + return SequenceStatus.is_finished(self.status) + def fork(self, child_seq: 'Sequence') -> None: child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) @@ -182,7 +190,7 @@ class SequenceGroup: raise ValueError(f'Sequence {seq_id} not found.') def is_finished(self) -> bool: - return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs) + return all(seq.is_finished() for seq in self.seqs) def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " diff --git a/cacheflow/server/arg_utils.py b/cacheflow/server/arg_utils.py index a4b898dd214c2..63f32c80fab97 100644 --- a/cacheflow/server/arg_utils.py +++ b/cacheflow/server/arg_utils.py @@ -15,7 +15,7 @@ class ServerArgs: use_dummy_weights: bool = False dtype: str = "default" seed: int = 0 - use_ray: bool = False + worker_use_ray: bool = False pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 block_size: int = 16 @@ -32,7 +32,63 @@ class ServerArgs: def add_cli_args( parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: - return _add_server_arguments(parser) + """Shared CLI arguments for CacheFlow servers.""" + # Model arguments + parser.add_argument('--model', type=str, default='facebook/opt-125m', + help='name or path of the huggingface model to use') + parser.add_argument('--download-dir', type=str, + default=ServerArgs.download_dir, + help='directory to download and load the weights, ' + 'default to the default cache dir of ' + 'huggingface') + parser.add_argument('--use-np-weights', action='store_true', + help='save a numpy copy of model weights for ' + 'faster loading. This can increase the disk ' + 'usage by up to 2x.') + parser.add_argument('--use-dummy-weights', action='store_true', + help='use dummy values for model weights') + # TODO(woosuk): Support FP32. + parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, + choices=['default', 'half', 'bfloat16'], + help='data type for model weights and activations. ' + 'The "default" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + # Parallel arguments + parser.add_argument('--worker-use-ray', action='store_true', + help='use Ray for distributed serving, will be ' + 'automatically set when using more than 1 GPU') + parser.add_argument('--pipeline-parallel-size', '-pp', type=int, + default=ServerArgs.pipeline_parallel_size, + help='number of pipeline stages') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, + default=ServerArgs.tensor_parallel_size, + help='number of tensor parallel replicas') + # KV cache arguments + parser.add_argument('--block-size', type=int, + default=ServerArgs.block_size, + choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], + help='token block size') + # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). + parser.add_argument('--seed', type=int, default=ServerArgs.seed, + help='random seed') + parser.add_argument('--swap-space', type=int, + default=ServerArgs.swap_space, + help='CPU swap space size (GiB) per GPU') + parser.add_argument('--gpu-memory-utilization', type=float, + default=ServerArgs.gpu_memory_utilization, + help='the percentage of GPU memory to be used for' + 'the model executor') + parser.add_argument('--max-num-batched-tokens', type=int, + default=ServerArgs.max_num_batched_tokens, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--max-num-seqs', type=int, + default=ServerArgs.max_num_seqs, + help='maximum number of sequences per iteration') + parser.add_argument('--disable-log-stats', action='store_true', + help='disable logging statistics') + return parser @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs": @@ -53,65 +109,22 @@ class ServerArgs: self.swap_space) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, - self.use_ray) + self.worker_use_ray) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs) return model_config, cache_config, parallel_config, scheduler_config -def _add_server_arguments( - parser: argparse.ArgumentParser, -)-> argparse.ArgumentParser: - """Shared CLI arguments for CacheFlow servers.""" - # Model arguments - parser.add_argument('--model', type=str, default='facebook/opt-125m', - help='name or path of the huggingface model to use') - parser.add_argument('--download-dir', type=str, - default=ServerArgs.download_dir, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') - parser.add_argument('--use-np-weights', action='store_true', - help='save a numpy copy of model weights for faster ' - 'loading. This can increase the disk usage by up ' - 'to 2x.') - parser.add_argument('--use-dummy-weights', action='store_true', - help='use dummy values for model weights') - # TODO(woosuk): Support FP32. - parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, - choices=['default', 'half', 'bfloat16'], - help=('data type for model weights and activations. ' - 'The "default" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.')) - # Parallel arguments - parser.add_argument('--use-ray', action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') - parser.add_argument('--pipeline-parallel-size', '-pp', type=int, - default=ServerArgs.pipeline_parallel_size, - help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', '-tp', type=int, - default=ServerArgs.tensor_parallel_size, - help='number of tensor parallel replicas') - # KV cache arguments - parser.add_argument('--block-size', type=int, default=ServerArgs.block_size, - choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], - help='token block size') - # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', type=int, default=ServerArgs.seed, - help='random seed') - parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space, - help='CPU swap space size (GiB) per GPU') - parser.add_argument('--gpu-memory-utilization', type=float, - default=ServerArgs.gpu_memory_utilization, - help='the percentage of GPU memory to be used for the ' - 'model executor') - parser.add_argument('--max-num-batched-tokens', type=int, - default=ServerArgs.max_num_batched_tokens, - help='maximum number of batched tokens per iteration') - parser.add_argument('--max-num-seqs', type=int, - default=ServerArgs.max_num_seqs, - help='maximum number of sequences per iteration') - parser.add_argument('--disable-log-stats', action='store_true', - help='disable logging statistics') - return parser +@dataclass +class AsyncServerArgs(ServerArgs): + server_use_ray: bool = False + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser, + ) -> argparse.ArgumentParser: + parser = ServerArgs.add_cli_args(parser) + parser.add_argument('--server-use-ray', action='store_true', + help='use Ray to start the LLM server in a ' + 'separate process as the web server process.') + return parser diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index 8755b023b897d..409af2f240eda 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -2,37 +2,52 @@ import asyncio import time from typing import Dict, Optional -import ray - +from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import ServerArgs +from cacheflow.server.arg_utils import AsyncServerArgs from cacheflow.server.llm_server import LLMServer -from cacheflow.server.ray_utils import initialize_cluster -from cacheflow.utils import random_uuid +from cacheflow.server.ray_utils import ray, initialize_cluster + +logger = init_logger(__name__) TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds class AsyncLLMServer: - def __init__(self, server_use_ray: bool, *args, **kwargs) -> None: - if server_use_ray: - remote_server_class = ray.remote(num_cpus=0)(LLMServer) + def __init__(self, worker_use_ray: bool, server_use_ray: bool, + *args, **kwargs) -> None: + self.worker_use_ray = worker_use_ray + self.server_use_ray = server_use_ray + if not self.server_use_ray: + server_class = LLMServer + elif self.worker_use_ray: + server_class = ray.remote(num_cpus=0)(LLMServer).remote else: - remote_server_class = ray.remote(num_gpus=1)(LLMServer) - self.server = remote_server_class.remote(*args, **kwargs) - + server_class = ray.remote(num_gpus=1)(LLMServer).remote + self.server = server_class(*args, **kwargs) # Request id -> request output. self.request_outputs: Dict[str, RequestOutput] = {} # Request id -> event to notify that there is new output. self.request_events: Dict[str, asyncio.Event] = {} self.is_server_running = False + self.kicking_request_id: Optional[str] = None - async def server_step(self): + async def server_step(self, kicking_request_id: Optional[str] = None): self.is_server_running = True - request_outputs = await self.server.step.remote() + self.kicking_request_id = kicking_request_id + if self.server_use_ray: + request_outputs = await self.server.step.remote() + else: + # Yield to the event loop to allow other coroutines to run + # while is_server_running is True. This let the server to add new + # requests into the queue. + await asyncio.sleep(0) + request_outputs = self.server.step() self.is_server_running = False + self.kicking_request_id = None + # Notify the waiting coroutines that there are new outputs ready. for request_output in request_outputs: request_id = request_output.request_id @@ -40,20 +55,26 @@ class AsyncLLMServer: self.request_events[request_id].set() async def generate(self, prompt: str, sampling_params: SamplingParams, - request_id: Optional[str] = None) -> RequestOutput: + request_id: str) -> RequestOutput: # Preprocess the request. arrival_time = time.time() # Create an event to notify us that there is new output from the # cacheflow server. - if request_id is None: - request_id = random_uuid() request_event = asyncio.Event() self.request_events[request_id] = request_event + logger.info(f"Received request {request_id}: " + f"prompt: {prompt!r}, " + f"sampling params: {sampling_params}.") + # Add the request into the cacheflow server's waiting queue. - await self.server.add_request.remote( - request_id, prompt, sampling_params, arrival_time=arrival_time) + if self.server_use_ray: + await self.server.add_request.remote( + request_id, prompt, sampling_params, arrival_time=arrival_time) + else: + self.server.add_request( + request_id, prompt, sampling_params, arrival_time=arrival_time) # The cacheflow server does not have a background loop that keeps # processing incoming requests. Therefore, we need to keep kicking @@ -61,7 +82,7 @@ class AsyncLLMServer: while True: # Kick the server if the server is not running. if not self.is_server_running: - await self.server_step() + await self.server_step(request_id) # Wait for new output. The group_event will be set in server_step # when there is new output available for the sequence group. @@ -80,6 +101,8 @@ class AsyncLLMServer: # Once finished, release the resources of the sequence group. if request_output.finished(): + logger.info(f"Finished request {request_id}.") + del self.request_outputs[request_id] del self.request_events[request_id] # Kick the server if the server is not running. This is to @@ -89,15 +112,41 @@ class AsyncLLMServer: await self.server_step() break + async def abort(self, request_id: str) -> None: + if request_id not in self.request_events: + # The request has already finished or been aborted. + return + + logger.info(f"Aborted request {request_id}.") + + if self.server_use_ray: + await self.server.abort_request.remote(request_id) + else: + self.server.abort_request(request_id) + + if request_id in self.request_events: + del self.request_events[request_id] + if request_id in self.request_outputs: + del self.request_outputs[request_id] + + # To prevent deadlock when a request is aborted while the server is + # running. + if self.kicking_request_id == request_id: + self.is_server_running = False + self.kicking_request_id = None + @classmethod - def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer": + def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer": # Create the server configs. server_configs = server_args.create_server_configs() parallel_config = server_configs[2] # Initialize the cluster. - distributed_init_method, devices = initialize_cluster(parallel_config) + distributed_init_method, devices = initialize_cluster( + parallel_config, server_args.server_use_ray) # Create the LLM server. - server = cls(server_args.use_ray, *server_configs, + server = cls(server_args.worker_use_ray, + server_args.server_use_ray, + *server_configs, distributed_init_method, devices, log_stats=not server_args.disable_log_stats) return server diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 0032a768c14a9..54ab622359b78 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -1,11 +1,6 @@ import time from typing import Any, List, Optional -try: - import ray -except ImportError: - ray = None - from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from cacheflow.core.scheduler import Scheduler @@ -13,7 +8,7 @@ from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.ray_utils import initialize_cluster +from cacheflow.server.ray_utils import ray, initialize_cluster from cacheflow.server.tokenizer_utils import (get_tokenizer, detokenize_incrementally) from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus @@ -62,7 +57,7 @@ class LLMServer: assert len(stage_devices) == 1, "Only support one stage for now." for rank, node_resource, _ in stage_devices[0]: worker_cls = Worker - if self.parallel_config.use_ray: + if self.parallel_config.worker_use_ray: worker_cls = ray.remote( num_cpus=0, num_gpus=1, @@ -152,6 +147,9 @@ class LLMServer: # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) + def abort_request(self, request_id: str) -> None: + self.scheduler.abort_seq_group(request_id) + def get_num_unfinished_requests(self) -> int: return self.scheduler.get_num_unfinished_seq_groups() @@ -243,13 +241,13 @@ class LLMServer: all_outputs = [] for worker in self.workers: executor = getattr(worker, method) - if self.parallel_config.use_ray: + if self.parallel_config.worker_use_ray: executor = executor.remote output = executor(*args, **kwargs) all_outputs.append(output) - if self.parallel_config.use_ray: + if self.parallel_config.worker_use_ray: all_outputs = ray.get(all_outputs) if get_all_outputs: diff --git a/cacheflow/server/ray_utils.py b/cacheflow/server/ray_utils.py index 4577fc8dc70ac..4d533bddee0b0 100644 --- a/cacheflow/server/ray_utils.py +++ b/cacheflow/server/ray_utils.py @@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi def initialize_cluster( parallel_config: ParallelConfig, + server_use_ray: bool = False, address: Optional[str] = None, ) -> Tuple[str, List[List[DeviceID]]]: - if not parallel_config.use_ray: + if parallel_config.worker_use_ray or server_use_ray: + if ray is None: + raise ImportError( + "Ray is not installed. Please install Ray to use distributed " + "serving.") + # Connect to a ray cluster. + ray.init(address=address) + + if not parallel_config.worker_use_ray: # Initialize cluster locally. port = random.randint(10000, 20000) # We need to setup the distributed init method to make sure @@ -24,13 +33,6 @@ def initialize_cluster( all_stage_devices = [[(0, None, 0)]] return distributed_init_method, all_stage_devices - if ray is None: - raise ImportError( - "Ray is not installed. Please install Ray to use distributed " - "serving.") - # Connect to a ray cluster. - ray.init(address=address) - # Assume we have a uniform cluster that each node has the same number of # GPUs for now. valid_node_resources = []