mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Mypy] Part 3 fix typing for nested directories for most of directory (#4161)
This commit is contained in:
parent
34128a697e
commit
0ae11f78ab
29
.github/workflows/mypy.yaml
vendored
29
.github/workflows/mypy.yaml
vendored
@ -32,19 +32,20 @@ jobs:
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/attention --config-file pyproject.toml
|
||||
# TODO(sang): Fix nested dir
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# TODO(sang): Follow up
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
# TODO(sang): Fix nested dir
|
||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||
# TODO(sang): Fix nested dir
|
||||
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||
|
||||
|
||||
26
format.sh
26
format.sh
@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'
|
||||
|
||||
# Run mypy
|
||||
echo 'vLLM mypy:'
|
||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/attention --config-file pyproject.toml
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Follow up
|
||||
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --config-file pyproject.toml
|
||||
|
||||
|
||||
CODESPELL_EXCLUDES=(
|
||||
|
||||
@ -46,15 +46,17 @@ ignore = [
|
||||
python_version = "3.8"
|
||||
|
||||
ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
check_untyped_defs = true
|
||||
follow_imports = "skip"
|
||||
|
||||
files = "vllm"
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
exclude = [
|
||||
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||
# Ignore triton kernels in ops.
|
||||
'vllm/attention/ops/.*\.py$'
|
||||
]
|
||||
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "dout, te, indicies"
|
||||
skip = "./tests/prompts,./benchmarks/sonnet.txt"
|
||||
|
||||
@ -116,7 +116,7 @@ class AttentionImpl(ABC):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
||||
attn_metadata: AttentionMetadata,
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -248,6 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
assert prefill_meta.prompt_lens is not None
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
|
||||
@ -106,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
@ -136,6 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
assert attn_metadata.prompt_lens is not None
|
||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
|
||||
@ -288,6 +288,7 @@ class XFormersImpl(AttentionImpl):
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
"""
|
||||
assert attn_metadata.prompt_lens is not None
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
|
||||
@ -104,6 +104,7 @@ class BlockTable:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert self._blocks is not None
|
||||
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots)
|
||||
|
||||
@ -99,7 +99,7 @@ class CopyOnWriteTracker:
|
||||
refcounter: RefCounter,
|
||||
allocator: BlockAllocator,
|
||||
):
|
||||
self._copy_on_writes = defaultdict(list)
|
||||
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
|
||||
self._refcounter = refcounter
|
||||
self._allocator = allocator
|
||||
|
||||
@ -138,6 +138,8 @@ class CopyOnWriteTracker:
|
||||
prev_block=block.prev_block).block_id
|
||||
|
||||
# Track src/dst copy.
|
||||
assert src_block_id is not None
|
||||
assert block_id is not None
|
||||
self._copy_on_writes[src_block_id].append(block_id)
|
||||
|
||||
return block_id
|
||||
@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
|
||||
recurse(block.prev_block, lst)
|
||||
lst.append(block)
|
||||
|
||||
all_blocks = []
|
||||
all_blocks: List[Block] = []
|
||||
recurse(last_block, all_blocks)
|
||||
return all_blocks
|
||||
|
||||
@ -52,8 +52,7 @@ class Block(ABC):
|
||||
class BlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -98,8 +97,7 @@ class BlockAllocator(ABC):
|
||||
class DeviceAwareBlockAllocator(BlockAllocator):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -18,7 +18,7 @@ except ImportError:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CA_HANDLE = None
|
||||
_CA_HANDLE: Optional["CustomAllreduce"] = None
|
||||
_IS_CAPTURING = False
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
@ -51,7 +51,7 @@ def init_custom_ar() -> None:
|
||||
"Cannot test GPU P2P because not all GPUs are visible to the "
|
||||
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
||||
" is set.")
|
||||
return False
|
||||
return
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
ca_handle = get_handle()
|
||||
# when custom allreduce is disabled, this will be None
|
||||
if ca_handle is None:
|
||||
return
|
||||
return None
|
||||
if is_capturing():
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if ca_handle.should_custom_ar(input):
|
||||
@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_unreg(input)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
@ -224,14 +226,14 @@ class CustomAllreduce:
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
all_data = [None] * self.world_size
|
||||
all_data: List[Optional[Any]] = [None] * self.world_size
|
||||
dist.all_gather_object(all_data, shard_data)
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0])
|
||||
offsets.append(all_data[i][1])
|
||||
handles.append(all_data[i][0]) # type: ignore
|
||||
offsets.append(all_data[i][1]) # type: ignore
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
|
||||
@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
||||
]
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
# enums
|
||||
class ncclDataType_t(ctypes.c_int):
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int):
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
class ncclRedOp_t(ctypes.c_int):
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int):
|
||||
_c_ncclAllReduce = nccl.ncclAllReduce
|
||||
_c_ncclAllReduce.restype = ctypes.c_int
|
||||
_c_ncclAllReduce.argtypes = [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
|
||||
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
|
||||
]
|
||||
|
||||
# equivalent to c declaration:
|
||||
@ -251,8 +255,8 @@ class NCCLCommunicator:
|
||||
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||
ctypes.c_void_p(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataType_t.from_torch(tensor.dtype),
|
||||
ncclRedOp_t.from_torch(op), self.comm,
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
ctypes.c_void_p(stream.cuda_stream))
|
||||
assert result == 0
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ def is_initialized() -> bool:
|
||||
def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||
"""Set the cuda stream for communication"""
|
||||
try:
|
||||
assert comm is not None
|
||||
comm.stream = stream
|
||||
yield
|
||||
finally:
|
||||
@ -52,6 +53,7 @@ def init_process_group(world_size: int,
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
assert comm is not None
|
||||
comm.all_reduce(input_, op)
|
||||
|
||||
|
||||
@ -62,8 +64,9 @@ def destroy_process_group() -> None:
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
assert comm is not None
|
||||
return comm.world_size
|
||||
|
||||
|
||||
def get_nccl_backend():
|
||||
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
|
||||
return comm
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterable, List
|
||||
from typing import Callable, List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -8,6 +8,7 @@ from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
class SequenceGroupOutputProcessor(ABC):
|
||||
@ -27,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
seq_counter: Iterable[int],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
stop_checker: "StopChecker",
|
||||
):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Iterable, List
|
||||
from typing import Callable, List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -33,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
self,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
seq_counter: Iterable[int],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -33,7 +34,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
seq_counter: Iterable[int],
|
||||
seq_counter: Counter,
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
self.scheduler_config = scheduler_config
|
||||
@ -69,7 +70,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
parent_child_dict: Dict[int, List[SequenceOutput]] = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
@ -92,7 +93,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id = next(self.seq_counter)
|
||||
new_child_seq_id: int = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
|
||||
@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
|
||||
"""Helper method which transforms a 2d list organized by
|
||||
[step][sequence group] into [sequence group][step].
|
||||
"""
|
||||
output_by_sequence_group = [[] for _ in range(num_seq_groups)]
|
||||
output_by_sequence_group: List[List[SamplerOutput]] = [
|
||||
[] for _ in range(num_seq_groups)
|
||||
]
|
||||
for step in sampler_outputs:
|
||||
for i, sequence_group_output in enumerate(step):
|
||||
output_by_sequence_group[i].append(sequence_group_output)
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest, ErrorResponse)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
openai_serving_chat: OpenAIServingChat = None
|
||||
openai_serving_completion: OpenAIServingCompletion = None
|
||||
openai_serving_chat: OpenAIServingChat
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
assert isinstance(generator, ChatCompletionResponse)
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, conint, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
|
||||
allow_fine_tuning: bool = False
|
||||
organization: str = "*"
|
||||
group: Optional[str] = None
|
||||
is_blocking: str = False
|
||||
is_blocking: bool = False
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
# type must be "json_object" or "text"
|
||||
type: str = Literal["text", "json_object"]
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
@ -152,6 +153,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
assert self.logit_bias is not None
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
bias = min(100, max(-100, bias))
|
||||
@ -213,7 +215,7 @@ class CompletionRequest(BaseModel):
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
n: Optional[int] = 1
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
@ -235,7 +237,7 @@ class CompletionRequest(BaseModel):
|
||||
min_tokens: Optional[int] = 0
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
@ -289,6 +291,7 @@ class CompletionRequest(BaseModel):
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
assert self.logit_bias is not None
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
bias = min(100, max(-100, bias))
|
||||
|
||||
@ -115,12 +115,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
assert request.n is not None
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
try:
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
@ -185,6 +185,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
assert request.n is not None
|
||||
previous_texts = [""] * request.n * num_prompts
|
||||
previous_num_tokens = [0] * request.n * num_prompts
|
||||
has_echoed = [False] * request.n * num_prompts
|
||||
@ -202,6 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# TODO(simon): optimize the performance by avoiding full
|
||||
# text O(n^2) sending.
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
delta_text = res.prompt
|
||||
@ -279,7 +281,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> CompletionResponse:
|
||||
choices = []
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
for final_res in final_res_batch:
|
||||
@ -289,6 +291,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
top_logprobs = prompt_logprobs
|
||||
|
||||
@ -4,7 +4,9 @@ from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import conint
|
||||
from pydantic import Field
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
@ -45,7 +47,8 @@ class OpenAIServing:
|
||||
]
|
||||
|
||||
self.max_model_len = 0
|
||||
self.tokenizer = None
|
||||
# Lazy initialized
|
||||
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
try:
|
||||
event_loop = asyncio.get_running_loop()
|
||||
@ -92,7 +95,7 @@ class OpenAIServing:
|
||||
def _create_logprobs(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
|
||||
top_logprobs: List[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
@ -108,6 +111,7 @@ class OpenAIServing:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(None)
|
||||
assert logprobs.top_logprobs is not None
|
||||
logprobs.top_logprobs.append(None)
|
||||
else:
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
@ -116,6 +120,7 @@ class OpenAIServing:
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
assert logprobs.top_logprobs is not None
|
||||
logprobs.top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
@ -155,9 +160,9 @@ class OpenAIServing:
|
||||
|
||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||
if request.model in self.served_model_names:
|
||||
return
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return
|
||||
return None
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
@ -165,7 +170,7 @@ class OpenAIServing:
|
||||
|
||||
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
||||
if request.model in self.served_model_names:
|
||||
return
|
||||
return None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora
|
||||
@ -177,7 +182,7 @@ class OpenAIServing:
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
) -> Tuple[List[int], str]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
|
||||
@ -33,7 +33,7 @@ class LoRALayerWeights:
|
||||
def optimize(self) -> "LoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
if self.scaling == 1:
|
||||
return
|
||||
return self
|
||||
self.lora_b *= self.scaling
|
||||
self.scaling = 1
|
||||
return self
|
||||
|
||||
@ -29,8 +29,8 @@ def _multi_split_sample(
|
||||
sampled_tokens_size: Tuple[int, int],
|
||||
sampled_logprobs_size: Tuple[int, int],
|
||||
sample_indices: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
*,
|
||||
logprobs: Optional[torch.Tensor] = None,
|
||||
modify_greedy_probs: bool = False,
|
||||
save_logprobs: bool = False,
|
||||
):
|
||||
@ -167,6 +167,7 @@ def sample(
|
||||
sampled_logprobs_size = (0, 0)
|
||||
logprobs = probs
|
||||
|
||||
assert logprobs is not None
|
||||
if _save_modified_probs:
|
||||
sampled_modified_probs_size = sampled_tokens_size
|
||||
else:
|
||||
|
||||
@ -108,7 +108,8 @@ class RotaryEmbedding(nn.Module):
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
positions.device)
|
||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig):
|
||||
f"got {alibi_scaling_type}")
|
||||
if (alibi_scaling_factor is not None
|
||||
and not isinstance(alibi_scaling_factor, float)
|
||||
or alibi_scaling_factor <= 1.0):
|
||||
or (alibi_scaling_factor is not None
|
||||
and alibi_scaling_factor <= 1.0)):
|
||||
raise ValueError(
|
||||
f"`alibi_scaling`'s factor field must be a float > 1.0,"
|
||||
f"got {alibi_scaling_factor}")
|
||||
if (alibi_dynamic_scaling is not None
|
||||
and not isinstance(alibi_dynamic_scaling, int)
|
||||
or alibi_dynamic_scaling <= 1):
|
||||
or (alibi_dynamic_scaling is not None
|
||||
and alibi_dynamic_scaling <= 1)):
|
||||
raise ValueError(
|
||||
f"`alibi_scaling`'s `train_seq_len` field must be an"
|
||||
f"integer > 1, got {alibi_dynamic_scaling}")
|
||||
|
||||
@ -11,7 +11,7 @@ if ray:
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
RayTokenizerGroupPool)
|
||||
else:
|
||||
RayTokenizerGroupPool = None
|
||||
RayTokenizerGroupPool = None # type: ignore
|
||||
|
||||
|
||||
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
|
||||
@ -89,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
This is blocking.
|
||||
"""
|
||||
self._ensure_queue_initialized()
|
||||
assert self._idle_actors is not None
|
||||
|
||||
if self._idle_actors.empty():
|
||||
raise RuntimeError("No idle actors available.")
|
||||
@ -120,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
This is non-blocking.
|
||||
"""
|
||||
self._ensure_queue_initialized()
|
||||
assert self._idle_actors is not None
|
||||
|
||||
actor = await self._idle_actors.get()
|
||||
try:
|
||||
|
||||
@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
def convert_tokens_to_string(self, tokens: List[str]):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
current_sub_tokens: List[str] = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for i, token in enumerate(tokens):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user