[Docs] Fix warnings in mkdocs build (continued) (#24740)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
This commit is contained in:
Hyogeun Oh (오효근) 2025-09-12 22:43:15 +09:00 committed by GitHub
parent bcb06d7baf
commit 41f17bf290
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 121 additions and 176 deletions

View File

@ -144,9 +144,9 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to
use to quantize the Tensor
param: weight parameter of the linear module
torchao_config: type of quantization and their arguments we want to
use to quantize the Tensor
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
@ -172,8 +172,8 @@ class TorchAOLinearMethod(LinearMethodBase):
"""Linear method for torchao.
Args:
torchao_config: The torchao quantization config, a string
that encodes the type of quantization and all relevant arguments.
quant_config: The torchao quantization config, a string that encodes
the type of quantization and all relevant arguments.
"""
def __init__(self, quant_config: TorchAOConfig):

View File

@ -423,7 +423,7 @@ def w8a8_block_int8_matmul(
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be
2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
output_dtype: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.

View File

@ -135,8 +135,8 @@ def triton_mrope(
"""Qwen2VL mrope kernel.
Args:
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
q: [num_tokens, num_heads * head_size]
k: [num_tokens, num_kv_heads * head_size]
cos: [3, num_tokens, head_size //2 ]
(T/H/W positions with multimodal inputs)
sin: [3, num_tokens, head_size //2 ]

View File

@ -171,51 +171,52 @@ class TensorizerConfig(MutableMapping):
_is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]]
"""
Args for the TensorizerConfig class. These are used to configure the
behavior of model serialization and deserialization using Tensorizer.
"""Configuration class for Tensorizer settings.
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is a required
argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of `tensorizer_uri`
where the `model.tensors` file will be assumed to be in this
directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified against
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
lora_dir: Path to a directory containing LoRA adapter artifacts for
serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
These settings configure the behavior of model serialization and
deserialization using Tensorizer.
Attributes:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is
a required argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of
`tensorizer_uri` where the `model.tensors` file will be assumed
to be in this directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified
against the hashes stored in the metadata. A `HashMismatchError`
will be raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
lora_dir: Path to a directory containing LoRA adapter artifacts for
serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
def __post_init__(self):
# check if the configuration is for a sharded vLLM model

View File

@ -143,16 +143,8 @@ class AriaProjector(nn.Module):
projects ViT's outputs into MoE's inputs.
Args:
patch_to_query_dict (dict): Maps patch numbers to their corresponding
query numbers,
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
based on image resolution.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
containing projector configuration parameters.
Outputs:
A tensor with the shape of (batch_size, query_number, output_dim)
@ -282,8 +274,8 @@ class AriaTextMoELayer(nn.Module):
Forward pass of the MoE Layer.
Args:
hidden_states (torch.Tensor): Input tensor of shape (batch_size,
sequence_length, hidden_size).
hidden_states: Input tensor of shape
(batch_size, sequence_length, hidden_size).
Returns:
torch.Tensor: Output tensor after passing through the MoE layer.

View File

@ -401,8 +401,7 @@ class BartEncoderLayer(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
@ -490,10 +489,8 @@ class BartDecoderLayer(nn.Module):
) -> torch.Tensor:
r"""
Args:
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
decoder_hidden_states: torch.Tensor of *decoder* input embeddings.
encoder_hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
@ -584,12 +581,10 @@ class BartEncoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
@ -663,14 +658,11 @@ class BartDecoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
@ -732,16 +724,13 @@ class BartModel(nn.Module, SupportsQuant):
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
@ -848,14 +837,10 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices.
Returns:
Output torch.Tensor
"""
@ -912,8 +897,7 @@ class MBartEncoderLayer(BartEncoderLayer):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
@ -1035,12 +1019,10 @@ class MBartEncoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
@ -1116,14 +1098,11 @@ class MBartDecoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
@ -1185,16 +1164,13 @@ class MBartModel(nn.Module, SupportsQuant):
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""

View File

@ -678,7 +678,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
Info:
[Blip2ImageInputs][]

View File

@ -79,10 +79,8 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
Returns:
Output torch.Tensor
"""
@ -351,14 +349,10 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""

View File

@ -631,16 +631,14 @@ class Florence2LanguageModel(nn.Module):
) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
@ -699,14 +697,10 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
@ -1068,14 +1062,10 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""

View File

@ -1599,17 +1599,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
**NOTE**: If mrope is enabled (default setting for GLM-4V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
second_per_grid_ts: Tensor `(num_videos)` of video time interval (
in seconds) for each grid along the temporal dimension in the
3D position IDs. `None` if no videos are passed.
intermediate_tensors: Optional intermediate tensors for pipeline
parallelism.
inputs_embeds: Optional pre-computed input embeddings.
**kwargs: Additional keyword arguments.
"""
if intermediate_tensors is not None:
inputs_embeds = None