mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:15:01 +08:00
[Quality] Add CI for formatting (#343)
This commit is contained in:
parent
e41f06702c
commit
42e0c1df78
31
.github/workflows/pylint.yml
vendored
Normal file
31
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: pylint
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pylint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pylint==2.8.2
|
||||||
|
- name: Analysing the code with pylint
|
||||||
|
run: |
|
||||||
|
pylint vllm
|
||||||
31
.github/workflows/yapf.yml
vendored
Normal file
31
.github/workflows/yapf.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: yapf
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
jobs:
|
||||||
|
yapf:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install yapf==0.32.0
|
||||||
|
pip install toml==0.10.2
|
||||||
|
- name: Running yapf
|
||||||
|
run: |
|
||||||
|
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
|
||||||
@ -2,6 +2,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||||
@ -206,6 +207,13 @@ class AsyncLLMEngine:
|
|||||||
self.is_engine_running = False
|
self.is_engine_running = False
|
||||||
self.kicking_request_id = None
|
self.kicking_request_id = None
|
||||||
|
|
||||||
|
async def get_model_config(self) -> ModelConfig:
|
||||||
|
"""Get the model configuration of the vLLM engine."""
|
||||||
|
if self.engine_use_ray:
|
||||||
|
return await self.engine.get_model_config.remote()
|
||||||
|
else:
|
||||||
|
return self.engine.get_model_config()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(cls,
|
def from_engine_args(cls,
|
||||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||||
|
|||||||
@ -210,6 +210,10 @@ class LLMEngine:
|
|||||||
"""
|
"""
|
||||||
self.scheduler.abort_seq_group(request_id)
|
self.scheduler.abort_seq_group(request_id)
|
||||||
|
|
||||||
|
def get_model_config(self) -> ModelConfig:
|
||||||
|
"""Gets the model configuration."""
|
||||||
|
return self.model_config
|
||||||
|
|
||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
"""Gets the number of unfinished requests."""
|
"""Gets the number of unfinished requests."""
|
||||||
return self.scheduler.get_num_unfinished_seq_groups()
|
return self.scheduler.get_num_unfinished_seq_groups()
|
||||||
|
|||||||
@ -2,16 +2,19 @@
|
|||||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
from fastapi import BackgroundTasks, Request
|
from fastapi import BackgroundTasks, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from fastchat.conversation import (Conversation, SeparatorStyle,
|
||||||
|
get_conv_template)
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@ -19,11 +22,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||||
CompletionResponseStreamChoice, CompletionStreamResponse,
|
CompletionResponseStreamChoice, CompletionStreamResponse,
|
||||||
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||||
ModelCard, ModelList, ModelPermission, UsageInfo)
|
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
async def check_length(request, prompt, engine):
|
async def check_length(request, prompt, model_config):
|
||||||
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
|
if hasattr(model_config.hf_config, "max_sequence_length"):
|
||||||
context_len = engine.engine.model_config.hf_config.max_sequence_length
|
context_len = model_config.hf_config.max_sequence_length
|
||||||
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
elif hasattr(model_config.hf_config, "seq_length"):
|
||||||
context_len = engine.engine.model_config.hf_config.seq_length
|
context_len = model_config.hf_config.seq_length
|
||||||
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
|
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
||||||
context_len = engine.engine.model_config.hf_config.max_position_embeddings
|
context_len = model_config.hf_config.max_position_embeddings
|
||||||
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
elif hasattr(model_config.hf_config, "seq_length"):
|
||||||
context_len = engine.engine.model_config.hf_config.seq_length
|
context_len = model_config.hf_config.seq_length
|
||||||
else:
|
else:
|
||||||
context_len = 2048
|
context_len = 2048
|
||||||
|
|
||||||
@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
prompt = await get_gen_prompt(request)
|
prompt = await get_gen_prompt(request)
|
||||||
error_check_ret = await check_length(request, prompt, engine)
|
error_check_ret = await check_length(request, prompt, engine_model_config)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = engine.generate(prompt, sampling_params,
|
result_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
request_id)
|
|
||||||
|
|
||||||
async def abort_request() -> None:
|
async def abort_request() -> None:
|
||||||
await engine.abort(request_id)
|
await engine.abort(request_id)
|
||||||
|
|
||||||
def create_stream_response_json(index: int,
|
def create_stream_response_json(
|
||||||
text: str,
|
index: int,
|
||||||
finish_reason: Optional[str] = None) -> str:
|
text: str,
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=index,
|
index=index,
|
||||||
delta=DeltaMessage(content=text),
|
delta=DeltaMessage(content=text),
|
||||||
@ -238,10 +241,11 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role="assistant"),
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||||
id=request_id, choices=[choice_data], model=model_name
|
choices=[choice_data],
|
||||||
)
|
model=model_name)
|
||||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
previous_texts = [""] * request.n
|
previous_texts = [""] * request.n
|
||||||
previous_num_tokens = [0] * request.n
|
previous_num_tokens = [0] * request.n
|
||||||
@ -295,8 +299,8 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
num_generated_tokens = sum(len(output.token_ids)
|
num_generated_tokens = sum(
|
||||||
for output in final_res.outputs)
|
len(output.token_ids) for output in final_res.outputs)
|
||||||
usage = UsageInfo(
|
usage = UsageInfo(
|
||||||
prompt_tokens=num_prompt_tokens,
|
prompt_tokens=num_prompt_tokens,
|
||||||
completion_tokens=num_generated_tokens,
|
completion_tokens=num_generated_tokens,
|
||||||
@ -314,9 +318,11 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
# When user requests streaming but we don't stream, we still need to
|
# When user requests streaming but we don't stream, we still need to
|
||||||
# return a streaming response with a single event.
|
# return a streaming response with a single event.
|
||||||
response_json = response.json(ensure_ascii=False)
|
response_json = response.json(ensure_ascii=False)
|
||||||
|
|
||||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
yield f"data: {response_json}\n\n"
|
yield f"data: {response_json}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(fake_stream_generator(),
|
return StreamingResponse(fake_stream_generator(),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
||||||
@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
|
|||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"please provide at least one prompt")
|
"please provide at least one prompt")
|
||||||
if len(request.prompt) > 1:
|
if len(request.prompt) > 1:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(
|
||||||
"multiple prompts in a batch is not "
|
HTTPStatus.BAD_REQUEST,
|
||||||
"currently supported")
|
"multiple prompts in a batch is not currently supported")
|
||||||
prompt = request.prompt[0]
|
prompt = request.prompt[0]
|
||||||
else:
|
else:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
@ -571,6 +577,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
# A separate tokenizer to map token IDs to strings.
|
||||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
|
||||||
# Copyright 2023 The CacheFlow team.
|
# Copyright 2023 The CacheFlow team.
|
||||||
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
||||||
#
|
#
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user