mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:45:01 +08:00
Fix various issues of async servers (#135)
This commit is contained in:
parent
8274ca23ac
commit
1a956e136b
58
benchmarks/benchmark_async_llm_server.py
Normal file
58
benchmarks/benchmark_async_llm_server.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
|
||||||
|
for i in range(args.n_threads)]
|
||||||
|
|
||||||
|
headers = {"User-Agent": "CacheFlow Benchmark Client"}
|
||||||
|
ploads = [{
|
||||||
|
"prompt": p,
|
||||||
|
"max_tokens": args.max_tokens,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"ignore_eos": True,
|
||||||
|
} for p in prompts]
|
||||||
|
|
||||||
|
def send_request(results, i):
|
||||||
|
response = requests.post(args.api_url, headers=headers,
|
||||||
|
json=ploads[i], stream=True)
|
||||||
|
results[i] = response
|
||||||
|
|
||||||
|
# use args.n_threads to prompt the backend
|
||||||
|
tik = time.time()
|
||||||
|
threads = []
|
||||||
|
results = [None] * args.n_threads
|
||||||
|
for i in range(args.n_threads):
|
||||||
|
t = threading.Thread(target=send_request, args=(results, i))
|
||||||
|
t.start()
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
print(f"Time (POST): {time.time() - tik} s")
|
||||||
|
n_words = 0
|
||||||
|
|
||||||
|
for i, response in enumerate(results):
|
||||||
|
k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
|
||||||
|
response_new_words = json.loads(k[-2].decode("utf-8"))["text"][0]
|
||||||
|
n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
|
||||||
|
|
||||||
|
time_seconds = time.time() - tik
|
||||||
|
print(f"Time (total): {time_seconds:.3f}s to finish, n_threads: {args.n_threads}, "
|
||||||
|
f"throughput: {n_words / time_seconds} words/s.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
|
||||||
|
parser.add_argument("--max-tokens", type=int, default=128)
|
||||||
|
parser.add_argument("--n-threads", type=int, default=128)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
@ -116,15 +116,15 @@ class ParallelConfig:
|
|||||||
self,
|
self,
|
||||||
pipeline_parallel_size: int,
|
pipeline_parallel_size: int,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
use_ray: bool,
|
worker_use_ray: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.pipeline_parallel_size = pipeline_parallel_size
|
self.pipeline_parallel_size = pipeline_parallel_size
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
self.use_ray = use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
|
|
||||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||||
if self.world_size > 1:
|
if self.world_size > 1:
|
||||||
self.use_ray = True
|
self.worker_use_ray = True
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
|
|||||||
@ -148,7 +148,7 @@ class BlockSpaceManager:
|
|||||||
# the sequences in the same group.
|
# the sequences in the same group.
|
||||||
blocks: Set[PhysicalTokenBlock] = set()
|
blocks: Set[PhysicalTokenBlock] = set()
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if SequenceStatus.is_finished(seq.status):
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
for block in block_table:
|
for block in block_table:
|
||||||
@ -169,7 +169,7 @@ class BlockSpaceManager:
|
|||||||
# CPU block -> GPU block.
|
# CPU block -> GPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if SequenceStatus.is_finished(seq.status):
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
@ -200,7 +200,7 @@ class BlockSpaceManager:
|
|||||||
# GPU block -> CPU block.
|
# GPU block -> CPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
if SequenceStatus.is_finished(seq.status):
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
@ -231,6 +231,9 @@ class BlockSpaceManager:
|
|||||||
self.cpu_allocator.free(block)
|
self.cpu_allocator.free(block)
|
||||||
|
|
||||||
def free(self, seq: Sequence) -> None:
|
def free(self, seq: Sequence) -> None:
|
||||||
|
if seq.seq_id not in self.block_tables:
|
||||||
|
# Already freed or haven't been scheduled yet.
|
||||||
|
return
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
self._free_block_table(block_table)
|
self._free_block_table(block_table)
|
||||||
del self.block_tables[seq.seq_id]
|
del self.block_tables[seq.seq_id]
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_LOGGING_INTERVAL_SEC = 10
|
_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
class PreemptionMode(enum.Enum):
|
class PreemptionMode(enum.Enum):
|
||||||
@ -84,6 +84,18 @@ class Scheduler:
|
|||||||
# Add sequence groups to the waiting queue.
|
# Add sequence groups to the waiting queue.
|
||||||
self.waiting.append(seq_group)
|
self.waiting.append(seq_group)
|
||||||
|
|
||||||
|
def abort_seq_group(self, request_id: str) -> None:
|
||||||
|
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||||
|
for seq_group in state_queue:
|
||||||
|
if seq_group.request_id == request_id:
|
||||||
|
# Remove the sequence group from the state queue.
|
||||||
|
state_queue.remove(seq_group)
|
||||||
|
for seq in seq_group.seqs:
|
||||||
|
if seq.is_finished():
|
||||||
|
continue
|
||||||
|
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
||||||
|
return
|
||||||
|
|
||||||
def has_unfinished_seqs(self) -> bool:
|
def has_unfinished_seqs(self) -> bool:
|
||||||
return self.waiting or self.running or self.swapped
|
return self.waiting or self.running or self.swapped
|
||||||
|
|
||||||
|
|||||||
@ -7,13 +7,14 @@ import time
|
|||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
from fastapi import BackgroundTasks, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
|
|||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
served_model = None
|
served_model = None
|
||||||
@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def create_completion(request: CompletionRequest):
|
async def create_completion(raw_request: Request):
|
||||||
|
request = CompletionRequest(**await raw_request.json())
|
||||||
logger.info(f"Received completion request: {request}")
|
logger.info(f"Received completion request: {request}")
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
|
|||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = server.generate(prompt, sampling_params,
|
result_generator = server.generate(prompt, sampling_params,
|
||||||
request_id=request_id)
|
request_id)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
# results. In addition, we do not stream the results when use beam search.
|
# results. In addition, we do not stream the results when use beam search.
|
||||||
@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
|
|||||||
(request.best_of is None or request.n == request.best_of) and
|
(request.best_of is None or request.n == request.best_of) and
|
||||||
not request.use_beam_search)
|
not request.use_beam_search)
|
||||||
|
|
||||||
|
async def abort_request() -> None:
|
||||||
|
await server.abort(request_id)
|
||||||
|
|
||||||
def create_stream_response_json(index: int,
|
def create_stream_response_json(index: int,
|
||||||
text: str,
|
text: str,
|
||||||
logprobs: Optional[LogProbs] = None,
|
logprobs: Optional[LogProbs] = None,
|
||||||
@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
|
|||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
|
background_tasks = BackgroundTasks()
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
background_tasks.add_task(abort_request)
|
||||||
return StreamingResponse(completion_stream_generator(),
|
return StreamingResponse(completion_stream_generator(),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream",
|
||||||
|
background=background_tasks)
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
|
if await raw_request.is_disconnected():
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
await server.abort(request_id)
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
choices = []
|
choices = []
|
||||||
@ -276,7 +291,7 @@ if __name__ == "__main__":
|
|||||||
help="The model name used in the API. If not specified, "
|
help="The model name used in the API. If not specified, "
|
||||||
"the model name will be the same as the "
|
"the model name will be the same as the "
|
||||||
"huggingface name.")
|
"huggingface name.")
|
||||||
parser = ServerArgs.add_cli_args(parser)
|
parser = AsyncServerArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@ -291,10 +306,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
served_model = args.served_model_name or args.model
|
served_model = args.served_model_name or args.model
|
||||||
|
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
server_args = AsyncServerArgs.from_cli_args(args)
|
||||||
server = AsyncLLMServer.from_server_args(server_args)
|
server = AsyncLLMServer.from_server_args(server_args)
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
# A separate tokenizer to map token IDs to strings.
|
||||||
tokenizer = get_tokenizer(args.model)
|
tokenizer = get_tokenizer(args.model)
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||||
|
|||||||
@ -2,15 +2,16 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import BackgroundTasks, FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||||
from cacheflow.server.ray_utils import initialize_cluster
|
from cacheflow.utils import random_uuid
|
||||||
|
|
||||||
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
|
|||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
prompt = request_dict.pop("prompt")
|
prompt = request_dict.pop("prompt")
|
||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
results_generator = server.generate(prompt, sampling_params)
|
request_id = random_uuid()
|
||||||
|
results_generator = server.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||||
async for request_output in results_generator:
|
async for request_output in results_generator:
|
||||||
@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
|
|||||||
}
|
}
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
return StreamingResponse(stream_results())
|
async def abort_request() -> None:
|
||||||
|
await server.abort(request_id)
|
||||||
|
|
||||||
|
background_tasks = BackgroundTasks()
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
background_tasks.add_task(abort_request)
|
||||||
|
return StreamingResponse(stream_results(), background=background_tasks)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=8001)
|
parser.add_argument("--port", type=int, default=8001)
|
||||||
parser = ServerArgs.add_cli_args(parser)
|
parser = AsyncServerArgs.add_cli_args(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
server_args = AsyncServerArgs.from_cli_args(args)
|
||||||
server = AsyncLLMServer.from_server_args(server_args)
|
server = AsyncLLMServer.from_server_args(server_args)
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||||
|
|||||||
@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
|
|||||||
SWAPPED = enum.auto()
|
SWAPPED = enum.auto()
|
||||||
FINISHED_STOPPED = enum.auto()
|
FINISHED_STOPPED = enum.auto()
|
||||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||||
|
FINISHED_ABORTED = enum.auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_finished(status: "SequenceStatus") -> bool:
|
def is_finished(status: "SequenceStatus") -> bool:
|
||||||
return status in [
|
return status in [
|
||||||
SequenceStatus.FINISHED_STOPPED,
|
SequenceStatus.FINISHED_STOPPED,
|
||||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||||
|
SequenceStatus.FINISHED_ABORTED,
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
|
|||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
|
elif status == SequenceStatus.FINISHED_ABORTED:
|
||||||
|
finish_reason = "abort"
|
||||||
else:
|
else:
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
return finish_reason
|
return finish_reason
|
||||||
|
|
||||||
|
|
||||||
class SequenceData:
|
class SequenceData:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -137,6 +142,9 @@ class Sequence:
|
|||||||
def get_cumulative_logprob(self) -> float:
|
def get_cumulative_logprob(self) -> float:
|
||||||
return self.data.cumulative_logprob
|
return self.data.cumulative_logprob
|
||||||
|
|
||||||
|
def is_finished(self) -> bool:
|
||||||
|
return SequenceStatus.is_finished(self.status)
|
||||||
|
|
||||||
def fork(self, child_seq: 'Sequence') -> None:
|
def fork(self, child_seq: 'Sequence') -> None:
|
||||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||||
@ -182,7 +190,7 @@ class SequenceGroup:
|
|||||||
raise ValueError(f'Sequence {seq_id} not found.')
|
raise ValueError(f'Sequence {seq_id} not found.')
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
|
return all(seq.is_finished() for seq in self.seqs)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||||
|
|||||||
@ -15,7 +15,7 @@ class ServerArgs:
|
|||||||
use_dummy_weights: bool = False
|
use_dummy_weights: bool = False
|
||||||
dtype: str = "default"
|
dtype: str = "default"
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
pipeline_parallel_size: int = 1
|
pipeline_parallel_size: int = 1
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
block_size: int = 16
|
block_size: int = 16
|
||||||
@ -32,7 +32,63 @@ class ServerArgs:
|
|||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
parser: argparse.ArgumentParser,
|
parser: argparse.ArgumentParser,
|
||||||
) -> argparse.ArgumentParser:
|
) -> argparse.ArgumentParser:
|
||||||
return _add_server_arguments(parser)
|
"""Shared CLI arguments for CacheFlow servers."""
|
||||||
|
# Model arguments
|
||||||
|
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||||
|
help='name or path of the huggingface model to use')
|
||||||
|
parser.add_argument('--download-dir', type=str,
|
||||||
|
default=ServerArgs.download_dir,
|
||||||
|
help='directory to download and load the weights, '
|
||||||
|
'default to the default cache dir of '
|
||||||
|
'huggingface')
|
||||||
|
parser.add_argument('--use-np-weights', action='store_true',
|
||||||
|
help='save a numpy copy of model weights for '
|
||||||
|
'faster loading. This can increase the disk '
|
||||||
|
'usage by up to 2x.')
|
||||||
|
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||||
|
help='use dummy values for model weights')
|
||||||
|
# TODO(woosuk): Support FP32.
|
||||||
|
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
||||||
|
choices=['default', 'half', 'bfloat16'],
|
||||||
|
help='data type for model weights and activations. '
|
||||||
|
'The "default" option will use FP16 precision '
|
||||||
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
|
'for BF16 models.')
|
||||||
|
# Parallel arguments
|
||||||
|
parser.add_argument('--worker-use-ray', action='store_true',
|
||||||
|
help='use Ray for distributed serving, will be '
|
||||||
|
'automatically set when using more than 1 GPU')
|
||||||
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||||
|
default=ServerArgs.pipeline_parallel_size,
|
||||||
|
help='number of pipeline stages')
|
||||||
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||||
|
default=ServerArgs.tensor_parallel_size,
|
||||||
|
help='number of tensor parallel replicas')
|
||||||
|
# KV cache arguments
|
||||||
|
parser.add_argument('--block-size', type=int,
|
||||||
|
default=ServerArgs.block_size,
|
||||||
|
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||||
|
help='token block size')
|
||||||
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
|
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
||||||
|
help='random seed')
|
||||||
|
parser.add_argument('--swap-space', type=int,
|
||||||
|
default=ServerArgs.swap_space,
|
||||||
|
help='CPU swap space size (GiB) per GPU')
|
||||||
|
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||||
|
default=ServerArgs.gpu_memory_utilization,
|
||||||
|
help='the percentage of GPU memory to be used for'
|
||||||
|
'the model executor')
|
||||||
|
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||||
|
default=ServerArgs.max_num_batched_tokens,
|
||||||
|
help='maximum number of batched tokens per '
|
||||||
|
'iteration')
|
||||||
|
parser.add_argument('--max-num-seqs', type=int,
|
||||||
|
default=ServerArgs.max_num_seqs,
|
||||||
|
help='maximum number of sequences per iteration')
|
||||||
|
parser.add_argument('--disable-log-stats', action='store_true',
|
||||||
|
help='disable logging statistics')
|
||||||
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
|
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
|
||||||
@ -53,65 +109,22 @@ class ServerArgs:
|
|||||||
self.swap_space)
|
self.swap_space)
|
||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.use_ray)
|
self.worker_use_ray)
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs)
|
self.max_num_seqs)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
def _add_server_arguments(
|
@dataclass
|
||||||
parser: argparse.ArgumentParser,
|
class AsyncServerArgs(ServerArgs):
|
||||||
)-> argparse.ArgumentParser:
|
server_use_ray: bool = False
|
||||||
"""Shared CLI arguments for CacheFlow servers."""
|
|
||||||
# Model arguments
|
@staticmethod
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
def add_cli_args(
|
||||||
help='name or path of the huggingface model to use')
|
parser: argparse.ArgumentParser,
|
||||||
parser.add_argument('--download-dir', type=str,
|
) -> argparse.ArgumentParser:
|
||||||
default=ServerArgs.download_dir,
|
parser = ServerArgs.add_cli_args(parser)
|
||||||
help='directory to download and load the weights, '
|
parser.add_argument('--server-use-ray', action='store_true',
|
||||||
'default to the default cache dir of huggingface')
|
help='use Ray to start the LLM server in a '
|
||||||
parser.add_argument('--use-np-weights', action='store_true',
|
'separate process as the web server process.')
|
||||||
help='save a numpy copy of model weights for faster '
|
return parser
|
||||||
'loading. This can increase the disk usage by up '
|
|
||||||
'to 2x.')
|
|
||||||
parser.add_argument('--use-dummy-weights', action='store_true',
|
|
||||||
help='use dummy values for model weights')
|
|
||||||
# TODO(woosuk): Support FP32.
|
|
||||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
|
||||||
choices=['default', 'half', 'bfloat16'],
|
|
||||||
help=('data type for model weights and activations. '
|
|
||||||
'The "default" option will use FP16 precision '
|
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
|
||||||
'for BF16 models.'))
|
|
||||||
# Parallel arguments
|
|
||||||
parser.add_argument('--use-ray', action='store_true',
|
|
||||||
help='use Ray for distributed serving, will be '
|
|
||||||
'automatically set when using more than 1 GPU')
|
|
||||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
|
||||||
default=ServerArgs.pipeline_parallel_size,
|
|
||||||
help='number of pipeline stages')
|
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
|
||||||
default=ServerArgs.tensor_parallel_size,
|
|
||||||
help='number of tensor parallel replicas')
|
|
||||||
# KV cache arguments
|
|
||||||
parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
|
|
||||||
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
|
|
||||||
help='token block size')
|
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
|
||||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
|
||||||
help='random seed')
|
|
||||||
parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
|
|
||||||
help='CPU swap space size (GiB) per GPU')
|
|
||||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
|
||||||
default=ServerArgs.gpu_memory_utilization,
|
|
||||||
help='the percentage of GPU memory to be used for the '
|
|
||||||
'model executor')
|
|
||||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
|
||||||
default=ServerArgs.max_num_batched_tokens,
|
|
||||||
help='maximum number of batched tokens per iteration')
|
|
||||||
parser.add_argument('--max-num-seqs', type=int,
|
|
||||||
default=ServerArgs.max_num_seqs,
|
|
||||||
help='maximum number of sequences per iteration')
|
|
||||||
parser.add_argument('--disable-log-stats', action='store_true',
|
|
||||||
help='disable logging statistics')
|
|
||||||
return parser
|
|
||||||
|
|||||||
@ -2,37 +2,52 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import ray
|
from cacheflow.logger import init_logger
|
||||||
|
|
||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||||
from cacheflow.server.llm_server import LLMServer
|
from cacheflow.server.llm_server import LLMServer
|
||||||
from cacheflow.server.ray_utils import initialize_cluster
|
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||||
from cacheflow.utils import random_uuid
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMServer:
|
class AsyncLLMServer:
|
||||||
|
|
||||||
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
|
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
||||||
if server_use_ray:
|
*args, **kwargs) -> None:
|
||||||
remote_server_class = ray.remote(num_cpus=0)(LLMServer)
|
self.worker_use_ray = worker_use_ray
|
||||||
|
self.server_use_ray = server_use_ray
|
||||||
|
if not self.server_use_ray:
|
||||||
|
server_class = LLMServer
|
||||||
|
elif self.worker_use_ray:
|
||||||
|
server_class = ray.remote(num_cpus=0)(LLMServer).remote
|
||||||
else:
|
else:
|
||||||
remote_server_class = ray.remote(num_gpus=1)(LLMServer)
|
server_class = ray.remote(num_gpus=1)(LLMServer).remote
|
||||||
self.server = remote_server_class.remote(*args, **kwargs)
|
self.server = server_class(*args, **kwargs)
|
||||||
|
|
||||||
# Request id -> request output.
|
# Request id -> request output.
|
||||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||||
# Request id -> event to notify that there is new output.
|
# Request id -> event to notify that there is new output.
|
||||||
self.request_events: Dict[str, asyncio.Event] = {}
|
self.request_events: Dict[str, asyncio.Event] = {}
|
||||||
self.is_server_running = False
|
self.is_server_running = False
|
||||||
|
self.kicking_request_id: Optional[str] = None
|
||||||
|
|
||||||
async def server_step(self):
|
async def server_step(self, kicking_request_id: Optional[str] = None):
|
||||||
self.is_server_running = True
|
self.is_server_running = True
|
||||||
request_outputs = await self.server.step.remote()
|
self.kicking_request_id = kicking_request_id
|
||||||
|
if self.server_use_ray:
|
||||||
|
request_outputs = await self.server.step.remote()
|
||||||
|
else:
|
||||||
|
# Yield to the event loop to allow other coroutines to run
|
||||||
|
# while is_server_running is True. This let the server to add new
|
||||||
|
# requests into the queue.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
request_outputs = self.server.step()
|
||||||
self.is_server_running = False
|
self.is_server_running = False
|
||||||
|
self.kicking_request_id = None
|
||||||
|
|
||||||
# Notify the waiting coroutines that there are new outputs ready.
|
# Notify the waiting coroutines that there are new outputs ready.
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
request_id = request_output.request_id
|
request_id = request_output.request_id
|
||||||
@ -40,20 +55,26 @@ class AsyncLLMServer:
|
|||||||
self.request_events[request_id].set()
|
self.request_events[request_id].set()
|
||||||
|
|
||||||
async def generate(self, prompt: str, sampling_params: SamplingParams,
|
async def generate(self, prompt: str, sampling_params: SamplingParams,
|
||||||
request_id: Optional[str] = None) -> RequestOutput:
|
request_id: str) -> RequestOutput:
|
||||||
# Preprocess the request.
|
# Preprocess the request.
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
# Create an event to notify us that there is new output from the
|
# Create an event to notify us that there is new output from the
|
||||||
# cacheflow server.
|
# cacheflow server.
|
||||||
if request_id is None:
|
|
||||||
request_id = random_uuid()
|
|
||||||
request_event = asyncio.Event()
|
request_event = asyncio.Event()
|
||||||
self.request_events[request_id] = request_event
|
self.request_events[request_id] = request_event
|
||||||
|
|
||||||
|
logger.info(f"Received request {request_id}: "
|
||||||
|
f"prompt: {prompt!r}, "
|
||||||
|
f"sampling params: {sampling_params}.")
|
||||||
|
|
||||||
# Add the request into the cacheflow server's waiting queue.
|
# Add the request into the cacheflow server's waiting queue.
|
||||||
await self.server.add_request.remote(
|
if self.server_use_ray:
|
||||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
await self.server.add_request.remote(
|
||||||
|
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||||
|
else:
|
||||||
|
self.server.add_request(
|
||||||
|
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||||
|
|
||||||
# The cacheflow server does not have a background loop that keeps
|
# The cacheflow server does not have a background loop that keeps
|
||||||
# processing incoming requests. Therefore, we need to keep kicking
|
# processing incoming requests. Therefore, we need to keep kicking
|
||||||
@ -61,7 +82,7 @@ class AsyncLLMServer:
|
|||||||
while True:
|
while True:
|
||||||
# Kick the server if the server is not running.
|
# Kick the server if the server is not running.
|
||||||
if not self.is_server_running:
|
if not self.is_server_running:
|
||||||
await self.server_step()
|
await self.server_step(request_id)
|
||||||
|
|
||||||
# Wait for new output. The group_event will be set in server_step
|
# Wait for new output. The group_event will be set in server_step
|
||||||
# when there is new output available for the sequence group.
|
# when there is new output available for the sequence group.
|
||||||
@ -80,6 +101,8 @@ class AsyncLLMServer:
|
|||||||
|
|
||||||
# Once finished, release the resources of the sequence group.
|
# Once finished, release the resources of the sequence group.
|
||||||
if request_output.finished():
|
if request_output.finished():
|
||||||
|
logger.info(f"Finished request {request_id}.")
|
||||||
|
|
||||||
del self.request_outputs[request_id]
|
del self.request_outputs[request_id]
|
||||||
del self.request_events[request_id]
|
del self.request_events[request_id]
|
||||||
# Kick the server if the server is not running. This is to
|
# Kick the server if the server is not running. This is to
|
||||||
@ -89,15 +112,41 @@ class AsyncLLMServer:
|
|||||||
await self.server_step()
|
await self.server_step()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
async def abort(self, request_id: str) -> None:
|
||||||
|
if request_id not in self.request_events:
|
||||||
|
# The request has already finished or been aborted.
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
|
if self.server_use_ray:
|
||||||
|
await self.server.abort_request.remote(request_id)
|
||||||
|
else:
|
||||||
|
self.server.abort_request(request_id)
|
||||||
|
|
||||||
|
if request_id in self.request_events:
|
||||||
|
del self.request_events[request_id]
|
||||||
|
if request_id in self.request_outputs:
|
||||||
|
del self.request_outputs[request_id]
|
||||||
|
|
||||||
|
# To prevent deadlock when a request is aborted while the server is
|
||||||
|
# running.
|
||||||
|
if self.kicking_request_id == request_id:
|
||||||
|
self.is_server_running = False
|
||||||
|
self.kicking_request_id = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer":
|
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
|
||||||
# Create the server configs.
|
# Create the server configs.
|
||||||
server_configs = server_args.create_server_configs()
|
server_configs = server_args.create_server_configs()
|
||||||
parallel_config = server_configs[2]
|
parallel_config = server_configs[2]
|
||||||
# Initialize the cluster.
|
# Initialize the cluster.
|
||||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
distributed_init_method, devices = initialize_cluster(
|
||||||
|
parallel_config, server_args.server_use_ray)
|
||||||
# Create the LLM server.
|
# Create the LLM server.
|
||||||
server = cls(server_args.use_ray, *server_configs,
|
server = cls(server_args.worker_use_ray,
|
||||||
|
server_args.server_use_ray,
|
||||||
|
*server_configs,
|
||||||
distributed_init_method, devices,
|
distributed_init_method, devices,
|
||||||
log_stats=not server_args.disable_log_stats)
|
log_stats=not server_args.disable_log_stats)
|
||||||
return server
|
return server
|
||||||
|
|||||||
@ -1,11 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
try:
|
|
||||||
import ray
|
|
||||||
except ImportError:
|
|
||||||
ray = None
|
|
||||||
|
|
||||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from cacheflow.core.scheduler import Scheduler
|
from cacheflow.core.scheduler import Scheduler
|
||||||
@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
|
|||||||
from cacheflow.outputs import RequestOutput
|
from cacheflow.outputs import RequestOutput
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
from cacheflow.server.arg_utils import ServerArgs
|
||||||
from cacheflow.server.ray_utils import initialize_cluster
|
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||||
detokenize_incrementally)
|
detokenize_incrementally)
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||||
@ -62,7 +57,7 @@ class LLMServer:
|
|||||||
assert len(stage_devices) == 1, "Only support one stage for now."
|
assert len(stage_devices) == 1, "Only support one stage for now."
|
||||||
for rank, node_resource, _ in stage_devices[0]:
|
for rank, node_resource, _ in stage_devices[0]:
|
||||||
worker_cls = Worker
|
worker_cls = Worker
|
||||||
if self.parallel_config.use_ray:
|
if self.parallel_config.worker_use_ray:
|
||||||
worker_cls = ray.remote(
|
worker_cls = ray.remote(
|
||||||
num_cpus=0,
|
num_cpus=0,
|
||||||
num_gpus=1,
|
num_gpus=1,
|
||||||
@ -152,6 +147,9 @@ class LLMServer:
|
|||||||
# Add the sequence group to the scheduler.
|
# Add the sequence group to the scheduler.
|
||||||
self.scheduler.add_seq_group(seq_group)
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
def abort_request(self, request_id: str) -> None:
|
||||||
|
self.scheduler.abort_seq_group(request_id)
|
||||||
|
|
||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
return self.scheduler.get_num_unfinished_seq_groups()
|
return self.scheduler.get_num_unfinished_seq_groups()
|
||||||
|
|
||||||
@ -243,13 +241,13 @@ class LLMServer:
|
|||||||
all_outputs = []
|
all_outputs = []
|
||||||
for worker in self.workers:
|
for worker in self.workers:
|
||||||
executor = getattr(worker, method)
|
executor = getattr(worker, method)
|
||||||
if self.parallel_config.use_ray:
|
if self.parallel_config.worker_use_ray:
|
||||||
executor = executor.remote
|
executor = executor.remote
|
||||||
|
|
||||||
output = executor(*args, **kwargs)
|
output = executor(*args, **kwargs)
|
||||||
all_outputs.append(output)
|
all_outputs.append(output)
|
||||||
|
|
||||||
if self.parallel_config.use_ray:
|
if self.parallel_config.worker_use_ray:
|
||||||
all_outputs = ray.get(all_outputs)
|
all_outputs = ray.get(all_outputs)
|
||||||
|
|
||||||
if get_all_outputs:
|
if get_all_outputs:
|
||||||
|
|||||||
@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
|
|||||||
|
|
||||||
def initialize_cluster(
|
def initialize_cluster(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
|
server_use_ray: bool = False,
|
||||||
address: Optional[str] = None,
|
address: Optional[str] = None,
|
||||||
) -> Tuple[str, List[List[DeviceID]]]:
|
) -> Tuple[str, List[List[DeviceID]]]:
|
||||||
if not parallel_config.use_ray:
|
if parallel_config.worker_use_ray or server_use_ray:
|
||||||
|
if ray is None:
|
||||||
|
raise ImportError(
|
||||||
|
"Ray is not installed. Please install Ray to use distributed "
|
||||||
|
"serving.")
|
||||||
|
# Connect to a ray cluster.
|
||||||
|
ray.init(address=address)
|
||||||
|
|
||||||
|
if not parallel_config.worker_use_ray:
|
||||||
# Initialize cluster locally.
|
# Initialize cluster locally.
|
||||||
port = random.randint(10000, 20000)
|
port = random.randint(10000, 20000)
|
||||||
# We need to setup the distributed init method to make sure
|
# We need to setup the distributed init method to make sure
|
||||||
@ -24,13 +33,6 @@ def initialize_cluster(
|
|||||||
all_stage_devices = [[(0, None, 0)]]
|
all_stage_devices = [[(0, None, 0)]]
|
||||||
return distributed_init_method, all_stage_devices
|
return distributed_init_method, all_stage_devices
|
||||||
|
|
||||||
if ray is None:
|
|
||||||
raise ImportError(
|
|
||||||
"Ray is not installed. Please install Ray to use distributed "
|
|
||||||
"serving.")
|
|
||||||
# Connect to a ray cluster.
|
|
||||||
ray.init(address=address)
|
|
||||||
|
|
||||||
# Assume we have a uniform cluster that each node has the same number of
|
# Assume we have a uniform cluster that each node has the same number of
|
||||||
# GPUs for now.
|
# GPUs for now.
|
||||||
valid_node_resources = []
|
valid_node_resources = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user