[Mypy] Part 3 fix typing for nested directories for most of directory (#4161)

This commit is contained in:
SangBin Cho 2024-04-23 13:32:44 +09:00 committed by GitHub
parent 34128a697e
commit 0ae11f78ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 126 additions and 88 deletions

View File

@ -32,19 +32,20 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml

View File

@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'
# Run mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml
CODESPELL_EXCLUDES=(

View File

@ -46,15 +46,17 @@ ignore = [
python_version = "3.8"
ignore_missing_imports = true
check_untyped_defs = true
check_untyped_defs = true
follow_imports = "skip"
files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
]
[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"

View File

@ -116,7 +116,7 @@ class AttentionImpl(ABC):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -248,6 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.prompt_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the

View File

@ -106,7 +106,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata,
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@ -136,6 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale)
if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)

View File

@ -288,6 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.prompt_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].

View File

@ -104,6 +104,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert self._blocks is not None
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)

View File

@ -99,7 +99,7 @@ class CopyOnWriteTracker:
refcounter: RefCounter,
allocator: BlockAllocator,
):
self._copy_on_writes = defaultdict(list)
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator
@ -138,6 +138,8 @@ class CopyOnWriteTracker:
prev_block=block.prev_block).block_id
# Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id)
return block_id
@ -180,6 +182,6 @@ def get_all_blocks_recursively(last_block: Block) -> List[Block]:
recurse(block.prev_block, lst)
lst.append(block)
all_blocks = []
all_blocks: List[Block] = []
recurse(last_block, all_blocks)
return all_blocks

View File

@ -52,8 +52,7 @@ class Block(ABC):
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
@ -98,8 +97,7 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod

View File

@ -1,6 +1,6 @@
import os
from contextlib import contextmanager
from typing import List, Optional
from typing import Any, List, Optional
import torch
import torch.distributed as dist
@ -18,7 +18,7 @@ except ImportError:
logger = init_logger(__name__)
_CA_HANDLE = None
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
@ -51,7 +51,7 @@ def init_custom_ar() -> None:
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ:
@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
return None
@contextmanager
def _nvml():
@ -224,14 +226,14 @@ class CustomAllreduce:
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size
all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0])
offsets.append(all_data[i][1])
handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):

View File

@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
ncclDataType_t = ctypes.c_int
# enums
class ncclDataType_t(ctypes.c_int):
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int):
raise ValueError(f"Unsupported dtype: {dtype}")
class ncclRedOp_t(ctypes.c_int):
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int):
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# equivalent to c declaration:
@ -251,8 +255,8 @@ class NCCLCommunicator:
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm,
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0

View File

@ -30,6 +30,7 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
@ -52,6 +53,7 @@ def init_process_group(world_size: int,
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op)
@ -62,8 +64,9 @@ def destroy_process_group() -> None:
def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size
def get_nccl_backend():
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Callable, List
from transformers import PreTrainedTokenizer
@ -8,6 +8,7 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC):
@ -27,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterable, List
from typing import Callable, List
from transformers import PreTrainedTokenizer
@ -11,6 +11,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
@ -33,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Tuple, Union
from typing import Dict, List, Tuple, Union
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
@ -33,7 +34,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
@ -69,7 +70,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
@ -92,7 +93,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)

View File

@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group = [[] for _ in range(num_seq_groups)]
output_by_sequence_group: List[List[SamplerOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)

View File

@ -18,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@ -26,8 +27,8 @@ from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__)
@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())

View File

@ -4,7 +4,8 @@ import time
from typing import Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, Field, conint, model_validator
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
is_blocking: bool = False
class ModelCard(BaseModel):
@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]
type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel):
@ -152,6 +153,7 @@ class ChatCompletionRequest(BaseModel):
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
@ -213,7 +215,7 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: Optional[int] = 1
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
@ -235,7 +237,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
@ -289,6 +291,7 @@ class CompletionRequest(BaseModel):
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))

View File

@ -115,12 +115,12 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True
# Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
try:
async for res in result_generator:
res: RequestOutput
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).

View File

@ -185,6 +185,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
@ -202,6 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
@ -279,7 +281,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
@ -289,6 +291,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = final_res.prompt
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs

View File

@ -4,7 +4,9 @@ from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union
from pydantic import conint
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@ -45,7 +47,8 @@ class OpenAIServing:
]
self.max_model_len = 0
self.tokenizer = None
# Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try:
event_loop = asyncio.get_running_loop()
@ -92,7 +95,7 @@ class OpenAIServing:
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
@ -108,6 +111,7 @@ class OpenAIServing:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else:
token_logprob = step_top_logprobs[token_id].logprob
@ -116,6 +120,7 @@ class OpenAIServing:
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
@ -155,9 +160,9 @@ class OpenAIServing:
async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
@ -165,7 +170,7 @@ class OpenAIServing:
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
return
return None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
@ -177,7 +182,7 @@ class OpenAIServing:
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")

View File

@ -33,7 +33,7 @@ class LoRALayerWeights:
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return
return self
self.lora_b *= self.scaling
self.scaling = 1
return self

View File

@ -29,8 +29,8 @@ def _multi_split_sample(
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
logprobs: torch.Tensor,
*,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
@ -167,6 +167,7 @@ def sample(
sampled_logprobs_size = (0, 0)
logprobs = probs
assert logprobs is not None
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:

View File

@ -108,7 +108,8 @@ class RotaryEmbedding(nn.Module):
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)

View File

@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig):
f"got {alibi_scaling_type}")
if (alibi_scaling_factor is not None
and not isinstance(alibi_scaling_factor, float)
or alibi_scaling_factor <= 1.0):
or (alibi_scaling_factor is not None
and alibi_scaling_factor <= 1.0)):
raise ValueError(
f"`alibi_scaling`'s factor field must be a float > 1.0,"
f"got {alibi_scaling_factor}")
if (alibi_dynamic_scaling is not None
and not isinstance(alibi_dynamic_scaling, int)
or alibi_dynamic_scaling <= 1):
or (alibi_dynamic_scaling is not None
and alibi_dynamic_scaling <= 1)):
raise ValueError(
f"`alibi_scaling`'s `train_seq_len` field must be an"
f"integer > 1, got {alibi_dynamic_scaling}")

View File

@ -11,7 +11,7 @@ if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
else:
RayTokenizerGroupPool = None
RayTokenizerGroupPool = None # type: ignore
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],

View File

@ -89,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is blocking.
"""
self._ensure_queue_initialized()
assert self._idle_actors is not None
if self._idle_actors.empty():
raise RuntimeError("No idle actors available.")
@ -120,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is non-blocking.
"""
self._ensure_queue_initialized()
assert self._idle_actors is not None
actor = await self._idle_actors.get()
try:

View File

@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
def convert_tokens_to_string(self, tokens: List[str]):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
current_sub_tokens: List[str] = []
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):