[Docs] Fix warnings in mkdocs build (continued) (#24740)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
This commit is contained in:
Hyogeun Oh (오효근) 2025-09-12 22:43:15 +09:00 committed by GitHub
parent bcb06d7baf
commit 41f17bf290
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 121 additions and 176 deletions

View File

@ -144,8 +144,8 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to
param: weight parameter of the linear module
torchao_config: type of quantization and their arguments we want to
use to quantize the Tensor
"""
from torchao.core.config import AOBaseConfig
@ -172,8 +172,8 @@ class TorchAOLinearMethod(LinearMethodBase):
"""Linear method for torchao.
Args:
torchao_config: The torchao quantization config, a string
that encodes the type of quantization and all relevant arguments.
quant_config: The torchao quantization config, a string that encodes
the type of quantization and all relevant arguments.
"""
def __init__(self, quant_config: TorchAOConfig):

View File

@ -423,7 +423,7 @@ def w8a8_block_int8_matmul(
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be
2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
output_dtype: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.

View File

@ -135,8 +135,8 @@ def triton_mrope(
"""Qwen2VL mrope kernel.
Args:
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
q: [num_tokens, num_heads * head_size]
k: [num_tokens, num_kv_heads * head_size]
cos: [3, num_tokens, head_size //2 ]
(T/H/W positions with multimodal inputs)
sin: [3, num_tokens, head_size //2 ]

View File

@ -171,22 +171,23 @@ class TensorizerConfig(MutableMapping):
_is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]]
"""
Args for the TensorizerConfig class. These are used to configure the
behavior of model serialization and deserialization using Tensorizer.
"""Configuration class for Tensorizer settings.
Args:
These settings configure the behavior of model serialization and
deserialization using Tensorizer.
Attributes:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is a required
argument.
`lora_dir` is passed to this object's initializer, this is
a required argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of `tensorizer_uri`
where the `model.tensors` file will be assumed to be in this
directory.
configs and tokenizer files. Can be passed instead of
`tensorizer_uri` where the `model.tensors` file will be assumed
to be in this directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
@ -194,9 +195,9 @@ class TensorizerConfig(MutableMapping):
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified against
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
verify_hash: If True, the hashes of each tensor will be verified
against the hashes stored in the metadata. A `HashMismatchError`
will be raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available

View File

@ -143,16 +143,8 @@ class AriaProjector(nn.Module):
projects ViT's outputs into MoE's inputs.
Args:
patch_to_query_dict (dict): Maps patch numbers to their corresponding
query numbers,
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
based on image resolution.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
containing projector configuration parameters.
Outputs:
A tensor with the shape of (batch_size, query_number, output_dim)
@ -282,8 +274,8 @@ class AriaTextMoELayer(nn.Module):
Forward pass of the MoE Layer.
Args:
hidden_states (torch.Tensor): Input tensor of shape (batch_size,
sequence_length, hidden_size).
hidden_states: Input tensor of shape
(batch_size, sequence_length, hidden_size).
Returns:
torch.Tensor: Output tensor after passing through the MoE layer.

View File

@ -401,8 +401,7 @@ class BartEncoderLayer(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
@ -490,10 +489,8 @@ class BartDecoderLayer(nn.Module):
) -> torch.Tensor:
r"""
Args:
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
decoder_hidden_states: torch.Tensor of *decoder* input embeddings.
encoder_hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
@ -584,12 +581,10 @@ class BartEncoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
@ -663,14 +658,11 @@ class BartDecoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
@ -732,16 +724,13 @@ class BartModel(nn.Module, SupportsQuant):
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
@ -848,14 +837,10 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices.
Returns:
Output torch.Tensor
"""
@ -912,8 +897,7 @@ class MBartEncoderLayer(BartEncoderLayer):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
@ -1035,12 +1019,10 @@ class MBartEncoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
@ -1116,14 +1098,11 @@ class MBartDecoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
@ -1185,16 +1164,13 @@ class MBartModel(nn.Module, SupportsQuant):
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""

View File

@ -678,7 +678,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
Info:
[Blip2ImageInputs][]

View File

@ -79,10 +79,8 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
Returns:
Output torch.Tensor
"""
@ -351,14 +349,10 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""

View File

@ -631,16 +631,14 @@ class Florence2LanguageModel(nn.Module):
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
@ -699,14 +697,10 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
@ -1068,14 +1062,10 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""

View File

@ -1599,17 +1599,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
**NOTE**: If mrope is enabled (default setting for GLM-4V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
second_per_grid_ts: Tensor `(num_videos)` of video time interval (
in seconds) for each grid along the temporal dimension in the
3D position IDs. `None` if no videos are passed.
intermediate_tensors: Optional intermediate tensors for pipeline
parallelism.
inputs_embeds: Optional pre-computed input embeddings.
**kwargs: Additional keyword arguments.
"""
if intermediate_tensors is not None:
inputs_embeds = None