[Chore] Fix pre-commit error after #25266 (#29190)

This commit is contained in:
Woosuk Kwon 2025-11-21 10:21:40 -08:00 committed by GitHub
parent ceca060501
commit 1bed891f72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 24 deletions

View File

@ -7,6 +7,7 @@ import torch
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
LogprobsTensors,
ModelRunnerOutput,
SamplerOutput,
)
@ -46,15 +47,18 @@ class AsyncOutput(AsyncModelRunnerOutput):
"cpu", non_blocking=True
)
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors = (
self.logprobs_tensors: LogprobsTensors | None = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
else:
self.logprobs_tensors = None
self.prompt_logprobs_dict = {}
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
if v is not None:
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
else:
self.prompt_logprobs_dict[k] = None
self.copy_event.record(self.copy_stream)
def get_output(self) -> ModelRunnerOutput:
@ -64,12 +68,10 @@ class AsyncOutput(AsyncModelRunnerOutput):
# the existing model runner.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids_np = self.sampled_token_ids.numpy()
num_reqs = sampled_token_ids_np.shape[0]
sampled_token_ids: list[np.ndarray] = [
sampled_token_ids_np[i, : self.num_sampled_tokens[i]]
for i in range(num_reqs)
]
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_reqs = len(sampled_token_ids)
for i in range(num_reqs):
del sampled_token_ids[i][self.num_sampled_tokens[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from typing import Any, cast
import torch
@ -13,6 +13,7 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
@ -22,7 +23,8 @@ from vllm.v1.worker.utils import bind_kv_cache
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
for layer_name, attn_module in attn_layers.items():
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(vllm_config):
@ -35,16 +37,15 @@ def init_attn_backend(
vllm_config: VllmConfig,
device: torch.device,
):
attn_backends: dict[str, AttentionBackend] = {}
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
attn_layers = get_layers_from_vllm_config(
vllm_config, AttentionLayerBase, layer_names
)
layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
attn_backend = attn_layers[any_layer_name].get_attn_backend()
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
@ -93,6 +94,7 @@ def _reshape_kv_cache(
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0

View File

@ -34,8 +34,16 @@ class CudaGraphManager:
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_capture_sizes is not None:
self.cudagraph_sizes = sorted(
self.compilation_config.cudagraph_capture_sizes
)
else:
self.cudagraph_sizes = []
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}

View File

@ -329,8 +329,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
torch.cuda.synchronize()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
if scheduler_output.preempted_req_ids is not None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
@ -346,6 +347,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Add new requests.
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None
assert new_req_data.sampling_params is not None
req_id = new_req_data.req_id
self.req_states.add_request(
req_id=req_id,
@ -398,8 +402,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(
scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get,
scheduler_output.num_scheduled_tokens.keys(),
key=lambda k: scheduler_output.num_scheduled_tokens[k],
)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
@ -637,9 +641,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None,
sampled_token_ids=None, # type: ignore
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,

View File

@ -8,8 +8,8 @@ import triton.language as tl
from vllm.config.model import LogprobsMode
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.states import SamplingMetadata
class Sampler: