[mypy] Enable following imports for some directories (#6681)

This commit is contained in:
Cyrus Leung 2024-07-31 10:38:03 +08:00 committed by GitHub
parent c32ab8be1a
commit da1f7cc12a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 185 additions and 143 deletions

View File

@ -32,22 +32,17 @@ jobs:
pip install types-setuptools pip install types-setuptools
- name: Mypy - name: Mypy
run: | run: |
mypy tests --config-file pyproject.toml mypy tests --follow-imports skip
mypy vllm/*.py --config-file pyproject.toml mypy vllm/attention --follow-imports skip
mypy vllm/attention --config-file pyproject.toml mypy vllm/core --follow-imports skip
mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --follow-imports skip
mypy vllm/distributed --config-file pyproject.toml mypy vllm/engine --follow-imports skip
mypy vllm/engine --config-file pyproject.toml mypy vllm/entrypoints --follow-imports skip
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --follow-imports skip
mypy vllm/executor --config-file pyproject.toml mypy vllm/lora --follow-imports skip
mypy vllm/inputs --config-file pyproject.toml mypy vllm/model_executor --follow-imports skip
mypy vllm/logging --config-file pyproject.toml mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/lora --config-file pyproject.toml mypy vllm/spec_decode --follow-imports skip
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/worker --follow-imports skip
mypy vllm/multimodal --config-file pyproject.toml mypy
mypy vllm/platforms --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml

View File

@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'
# Run mypy # Run mypy
echo 'vLLM mypy:' echo 'vLLM mypy:'
mypy tests --config-file pyproject.toml mypy tests --follow-imports skip
mypy vllm/*.py --config-file pyproject.toml mypy vllm/attention --follow-imports skip
mypy vllm/attention --config-file pyproject.toml mypy vllm/core --follow-imports skip
mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --follow-imports skip
mypy vllm/distributed --config-file pyproject.toml mypy vllm/engine --follow-imports skip
mypy vllm/engine --config-file pyproject.toml mypy vllm/entrypoints --follow-imports skip
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --follow-imports skip
mypy vllm/executor --config-file pyproject.toml mypy vllm/lora --follow-imports skip
mypy vllm/logging --config-file pyproject.toml mypy vllm/model_executor --follow-imports skip
mypy vllm/lora --config-file pyproject.toml mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/spec_decode --follow-imports skip
mypy vllm/multimodal --config-file pyproject.toml mypy vllm/worker --follow-imports skip
mypy vllm/prompt_adapter --config-file pyproject.toml mypy
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
# If git diff returns a file that is in the skip list, the file may be checked anyway: # If git diff returns a file that is in the skip list, the file may be checked anyway:

View File

@ -48,9 +48,23 @@ python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
follow_imports = "skip" follow_imports = "silent"
files = "vllm" # After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from format.sh and mypy.yaml
files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
"vllm/platforms",
"vllm/server",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [ exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/", "vllm/model_executor/parallel_utils/|vllm/model_executor/models/",

View File

@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype], out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)

View File

@ -25,27 +25,33 @@ class ipex_ops:
x2 = x2.reshape(num, d) x2 = x2.reshape(num, d)
return x1, x2 return x1, x2
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out) ipex.llm.functional.silu_mul(x1, x2, out)
@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none") ipex.llm.functional.gelu_mul(x1, x2, out, "none")
@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x)) out.copy_(torch.nn.functional.gelu(x))
@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x)) out.copy_(torch.nn.functional.gelu(x))
# TODO add implementation of gelu_quick here # TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@staticmethod
def paged_attention_v1( def paged_attention_v1(
out: torch.Tensor, out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
@ -78,12 +84,21 @@ class ipex_ops:
).view(num_kv_heads, ).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten() 1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace # todo: ipex will refactor namespace
torch.xpu.paged_attention_v1(out, query.contiguous(), torch.xpu.paged_attention_v1( # type: ignore
key_cache.view_as(value_cache), out,
value_cache, head_mapping, scale, query.contiguous(),
block_tables, context_lens, block_size, key_cache.view_as(value_cache),
max_context_len, alibi_slopes) value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def paged_attention_v2( def paged_attention_v2(
out: torch.Tensor, out: torch.Tensor,
exp_sum: torch.Tensor, exp_sum: torch.Tensor,
@ -119,13 +134,24 @@ class ipex_ops:
).view(num_kv_heads, ).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten() 1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace # todo: ipex will refactor namespace
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out, torch.xpu.paged_attention_v2( # type: ignore
query.contiguous(), out,
key_cache.view_as(value_cache), exp_sum,
value_cache, head_mapping, block_tables, max_logits,
context_lens, scale, block_size, tmp_out,
max_context_len, alibi_slopes) query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def rotary_embedding( def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len] positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
@ -158,6 +184,7 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions) rotary_dim, is_neox, positions)
@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int, key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool, cos_sin_cache: torch.Tensor, is_neox: bool,
@ -189,17 +216,20 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions) rotary_dim, is_neox, positions)
@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None: epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp) out.copy_(tmp)
@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None: weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True) epsilon, True)
input.copy_(tmp) input.copy_(tmp)
@staticmethod
def varlen_attention( def varlen_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
@ -222,6 +252,7 @@ class ipex_ops:
softmax_scale, zero_tensors, softmax_scale, zero_tensors,
is_causal, return_softmax, gen_) is_causal, return_softmax, gen_)
@staticmethod
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
@ -240,8 +271,13 @@ class ipex_ops:
def copy_blocks(key_caches: List[torch.Tensor], def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor], value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping) torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping) torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore

View File

@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
super().__init__(capacity) super().__init__(capacity)
self.deactivate_fn = deactivate_fn self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T): def _on_remove(self, key: Hashable, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key) logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key) self.deactivate_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)
@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
@property @property
@abstractmethod @abstractmethod
def adapter_slots(self): def adapter_slots(self) -> int:
... raise NotImplementedError
@property @property
@abstractmethod @abstractmethod
def capacity(self): def capacity(self) -> int:
... raise NotImplementedError
@abstractmethod @abstractmethod
def activate_adapter(self, adapter_id: int) -> bool: def activate_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool: def deactivate_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def add_adapter(self, adapter: Any) -> bool: def add_adapter(self, adapter: Any) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None: def set_adapter_mapping(self, mapping: Any) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_adapter(self, adapter_id: int) -> bool: def remove_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_all_adapters(self): def remove_all_adapters(self) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]: def get_adapter(self, adapter_id: int) -> Optional[Any]:
... raise NotImplementedError
@abstractmethod @abstractmethod
def list_adapters(self) -> Dict[int, Any]: def list_adapters(self) -> Dict[int, Any]:
... raise NotImplementedError
@abstractmethod @abstractmethod
def pin_adapter(self, adapter_id: int) -> bool: def pin_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError

View File

@ -1,19 +1,19 @@
from abc import abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class AdapterRequest: class AdapterRequest(ABC):
""" """
Base class for adapter requests. Base class for adapter requests.
""" """
@property @property
@abstractmethod @abstractmethod
def adapter_id(self): def adapter_id(self) -> int:
... raise NotImplementedError
def __post_init__(self): def __post_init__(self) -> None:
if self.adapter_id < 1: if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}") raise ValueError(f"id must be > 0, got {self.adapter_id}")

View File

@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
@property @property
@abstractmethod @abstractmethod
def is_enabled(self) -> bool: def is_enabled(self) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def set_active_adapters(self, requests: Set[Any], def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None: mapping: Optional[Any]) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def add_adapter(self, adapter_request: Any) -> bool: def add_adapter(self, adapter_request: Any) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_adapter(self, adapter_id: int) -> bool: def remove_adapter(self, adapter_id: int) -> bool:
... raise NotImplementedError
@abstractmethod @abstractmethod
def remove_all_adapters(self): def remove_all_adapters(self) -> None:
... raise NotImplementedError
@abstractmethod @abstractmethod
def list_adapters(self) -> Set[int]: def list_adapters(self) -> Set[int]:
... raise NotImplementedError

View File

@ -724,7 +724,7 @@ class ParallelConfig:
backend) backend)
self._verify_args() self._verify_args()
self.rank = 0 self.rank: int = 0
@property @property
def use_ray(self) -> bool: def use_ray(self) -> bool:
@ -850,6 +850,7 @@ class SchedulerConfig:
class DeviceConfig: class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":

View File

@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -40,7 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
BaseTokenizerGroup,
get_tokenizer_group) get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
@ -477,13 +476,12 @@ class LLMEngine:
return self.tokenizer return self.tokenizer
def get_tokenizer( def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request) return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self, def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
sequence: Sequence) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer( return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request) sequence.lora_request)

View File

@ -5,7 +5,6 @@ from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -30,6 +29,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
@ -49,8 +49,6 @@ class LoRAModulePath:
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest] EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict): class TextTokensPrompt(TypedDict):
prompt: str prompt: str

View File

@ -4,9 +4,10 @@ import asyncio
import os import os
import signal import signal
import sys import sys
from typing import Optional from typing import List, Optional
from openai import OpenAI from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
@ -63,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
def chat(system_prompt: Optional[str], model_name: str, def chat(system_prompt: Optional[str], model_name: str,
client: OpenAI) -> None: client: OpenAI) -> None:
conversation = [] conversation: List[ChatCompletionMessageParam] = []
if system_prompt is not None: if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt}) conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:") print("Please enter a message for the chat model:")
while True: while True:
input_message = input("> ") input_message = input("> ")
message = {"role": "user", "content": input_message} conversation.append({"role": "user", "content": input_message})
conversation.append(message)
chat_completion = client.chat.completions.create(model=model_name, chat_completion = client.chat.completions.create(model=model_name,
messages=conversation) messages=conversation)
@ -79,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
response_message = chat_completion.choices[0].message response_message = chat_completion.choices[0].message
output = response_message.content output = response_message.content
conversation.append(response_message) conversation.append(response_message) # type: ignore
print(output) print(output)

View File

@ -37,6 +37,8 @@ class Detokenizer:
The prompt logprobs with the decoded tokens. The prompt logprobs with the decoded tokens.
""" """
prms = seq_group.sampling_params prms = seq_group.sampling_params
assert prms is not None
# We can pick any sequence for the prompt. # We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values())) seq = next(iter(seq_group.seqs_dict.values()))
# Only prompt, without the generated token. # Only prompt, without the generated token.

View File

@ -2,10 +2,9 @@ from typing import Optional, Type
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup) from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( from .tokenizer_group import TokenizerGroup
TokenizerGroup)
if ray: if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
@ -34,4 +33,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] __all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]

View File

@ -1,11 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class BaseTokenizerGroup(ABC): class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters.""" """A group of tokenizers that can be used for LoRA adapters."""
@ -47,17 +49,17 @@ class BaseTokenizerGroup(ABC):
@abstractmethod @abstractmethod
def get_lora_tokenizer( def get_lora_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request.""" """Get a tokenizer for a LoRA request."""
pass pass
@abstractmethod @abstractmethod
async def get_lora_tokenizer_async( async def get_lora_tokenizer_async(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request.""" """Get a tokenizer for a LoRA request."""
pass pass

View File

@ -6,18 +6,16 @@ try:
from ray.exceptions import ActorDiedError from ray.exceptions import ActorDiedError
except ImportError: except ImportError:
# For older versions of Ray # For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError from ray.exceptions import RayActorError as ActorDiedError # type: ignore
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray from vllm.executor.ray_utils import ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup) from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( from .tokenizer_group import TokenizerGroup
TokenizerGroup)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -67,7 +65,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
**self._tokenizer_config, ) **self._tokenizer_config, )
self._ray_tokenizer_group_cls = ray.remote( self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options) self._worker_cls).options(**ray_actor_options) # type: ignore
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
self._idle_actors: Optional[asyncio.Queue] = None self._idle_actors: Optional[asyncio.Queue] = None
@ -83,8 +81,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return len(self.tokenizer_actors) return len(self.tokenizer_actors)
def ping(self): def ping(self):
return ray.get( return ray.get([
[actor.ping.remote() for actor in self.tokenizer_actors]) actor.ping.remote() # type: ignore
for actor in self.tokenizer_actors
])
def _ensure_queue_initialized(self): def _ensure_queue_initialized(self):
if self._idle_actors is None: if self._idle_actors is None:
@ -208,15 +208,15 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return self._local_tokenizer_group.get_max_input_len(lora_request) return self._local_tokenizer_group.get_max_input_len(lora_request)
def get_lora_tokenizer( def get_lora_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
return self._local_tokenizer_group.get_lora_tokenizer(lora_request) return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
async def get_lora_tokenizer_async( async def get_lora_tokenizer_async(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
return await self._local_tokenizer_group.get_lora_tokenizer_async( return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request) lora_request)

View File

@ -1,16 +1,14 @@
from typing import List, Optional from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async, get_lora_tokenizer_async,
get_tokenizer) get_tokenizer)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.utils import LRUCache from vllm.utils import LRUCache
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
class TokenizerGroup(BaseTokenizerGroup): class TokenizerGroup(BaseTokenizerGroup):
"""A group of tokenizers that can be used for LoRA adapters.""" """A group of tokenizers that can be used for LoRA adapters."""
@ -22,8 +20,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora self.enable_lora = enable_lora
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
self.lora_tokenizers = LRUCache[PreTrainedTokenizer]( self.lora_tokenizers = LRUCache[AnyTokenizer](
capacity=max_num_seqs) if enable_lora else None capacity=max_num_seqs if enable_lora else 0)
@classmethod @classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
@ -41,7 +39,7 @@ class TokenizerGroup(BaseTokenizerGroup):
return self.max_input_length return self.max_input_length
def _raise_if_input_too_long(self, def _raise_if_input_too_long(self,
encoded_tokens: List[str], encoded_tokens: List[int],
lora_request: Optional[LoRARequest] = None): lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens) input_length = len(encoded_tokens)
if lora_request: if lora_request:
@ -72,9 +70,9 @@ class TokenizerGroup(BaseTokenizerGroup):
return ret return ret
def get_lora_tokenizer( def get_lora_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
if not lora_request or not self.enable_lora: if not lora_request or not self.enable_lora:
return self.tokenizer return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers: if lora_request.lora_int_id not in self.lora_tokenizers:
@ -83,12 +81,12 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer return tokenizer
else: else:
return self.lora_tokenizers.get(lora_request.lora_int_id) return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async( async def get_lora_tokenizer_async(
self, self,
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
if not lora_request or not self.enable_lora: if not lora_request or not self.enable_lora:
return self.tokenizer return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers: if lora_request.lora_int_id not in self.lora_tokenizers:
@ -97,4 +95,4 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer return tokenizer
else: else:
return self.lora_tokenizers.get(lora_request.lora_int_id) return self.lora_tokenizers[lora_request.lora_int_id]

View File

@ -94,8 +94,10 @@ class LRUCache(Generic[T]):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> Optional[T]: def __getitem__(self, key: Hashable) -> T:
return self.get(key) value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: Hashable, value: T) -> None: def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value) self.put(key, value)
@ -109,8 +111,9 @@ class LRUCache(Generic[T]):
def get(self, def get(self,
key: Hashable, key: Hashable,
default_value: Optional[T] = None) -> Optional[T]: default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
if key in self.cache: if key in self.cache:
value: Optional[T] = self.cache[key] value = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
else: else:
value = default_value value = default_value
@ -590,8 +593,8 @@ class CudaMemoryProfiler:
torch.cuda.reset_peak_memory_stats(self.device) torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device) mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu(): elif is_xpu():
torch.xpu.reset_peak_memory_stats(self.device) torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem return mem
def __enter__(self): def __enter__(self):