[Docs] Fix warnings in mkdocs build (continued) (#24791)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
This commit is contained in:
Hyogeun Oh (오효근) 2025-09-13 16:13:44 +09:00 committed by GitHub
parent 5febdc8750
commit 9a8966bcc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 102 additions and 110 deletions

View File

@ -337,11 +337,11 @@ class EplbState:
Args:
model (MixtureOfExperts): The MoE model.
is_dummy (bool): If `True`, this is a dummy step and the load
metrics recorded in this forward pass will not count. Defaults
to `False`.
metrics recorded in this forward pass will not count. Defaults
to `False`.
is_profile (bool): If `True`, perform a dummy rearrangement
with maximum communication cost. This is used in `profile_run`
to reserve enough memory for the communication buffer.
with maximum communication cost. This is used in `profile_run`
to reserve enough memory for the communication buffer.
log_stats (bool): If `True`, log the expert load metrics.
# Stats

View File

@ -102,14 +102,14 @@ def rebalance_experts_hierarchical(
num_groups: int,
num_nodes: int,
num_gpus: int,
):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:

View File

@ -149,7 +149,7 @@ class KVConnectorBase_V1(ABC):
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
**kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
@ -182,7 +182,8 @@ class KVConnectorBase_V1(ABC):
@abstractmethod
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to

View File

@ -30,7 +30,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
**kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
@ -61,7 +61,8 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self._lmcache_engine.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""
Start saving the a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to

View File

@ -91,7 +91,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
**kwargs: Any) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
@ -212,7 +212,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
@ -278,7 +279,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
def get_finished(
self, finished_req_ids: set[str],
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
**kwargs: Any) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.

View File

@ -218,8 +218,9 @@ class TensorMemoryPool:
return addr
def load_tensor(self, addr: int, dtype: torch.dtype,
shape: tuple[int, ...], device) -> torch.Tensor:
def load_tensor(self, addr: int, dtype: torch.dtype, shape: tuple[int,
...],
device: torch.device) -> torch.Tensor:
"""Loads a tensor from pinned host memory to the specified device.
Args:

View File

@ -3,7 +3,7 @@
import hashlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
@ -90,7 +90,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
**kwargs: Any) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
@ -191,7 +191,8 @@ class SharedStorageConnector(KVConnectorBase_V1):
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
attn_metadata: "AttentionMetadata",
**kwargs: Any) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.

View File

@ -251,8 +251,8 @@ class PyNcclPipe(KVPipeBase):
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Args:
tensor: The received tensor, or `None` if no tensor is received.
Returns:
The received tensor, or `None` if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)

View File

@ -823,7 +823,7 @@ class SupportsEagle3(Protocol):
Args:
layers: Tuple of layer indices that should output auxiliary
hidden states.
hidden states.
"""
...

View File

@ -1520,15 +1520,9 @@ class BaseKeyeModule(nn.Module):
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
otherwise it will be `(seq_len,)`.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
"""
if intermediate_tensors is not None:
inputs_embeds = None

View File

@ -58,17 +58,18 @@ def split_thw(grid_thw: torch.Tensor) -> torch.Tensor:
return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0)
def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int],
torch.Tensor]):
def get_num_patches(grid_thw: torch.Tensor,
num_frames: Union[list[int], torch.Tensor]) -> list[int]:
"""
Return num_patches per video.
Args:
t: tensor with shape [N, ...] where each item is a list/tensor
cu_seqlens: list indicating the boundaries of groups
grid_thw: Tensor with shape [N, 3] containing temporal, height, width
dimensions
num_frames: List or tensor indicating the number of frames per video
Returns:
list of ints representing the sum of products for each group
List of ints representing the number of patches for each video
Examples:
>>> # Suppose there are 2 videos with a total of 3 grids

View File

@ -732,7 +732,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
positions: Position indices for the input tokens.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
Info:
[LlavaImageInputs][]

View File

@ -535,8 +535,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each grid patch for each input image.
image_sizes: The original `(height, width)` for each input image.
positions: Position indices for the input tokens.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
Info:
[LlavaNextImageInputs][]

View File

@ -578,7 +578,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
positions: Position indices for the input tokens.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
Info:
[Mistral3ImagePixelInputs][]

View File

@ -387,11 +387,10 @@ class Llama4VisionEncoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if you
want more control over how to convert `input_ids` indices into
hidden_states: Input tensor of shape
(batch_size, sequence_length, hidden_size).
Hidden states from the model embeddings, representing
the input tokens.
associated vectors than the model's internal embedding
lookup matrix.
"""

View File

@ -70,11 +70,15 @@ def multihead_attention(
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
):
) -> torch.Tensor:
"""Multi-head attention using flash attention 2.
Args:
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
The first element should be 0 and the last element should be q.shape[0].
@ -123,8 +127,14 @@ def sdpa_attention(
"""SDPA attention.
Args:
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens: Optional cumulative sequence lengths of q.
k_cu_seqlens: Optional cumulative sequence lengths of k.
"""
seq_length = q.shape[0]
attention_mask = torch.zeros([1, seq_length, seq_length],
@ -387,7 +397,7 @@ class MLP2(nn.Module):
def __init__(self,
dims: list[int],
activation,
bias=True,
bias: bool = True,
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()

View File

@ -374,8 +374,8 @@ class Phi4MMAudioMeanVarianceNormLayer(nn.Module):
Typically used as a very first layer in a model.
Args:
input_size: int
layer input size.
config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig)
object containing model parameters.
"""
def __init__(self, config: Phi4MultimodalAudioConfig):

View File

@ -1372,15 +1372,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
otherwise it will be `(seq_len,)`.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
"""
if intermediate_tensors is not None:

View File

@ -390,12 +390,9 @@ class Siglip2EncoderLayer(nn.Module):
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all
attention layers. See `attentions` under
returned tensors for more detail.
hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
cu_seqlens: Cumulative sequence lengths tensor.
position_embeddings: Position embeddings tensor.
"""
residual = hidden_states
@ -534,19 +531,11 @@ class Siglip2Encoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if
you want more control over how to convert `input_ids` indices
into associated vectors than the model's internal embedding
lookup matrix.
grid_thws (`torch.LongTensor`):
grid shape (num_patches, 3)
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See
`hidden_states` under returned tensors for more detail.
return_dict (`bool`, *optional*):
inputs_embeds: Input tensor of shape
(batch_size, sequence_length, hidden_size).
Embedded representation of the input tokens.
grid_thws: Grid tensor of shape (num_patches, 3)
containing grid dimensions.
Whether or not to return a [`~utils.ModelOutput`] instead of
a plain tuple.
"""

View File

@ -597,10 +597,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
with the `input_ids`.
Args:
audio_features: A batch of audio input chunks [B, N, 80, M].
audio_lens: Length of audio frames for each audio chunk [B].
audio_token_len: Length of audio tokens for each audio chunk [B'].
Note: batch dim is different from batch dim in audio chunks.
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Position indices for the input tokens.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
"""

View File

@ -909,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
prefix: Optional prefix for parameter names
Raises:
AssertionError: If prefix caching is enabled
(not supported by Mamba)
AssertionError: If prefix caching is enabled
(not supported by Mamba)
"""
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config

View File

@ -679,20 +679,21 @@ def get_hf_file_to_dict(file_name: str,
@cache
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
def get_pooling_config(model: str,
revision: Optional[str] = 'main') -> Optional[dict]:
"""
This function gets the pooling and normalize
config from the model - only applies to
sentence-transformers models.
Args:
model (str): The name of the Hugging Face model.
revision (str, optional): The specific version
of the model to use. Defaults to 'main'.
model: The name of the Hugging Face model.
revision: The specific version of the model to use.
Defaults to 'main'.
Returns:
dict: A dictionary containing the pooling
type and whether normalization is used.
A dictionary containing the pooling type and whether
normalization is used, or None if no pooling configuration is found.
"""
modules_file_name = "modules.json"

View File

@ -74,10 +74,10 @@ class JAISConfig(PretrainedConfig):
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models).
scale_attn_by_inverse_layer_idx (`bool`, *optional*,
defaults to `False`):
Whether to additionally scale attention weights by
`1 / layer_idx + 1`.
scale_attn_by_inverse_layer_idx
(`bool`, *optional*, defaults to `False`):
Whether to additionally scale attention weights
by `1 / layer_idx + 1`.
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
Whether to scale keys (K) prior to computing attention
(dot-product)

View File

@ -37,10 +37,6 @@ class UltravoxConfig(transformers.PretrainedConfig):
The initialization value for the layer normalization.
projector_act (`str`, *optional*, defaults to `"swiglu"`):
The activation function used by the multimodal projector.
text_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
projector_ln_mid (`bool`, *optional*, defaults to `False`):
Whether to apply layer normalization at the middle of the
projector or at the end. Versions v0.4.1 and below

View File

@ -25,6 +25,7 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import math
from typing import Any
import torch
import torchvision.transforms as T
@ -178,17 +179,15 @@ class DeepseekVLV2Processor(ProcessorMixin):
prompt: str,
images: list[Image.Image],
inference_mode: bool = True,
**kwargs,
**kwargs: Any,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (list[dict]): conversations with a list of messages;
images (list[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
**kwargs: Additional keyword arguments.
Returns:
outputs (BaseProcessorOutput): the output of the processor,
@ -259,7 +258,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
text: str,
images: list[Image.Image],
inference_mode: bool = True,
**kwargs,
**kwargs: Any,
):
"""

View File

@ -33,7 +33,6 @@ def list_safetensors(path: str = "") -> list[str]:
Args:
path: The object storage path to list from.
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full object storage paths allowed by the pattern
@ -54,8 +53,7 @@ class ObjectStorageModel:
dir: The temporary created directory.
Methods:
pull_files(): Pull model from object storage to the temporary
directory.
pull_files(): Pull model from object storage to the temporary directory.
"""
def __init__(self) -> None:

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch
from typing import Optional
from typing import Any, Optional
from vllm.utils import PlaceholderModule
@ -26,7 +26,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
]
def glob(s3=None,
def glob(s3: Optional[Any] = None,
path: str = "",
allow_pattern: Optional[list[str]] = None) -> list[str]:
"""
@ -51,7 +51,7 @@ def glob(s3=None,
def list_files(
s3,
s3: Any,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None