[Docs] Fix warnings in mkdocs build (continued) (#23743)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
Signed-off-by: Hyogeun Oh (오효근) <ohg3417@gmail.com>
This commit is contained in:
Hyogeun Oh (오효근) 2025-08-28 02:17:29 +09:00 committed by GitHub
parent dd58932280
commit 4e4d017b6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 56 additions and 50 deletions

View File

@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator):
Args: Args:
absolute_id (int): The absolute block id for the block absolute_id (int): The absolute block id for the block
in whole allocator. in whole allocator.
Returns: Returns:
int: The zero-offset block id on certain device. int: The zero-offset block id on certain device.

View File

@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Args: Args:
num_blocks (int): The total number of blocks to manage. num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens. block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1. from 0 to num_blocks - 1.
""" """

View File

@ -657,7 +657,7 @@ class Scheduler:
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
partial_prefill_metadata: information about the partial prefills partial_prefill_metadata: information about the partial prefills
that are currently running that are currently running
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.

View File

@ -491,7 +491,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache: shape =
[2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0] NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run. for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.

View File

@ -438,7 +438,8 @@ class FlashAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]

View File

@ -637,11 +637,9 @@ class FlashInferImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape - kv_cache: KV cache tensor with different possible shapes:
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]

View File

@ -689,7 +689,8 @@ class FlexAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]

View File

@ -235,7 +235,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] kv_cache: shape =
[num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
@ -329,7 +330,7 @@ def write_to_kv_cache(
Args: Args:
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int num_slices_per_kv_cache_update_block: int
""" """
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape _, page_size, num_combined_kv_heads, head_size = kv_cache.shape

View File

@ -429,7 +429,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]

View File

@ -362,7 +362,8 @@ class TreeAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]

View File

@ -285,7 +285,8 @@ class TritonAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]

View File

@ -330,7 +330,8 @@ class XFormersAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]

View File

@ -255,9 +255,9 @@ def compute_encoder_budget(
Returns: Returns:
- Compute budget for encoder execution, measured in number of tokens - Compute budget for encoder execution, measured in number of tokens
from the input sequence. from the input sequence.
- Space budget for encoder cache size, measured in number of tokens - Space budget for encoder cache size, measured in number of tokens
from the input sequence. from the input sequence.
""" """
if mm_registry.supports_multimodal_inputs(model_config): if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry \ max_tokens_by_modality = mm_registry \
@ -303,9 +303,9 @@ def compute_mm_encoder_budget(
Returns: Returns:
- Compute budget for encoder execution, measured in number of tokens - Compute budget for encoder execution, measured in number of tokens
from the input sequence. from the input sequence.
- Space budget for encoder cache size, measured in number of tokens - Space budget for encoder cache size, measured in number of tokens
from the input sequence. from the input sequence.
""" """
if not max_tokens_by_modality: if not max_tokens_by_modality:

View File

@ -119,7 +119,8 @@ class KVCacheCoordinator(ABC):
Args: Args:
request: The request. request: The request.
num_tokens: The total number of tokens that need to be cached num_computed_tokens: The total number of tokens
that need to be cached
(including tokens that are already cached). (including tokens that are already cached).
""" """
for manager in self.single_type_managers: for manager in self.single_type_managers:

View File

@ -54,14 +54,15 @@ class KVCacheBlocks:
def get_block_ids( def get_block_ids(
self, self,
allow_none: bool = False, allow_none: bool = False,
): ) -> Optional[tuple[list[int], ...]]:
""" """
Converts the KVCacheBlocks instance to block_ids. Converts the KVCacheBlocks instance to block_ids.
Returns: Returns:
tuple[list[int], ...]: A tuple of lists where tuple[list[int], ...]: A tuple of lists where:
* the outer tuple corresponds to KV cache groups - the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group - each inner list contains the block_ids of the blocks in that
group
""" """
if allow_none and all(len(group) == 0 for group in self.blocks): if allow_none and all(len(group) == 0 for group in self.blocks):
return None return None

View File

@ -8,6 +8,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.ray_distributed_executor import ( # noqa from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0) RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
@ -64,7 +65,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
def execute_model( def execute_model(
self, self,
scheduler_output, scheduler_output: SchedulerOutput,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
"""Execute the model on the Ray workers. """Execute the model on the Ray workers.

View File

@ -36,7 +36,7 @@ def setup_multiprocess_prometheus():
"and vLLM will properly handle cleanup.") "and vLLM will properly handle cleanup.")
def get_prometheus_registry(): def get_prometheus_registry() -> CollectorRegistry:
"""Get the appropriate prometheus registry based on multiprocessing """Get the appropriate prometheus registry based on multiprocessing
configuration. configuration.

View File

@ -91,7 +91,7 @@ class LogitsProcessor(ABC):
to each forward pass. to each forward pass.
Args: Args:
batch_update is non-None iff there have been batch_update: Non-None iff there have been changes
changes to the batch makeup. to the batch makeup.
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -68,7 +68,7 @@ class RejectionSampler(nn.Module):
different requests are flattened into a single tensor because different requests are flattened into a single tensor because
this is the shape of the output logits. this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory. NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor): bonus_token_ids (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1]. A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens proposed tokens are accepted. We generate the bonus tokens

View File

@ -89,7 +89,7 @@ class Sampler(nn.Module):
Gather logprobs for topk and sampled/prompt token. Gather logprobs for topk and sampled/prompt token.
Args: Args:
logits: (num tokens) x (vocab) tensor logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to num_logprobs: minimum number of logprobs to
retain per token retain per token
token_ids: prompt tokens (if prompt logprobs) token_ids: prompt tokens (if prompt logprobs)

View File

@ -110,7 +110,7 @@ class StructuredOutputBackend(ABC):
Args: Args:
request_type (StructuredOutputOptions): The type of structured request_type (StructuredOutputOptions): The type of structured
output request. output request.
grammar_spec (str): The grammar specification to compile. grammar_spec (str): The grammar specification to compile.
Returns: Returns:
@ -124,7 +124,7 @@ class StructuredOutputBackend(ABC):
Args: Args:
max_num_seqs (int): The maximum number of sequences for which max_num_seqs (int): The maximum number of sequences for which
to allocate the bitmask. to allocate the bitmask.
""" """
@abstractmethod @abstractmethod

View File

@ -525,9 +525,6 @@ class InputBatch:
Any consecutive empty indices at the very end of the list are not Any consecutive empty indices at the very end of the list are not
filled. filled.
Args:
empty_req_indices: empty indices which may be filled.
Returns: Returns:
swaps: list of (from,to) swap tuples for moved requests swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation empty_req_indices: indices not filled by condensation

View File

@ -2955,7 +2955,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args: Args:
kv_cache_config: The KV cache config kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape. correct size but uninitialized shape.
Returns: Returns:
Dict[str, torch.Tensor]: A map between layer names to their Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache. corresponding memory buffer for KV cache.

View File

@ -552,7 +552,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return kv_cache_spec return kv_cache_spec
def _get_slot_mapping_metadata(self, num_reqs, def _get_slot_mapping_metadata(self, num_reqs,
num_scheduled_tokens_per_req): num_scheduled_tokens_per_req) -> np.ndarray:
""" """
Computes metadata for mapping slots to blocks in the key-value (KV) Computes metadata for mapping slots to blocks in the key-value (KV)
cache for a batch of requests. cache for a batch of requests.
@ -565,15 +565,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args: Args:
num_reqs (int): Number of requests in the current batch. num_reqs (int): Number of requests in the current batch.
num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
to be scheduled for each request. to be scheduled for each request.
Returns: Returns:
np.ndarray: A 2D array of shape (total_block_len, 3), where each row np.ndarray: A 2D array of shape (total_block_len, 3), where each row
contains: contains:
- kv_cache_start_index (int): The starting index in the KV cache - kv_cache_start_index (int): The starting index in the KV cache
for the corresponding slice. for the corresponding slice.
- new_kv_start_index (int): The starting index in the new KV - new_kv_start_index (int): The starting index in the new KV
cache for the corresponding slice. cache for the corresponding slice.
- slice_len (int): The length of the slice. - slice_len (int): The length of the slice.
""" """
slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]

View File

@ -172,10 +172,10 @@ def scatter_mm_placeholders(
Args: Args:
embeds: The multimodal embeddings. embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)` Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings. tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)` Shape: `(num_placeholders, num_embeds)`
""" """
if is_embed is None: if is_embed is None:
return embeds return embeds
@ -278,7 +278,7 @@ def bind_kv_cache(
Args: Args:
kv_caches: The allocated kv_caches with layer names as keys. kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention forward_context: The global forward context containing all Attention
layers with layer names as keys. layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner. runner_kv_caches: The kv_cache declared by ModelRunner.
""" """
# Bind kv_caches to ModelRunner # Bind kv_caches to ModelRunner

View File

@ -37,7 +37,7 @@ class WorkerBase(WorkerBaseV0):
rank: Global rank in distributed setup rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver is_driver_worker: Whether this worker handles driver
responsibilities responsibilities
""" """
# Configuration storage # Configuration storage
super().__init__(vllm_config=vllm_config) super().__init__(vllm_config=vllm_config)