[Docs] Fix warnings in mkdocs build (continued) (#24740)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
This commit is contained in:
Hyogeun Oh (오효근) 2025-09-12 22:43:15 +09:00 committed by GitHub
parent bcb06d7baf
commit 41f17bf290
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 121 additions and 176 deletions

View File

@ -144,9 +144,9 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""Quantize a Tensor with torchao quantization specified by torchao_config """Quantize a Tensor with torchao quantization specified by torchao_config
Args: Args:
`param`: weight parameter of the linear module param: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to torchao_config: type of quantization and their arguments we want to
use to quantize the Tensor use to quantize the Tensor
""" """
from torchao.core.config import AOBaseConfig from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_ from torchao.quantization import quantize_
@ -172,8 +172,8 @@ class TorchAOLinearMethod(LinearMethodBase):
"""Linear method for torchao. """Linear method for torchao.
Args: Args:
torchao_config: The torchao quantization config, a string quant_config: The torchao quantization config, a string that encodes
that encodes the type of quantization and all relevant arguments. the type of quantization and all relevant arguments.
""" """
def __init__(self, quant_config: TorchAOConfig): def __init__(self, quant_config: TorchAOConfig):

View File

@ -423,7 +423,7 @@ def w8a8_block_int8_matmul(
Bs: The per-block quantization scale for `B`. Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be block_size: The block size for per-block quantization. It should be
2-dim, e.g., [128, 128]. 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor. output_dtype: The dtype of the returned tensor.
Returns: Returns:
torch.Tensor: The result of matmul. torch.Tensor: The result of matmul.

View File

@ -135,8 +135,8 @@ def triton_mrope(
"""Qwen2VL mrope kernel. """Qwen2VL mrope kernel.
Args: Args:
query: [num_tokens, num_heads * head_size] q: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size] k: [num_tokens, num_kv_heads * head_size]
cos: [3, num_tokens, head_size //2 ] cos: [3, num_tokens, head_size //2 ]
(T/H/W positions with multimodal inputs) (T/H/W positions with multimodal inputs)
sin: [3, num_tokens, head_size //2 ] sin: [3, num_tokens, head_size //2 ]

View File

@ -171,51 +171,52 @@ class TensorizerConfig(MutableMapping):
_is_sharded: bool = field(init=False, default=False) _is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]] _fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]] _keys: ClassVar[frozenset[str]]
""" """Configuration class for Tensorizer settings.
Args for the TensorizerConfig class. These are used to configure the
behavior of model serialization and deserialization using Tensorizer.
Args: These settings configure the behavior of model serialization and
tensorizer_uri: Path to serialized model tensors. Can be a local file deserialization using Tensorizer.
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the Attributes:
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or tensorizer_uri: Path to serialized model tensors. Can be a local file
`lora_dir` is passed to this object's initializer, this is a required path or a S3 URI. This is a required field unless lora_dir is
argument. provided and the config is meant to be used for the
tensorizer_dir: Path to a directory containing serialized model tensors, `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
and all other potential model artifacts to load the model, such as `lora_dir` is passed to this object's initializer, this is
configs and tokenizer files. Can be passed instead of `tensorizer_uri` a required argument.
where the `model.tensors` file will be assumed to be in this tensorizer_dir: Path to a directory containing serialized model tensors,
directory. and all other potential model artifacts to load the model, such as
vllm_tensorized: If True, indicates that the serialized model is a configs and tokenizer files. Can be passed instead of
vLLM model. This is used to determine the behavior of the `tensorizer_uri` where the `model.tensors` file will be assumed
TensorDeserializer when loading tensors from a serialized model. to be in this directory.
It is far faster to deserialize a vLLM model as it utilizes vllm_tensorized: If True, indicates that the serialized model is a
tensorizer's optimized GPU loading. Note that this is now vLLM model. This is used to determine the behavior of the
deprecated, as serialized vLLM models are now automatically TensorDeserializer when loading tensors from a serialized model.
inferred as vLLM models. It is far faster to deserialize a vLLM model as it utilizes
verify_hash: If True, the hashes of each tensor will be verified against tensorizer's optimized GPU loading. Note that this is now
the hashes stored in the metadata. A `HashMismatchError` will be deprecated, as serialized vLLM models are now automatically
raised if any of the hashes do not match. inferred as vLLM models.
num_readers: Controls how many threads are allowed to read concurrently verify_hash: If True, the hashes of each tensor will be verified
from the source file. Default is `None`, which will dynamically set against the hashes stored in the metadata. A `HashMismatchError`
the number of readers based on the number of available will be raised if any of the hashes do not match.
resources and model size. This greatly increases performance. num_readers: Controls how many threads are allowed to read concurrently
encryption_keyfile: File path to a binary file containing a from the source file. Default is `None`, which will dynamically set
binary key to use for decryption. `None` (the default) means the number of readers based on the number of available
no decryption. See the example script in resources and model size. This greatly increases performance.
examples/others/tensorize_vllm_model.py. encryption_keyfile: File path to a binary file containing a
s3_access_key_id: The access key for the S3 bucket. Can also be set via binary key to use for decryption. `None` (the default) means
the S3_ACCESS_KEY_ID environment variable. no decryption. See the example script in
s3_secret_access_key: The secret access key for the S3 bucket. Can also examples/others/tensorize_vllm_model.py.
be set via the S3_SECRET_ACCESS_KEY environment variable. s3_access_key_id: The access key for the S3 bucket. Can also be set via
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the the S3_ACCESS_KEY_ID environment variable.
S3_ENDPOINT_URL environment variable. s3_secret_access_key: The secret access key for the S3 bucket. Can also
lora_dir: Path to a directory containing LoRA adapter artifacts for be set via the S3_SECRET_ACCESS_KEY environment variable.
serialization or deserialization. When serializing LoRA adapters s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
this is the only necessary parameter to pass to this object's S3_ENDPOINT_URL environment variable.
initializer. lora_dir: Path to a directory containing LoRA adapter artifacts for
""" serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
def __post_init__(self): def __post_init__(self):
# check if the configuration is for a sharded vLLM model # check if the configuration is for a sharded vLLM model

View File

@ -143,16 +143,8 @@ class AriaProjector(nn.Module):
projects ViT's outputs into MoE's inputs. projects ViT's outputs into MoE's inputs.
Args: Args:
patch_to_query_dict (dict): Maps patch numbers to their corresponding config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
query numbers, containing projector configuration parameters.
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
based on image resolution.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
Outputs: Outputs:
A tensor with the shape of (batch_size, query_number, output_dim) A tensor with the shape of (batch_size, query_number, output_dim)
@ -282,8 +274,8 @@ class AriaTextMoELayer(nn.Module):
Forward pass of the MoE Layer. Forward pass of the MoE Layer.
Args: Args:
hidden_states (torch.Tensor): Input tensor of shape (batch_size, hidden_states: Input tensor of shape
sequence_length, hidden_size). (batch_size, sequence_length, hidden_size).
Returns: Returns:
torch.Tensor: Output tensor after passing through the MoE layer. torch.Tensor: Output tensor after passing through the MoE layer.

View File

@ -401,8 +401,7 @@ class BartEncoderLayer(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r""" r"""
Args: Args:
hidden_states hidden_states: torch.Tensor of *encoder* input embeddings.
torch.Tensor of *encoder* input embeddings.
Returns: Returns:
Encoder layer output torch.Tensor Encoder layer output torch.Tensor
""" """
@ -490,10 +489,8 @@ class BartDecoderLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
decoder_hidden_states decoder_hidden_states: torch.Tensor of *decoder* input embeddings.
torch.Tensor of *decoder* input embeddings. encoder_hidden_states: torch.Tensor of *encoder* input embeddings.
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
Returns: Returns:
Decoder layer output torch.Tensor Decoder layer output torch.Tensor
""" """
@ -584,12 +581,10 @@ class BartEncoder(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: Indices of *encoder* input sequence tokens in the
Indices of *encoder* input sequence tokens in the vocabulary. vocabulary.
Padding will be ignored by default should you Padding will be ignored by default should you provide it.
provide it. positions: Positions of *encoder* input sequence tokens.
positions
Positions of *encoder* input sequence tokens.
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
@ -663,14 +658,11 @@ class BartDecoder(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
decoder_input_ids decoder_input_ids: Indices of *decoder* input sequence tokens
Indices of *decoder* input sequence tokens in the vocabulary. in the vocabulary.
Padding will be ignored by default should you Padding will be ignored by default should you provide it.
provide it. decoder_positions: Positions of *decoder* input sequence tokens.
decoder_positions encoder_hidden_states: Tensor of encoder output embeddings.
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
@ -732,16 +724,13 @@ class BartModel(nn.Module, SupportsQuant):
encoder_positions: torch.Tensor) -> torch.Tensor: encoder_positions: torch.Tensor) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: Indices of *decoder* input sequence tokens
Indices of *decoder* input sequence tokens in the vocabulary. in the vocabulary.
Padding will be ignored by default should you Padding will be ignored by default should you provide it.
provide it. positions: Positions of *decoder* input sequence tokens.
positions encoder_input_ids: Indices of *encoder* input sequence tokens
Positions of *decoder* input sequence tokens. in the vocabulary.
encoder_input_ids encoder_positions: Positions of *encoder* input sequence tokens.
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
Returns: Returns:
Model output torch.Tensor Model output torch.Tensor
""" """
@ -848,14 +837,10 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: torch.Tensor of *decoder* input token ids.
torch.Tensor of *decoder* input token ids. positions: torch.Tensor of *decoder* position indices.
positions encoder_input_ids: torch.Tensor of *encoder* input token ids.
torch.Tensor of *decoder* position indices. encoder_positions: torch.Tensor of *encoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
@ -912,8 +897,7 @@ class MBartEncoderLayer(BartEncoderLayer):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r""" r"""
Args: Args:
hidden_states hidden_states: torch.Tensor of *encoder* input embeddings.
torch.Tensor of *encoder* input embeddings.
Returns: Returns:
Encoder layer output torch.Tensor Encoder layer output torch.Tensor
""" """
@ -1035,12 +1019,10 @@ class MBartEncoder(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: Indices of *encoder* input sequence tokens in the
Indices of *encoder* input sequence tokens in the vocabulary. vocabulary.
Padding will be ignored by default should you Padding will be ignored by default should you provide it.
provide it. positions: Positions of *encoder* input sequence tokens.
positions
Positions of *encoder* input sequence tokens.
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
@ -1116,14 +1098,11 @@ class MBartDecoder(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
decoder_input_ids decoder_input_ids: Indices of *decoder* input sequence tokens
Indices of *decoder* input sequence tokens in the vocabulary. in the vocabulary.
Padding will be ignored by default should you Padding will be ignored by default should you provide it.
provide it. decoder_positions: Positions of *decoder* input sequence tokens.
decoder_positions encoder_hidden_states: Tensor of encoder output embeddings.
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
Returns: Returns:
Decoder output torch.Tensor Decoder output torch.Tensor
""" """
@ -1185,16 +1164,13 @@ class MBartModel(nn.Module, SupportsQuant):
encoder_positions: torch.Tensor) -> torch.Tensor: encoder_positions: torch.Tensor) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: Indices of *decoder* input sequence tokens
Indices of *decoder* input sequence tokens in the vocabulary. in the vocabulary.
Padding will be ignored by default should you Padding will be ignored by default should you provide it.
provide it. positions: Positions of *decoder* input sequence tokens.
positions encoder_input_ids: Indices of *encoder* input sequence tokens
Positions of *decoder* input sequence tokens. in the vocabulary.
encoder_input_ids encoder_positions: Positions of *encoder* input sequence tokens.
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
Returns: Returns:
Model output torch.Tensor Model output torch.Tensor
""" """

View File

@ -678,7 +678,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
Args: Args:
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values: The pixels in each input image.
Info: Info:
[Blip2ImageInputs][] [Blip2ImageInputs][]

View File

@ -79,10 +79,8 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: torch.Tensor of *decoder* input token ids.
torch.Tensor of *decoder* input token ids. positions: torch.Tensor of *decoder* position indices.
positions
torch.Tensor of *decoder* position indices.
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
@ -351,14 +349,10 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: torch.Tensor of *decoder* input token ids.
torch.Tensor of *decoder* input token ids. positions: torch.Tensor of *decoder* position indices.
positions encoder_input_ids: torch.Tensor of *encoder* input token ids.
torch.Tensor of *decoder* position indices. encoder_positions: torch.Tensor of *encoder* position indices
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """

View File

@ -631,16 +631,14 @@ class Florence2LanguageModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: Indices of *decoder* input sequence tokens
Indices of *decoder* input sequence tokens in the vocabulary. in the vocabulary.
Padding will be ignored by default should you Padding will be ignored by default should you
provide it. provide it.
positions positions: Positions of *decoder* input sequence tokens.
Positions of *decoder* input sequence tokens. encoder_input_ids: Indices of *encoder* input sequence tokens
encoder_input_ids in the vocabulary.
Indices of *encoder* input sequence tokens in the vocabulary. encoder_positions: Positions of *encoder* input sequence tokens.
encoder_positions:
Positions of *encoder* input sequence tokens.
Returns: Returns:
Model output torch.Tensor Model output torch.Tensor
""" """
@ -699,14 +697,10 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: torch.Tensor of *decoder* input token ids.
torch.Tensor of *decoder* input token ids. positions: torch.Tensor of *decoder* position indices.
positions encoder_input_ids: torch.Tensor of *encoder* input token ids.
torch.Tensor of *decoder* position indices. encoder_positions: torch.Tensor of *encoder* position indices
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """
@ -1068,14 +1062,10 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
input_ids input_ids: torch.Tensor of *decoder* input token ids.
torch.Tensor of *decoder* input token ids. positions: torch.Tensor of *decoder* position indices.
positions encoder_input_ids: torch.Tensor of *encoder* input token ids.
torch.Tensor of *decoder* position indices. encoder_positions: torch.Tensor of *encoder* position indices
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
Returns: Returns:
Output torch.Tensor Output torch.Tensor
""" """

View File

@ -1599,17 +1599,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
**NOTE**: If mrope is enabled (default setting for GLM-4V **NOTE**: If mrope is enabled (default setting for GLM-4V
opensource models), the shape will be `(3, seq_len)`, opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model. intermediate_tensors: Optional intermediate tensors for pipeline
`None` if no images are passed. parallelism.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. inputs_embeds: Optional pre-computed input embeddings.
`None` if no images are passed. **kwargs: Additional keyword arguments.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
second_per_grid_ts: Tensor `(num_videos)` of video time interval (
in seconds) for each grid along the temporal dimension in the
3D position IDs. `None` if no videos are passed.
""" """
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None