[Quality] Add CI for formatting (#343)

This commit is contained in:
Zhuohan Li 2023-07-03 14:50:56 -07:00 committed by GitHub
parent e41f06702c
commit 42e0c1df78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 31 deletions

31
.github/workflows/pylint.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: pylint
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
pylint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint==2.8.2
- name: Analysing the code with pylint
run: |
pylint vllm

31
.github/workflows/yapf.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: yapf
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'

View File

@ -2,6 +2,7 @@ import asyncio
import time import time
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray from vllm.engine.ray_utils import initialize_cluster, ray
@ -206,6 +207,13 @@ class AsyncLLMEngine:
self.is_engine_running = False self.is_engine_running = False
self.kicking_request_id = None self.kicking_request_id = None
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote()
else:
return self.engine.get_model_config()
@classmethod @classmethod
def from_engine_args(cls, def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":

View File

@ -210,6 +210,10 @@ class LLMEngine:
""" """
self.scheduler.abort_seq_group(request_id) self.scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return self.scheduler.get_num_unfinished_seq_groups()

View File

@ -2,16 +2,19 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse import argparse
import asyncio
from http import HTTPStatus from http import HTTPStatus
import json import json
import time import time
from typing import AsyncGenerator, Dict, List, Optional, Union, Any from typing import AsyncGenerator, Dict, List, Optional
import fastapi import fastapi
from fastapi import BackgroundTasks, Request from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import (Conversation, SeparatorStyle,
get_conv_template)
import uvicorn import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -19,11 +22,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatMessage, DeltaMessage, ErrorResponse, LogProbs, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
ModelCard, ModelList, ModelPermission, UsageInfo) LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
return prompt return prompt
async def check_length(request, prompt, engine): async def check_length(request, prompt, model_config):
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"): if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = engine.engine.model_config.hf_config.max_sequence_length context_len = model_config.hf_config.max_sequence_length
elif hasattr(engine.engine.model_config.hf_config, "seq_length"): elif hasattr(model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length context_len = model_config.hf_config.seq_length
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"): elif hasattr(model_config.hf_config, "max_position_embeddings"):
context_len = engine.engine.model_config.hf_config.max_position_embeddings context_len = model_config.hf_config.max_position_embeddings
elif hasattr(engine.engine.model_config.hf_config, "seq_length"): elif hasattr(model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length context_len = model_config.hf_config.seq_length
else: else:
context_len = 2048 context_len = 2048
@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported") "logit_bias is not currently supported")
prompt = await get_gen_prompt(request) prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine) error_check_ret = await check_length(request, prompt, engine_model_config)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, result_generator = engine.generate(prompt, sampling_params, request_id)
request_id)
async def abort_request() -> None: async def abort_request() -> None:
await engine.abort(request_id) await engine.abort(request_id)
def create_stream_response_json(index: int, def create_stream_response_json(
index: int,
text: str, text: str,
finish_reason: Optional[str] = None) -> str: finish_reason: Optional[str] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(content=text), delta=DeltaMessage(content=text),
@ -238,10 +241,11 @@ async def create_chat_completion(raw_request: Request):
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
finish_reason=None, finish_reason=None,
) )
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(id=request_id,
id=request_id, choices=[choice_data], model=model_name choices=[choice_data],
) model=model_name)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
previous_texts = [""] * request.n previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n previous_num_tokens = [0] * request.n
@ -295,8 +299,8 @@ async def create_chat_completion(raw_request: Request):
choices.append(choice_data) choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) num_generated_tokens = sum(
for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
@ -314,9 +318,11 @@ async def create_chat_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
response_json = response.json(ensure_ascii=False) response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]: async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(), return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream")
@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt") "please provide at least one prompt")
if len(request.prompt) > 1: if len(request.prompt) > 1:
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(
"multiple prompts in a batch is not " HTTPStatus.BAD_REQUEST,
"currently supported") "multiple prompts in a batch is not currently supported")
prompt = request.prompt[0] prompt = request.prompt[0]
else: else:
prompt = request.prompt prompt = request.prompt
@ -571,6 +577,7 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer, tokenizer = get_tokenizer(engine_args.tokenizer,

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The CacheFlow team. # Copyright 2023 The CacheFlow team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
# #