[Docs] Fix warnings in mkdocs build (continued) (#23743)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
Signed-off-by: Hyogeun Oh (오효근) <ohg3417@gmail.com>
This commit is contained in:
Hyogeun Oh (오효근) 2025-08-28 02:17:29 +09:00 committed by GitHub
parent dd58932280
commit 4e4d017b6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 56 additions and 50 deletions

View File

@ -491,7 +491,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
kv_cache: shape =
[2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.

View File

@ -438,7 +438,8 @@ class FlashAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]

View File

@ -637,11 +637,9 @@ class FlashInferImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape -
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
kv_cache: KV cache tensor with different possible shapes:
- NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
- HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]

View File

@ -689,7 +689,8 @@ class FlexAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]

View File

@ -235,7 +235,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
kv_cache: shape =
[num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@ -329,7 +330,7 @@ def write_to_kv_cache(
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape

View File

@ -429,7 +429,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]

View File

@ -362,7 +362,8 @@ class TreeAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]

View File

@ -285,7 +285,8 @@ class TritonAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]

View File

@ -330,7 +330,8 @@ class XFormersAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]

View File

@ -119,7 +119,8 @@ class KVCacheCoordinator(ABC):
Args:
request: The request.
num_tokens: The total number of tokens that need to be cached
num_computed_tokens: The total number of tokens
that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:

View File

@ -54,14 +54,15 @@ class KVCacheBlocks:
def get_block_ids(
self,
allow_none: bool = False,
):
) -> Optional[tuple[list[int], ...]]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
tuple[list[int], ...]: A tuple of lists where
* the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
tuple[list[int], ...]: A tuple of lists where:
- the outer tuple corresponds to KV cache groups
- each inner list contains the block_ids of the blocks in that
group
"""
if allow_none and all(len(group) == 0 for group in self.blocks):
return None

View File

@ -8,6 +8,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.executor.ray_distributed_executor import ( # noqa
RayDistributedExecutor as RayDistributedExecutorV0)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
@ -64,7 +65,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
def execute_model(
self,
scheduler_output,
scheduler_output: SchedulerOutput,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
"""Execute the model on the Ray workers.

View File

@ -36,7 +36,7 @@ def setup_multiprocess_prometheus():
"and vLLM will properly handle cleanup.")
def get_prometheus_registry():
def get_prometheus_registry() -> CollectorRegistry:
"""Get the appropriate prometheus registry based on multiprocessing
configuration.

View File

@ -91,7 +91,7 @@ class LogitsProcessor(ABC):
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
batch_update: Non-None iff there have been changes
to the batch makeup.
"""
raise NotImplementedError

View File

@ -68,7 +68,7 @@ class RejectionSampler(nn.Module):
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
bonus_token_ids (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens

View File

@ -89,7 +89,7 @@ class Sampler(nn.Module):
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)

View File

@ -525,9 +525,6 @@ class InputBatch:
Any consecutive empty indices at the very end of the list are not
filled.
Args:
empty_req_indices: empty indices which may be filled.
Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation

View File

@ -552,7 +552,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return kv_cache_spec
def _get_slot_mapping_metadata(self, num_reqs,
num_scheduled_tokens_per_req):
num_scheduled_tokens_per_req) -> np.ndarray:
"""
Computes metadata for mapping slots to blocks in the key-value (KV)
cache for a batch of requests.