diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py index 638ec6fb0b082..e523090aa2172 100644 --- a/vllm/v1/worker/gpu/async_utils.py +++ b/vllm/v1/worker/gpu/async_utils.py @@ -7,6 +7,7 @@ import torch from vllm.v1.outputs import ( AsyncModelRunnerOutput, + LogprobsTensors, ModelRunnerOutput, SamplerOutput, ) @@ -46,15 +47,18 @@ class AsyncOutput(AsyncModelRunnerOutput): "cpu", non_blocking=True ) if sampler_output.logprobs_tensors is not None: - self.logprobs_tensors = ( + self.logprobs_tensors: LogprobsTensors | None = ( sampler_output.logprobs_tensors.to_cpu_nonblocking() ) else: self.logprobs_tensors = None - self.prompt_logprobs_dict = {} + self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} if self.model_runner_output.prompt_logprobs_dict: for k, v in self.model_runner_output.prompt_logprobs_dict.items(): - self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking() + if v is not None: + self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking() + else: + self.prompt_logprobs_dict[k] = None self.copy_event.record(self.copy_stream) def get_output(self) -> ModelRunnerOutput: @@ -64,12 +68,10 @@ class AsyncOutput(AsyncModelRunnerOutput): # the existing model runner. # Going forward, we should keep the data structures as NumPy arrays # rather than Python lists. - sampled_token_ids_np = self.sampled_token_ids.numpy() - num_reqs = sampled_token_ids_np.shape[0] - sampled_token_ids: list[np.ndarray] = [ - sampled_token_ids_np[i, : self.num_sampled_tokens[i]] - for i in range(num_reqs) - ] + sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist() + num_reqs = len(sampled_token_ids) + for i in range(num_reqs): + del sampled_token_ids[i][self.num_sampled_tokens[i] :] self.model_runner_output.sampled_token_ids = sampled_token_ids if self.logprobs_tensors is not None: diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 8850c18092299..222db565dff17 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Any +from typing import Any, cast import torch @@ -13,6 +13,7 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, ) from vllm.v1.kv_cache_interface import ( + AttentionSpec, KVCacheConfig, KVCacheSpec, ) @@ -22,7 +23,8 @@ from vllm.v1.worker.utils import bind_kv_cache def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase) + layer_type = cast(type[Any], AttentionLayerBase) + attn_layers = get_layers_from_vllm_config(vllm_config, layer_type) for layer_name, attn_module in attn_layers.items(): # Skip modules that don't need KV cache (eg encoder-only attention) if spec := attn_module.get_kv_cache_spec(vllm_config): @@ -35,16 +37,15 @@ def init_attn_backend( vllm_config: VllmConfig, device: torch.device, ): - attn_backends: dict[str, AttentionBackend] = {} + attn_backends: dict[str, type[AttentionBackend]] = {} attn_metadata_builders: list[AttentionMetadataBuilder] = [] flashinfer_workspace: torch.Tensor | None = None for kv_cache_group_spec in kv_cache_config.kv_cache_groups: layer_names = kv_cache_group_spec.layer_names any_layer_name = next(iter(layer_names)) - attn_layers = get_layers_from_vllm_config( - vllm_config, AttentionLayerBase, layer_names - ) + layer_type = cast(type[Any], AttentionLayerBase) + attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names) attn_backend = attn_layers[any_layer_name].get_attn_backend() for layer_name in layer_names: attn_backends[layer_name] = attn_backend @@ -93,6 +94,7 @@ def _reshape_kv_cache( kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec + assert isinstance(kv_cache_spec, AttentionSpec) for layer_name in kv_cache_group_spec.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 7fd1f76669f48..31a706475243c 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -34,8 +34,16 @@ class CudaGraphManager: self.compilation_config = vllm_config.compilation_config assert self.compilation_config is not None - self.cudagraph_mode = self.compilation_config.cudagraph_mode - self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) + if self.compilation_config.cudagraph_mode is None: + self.cudagraph_mode = CUDAGraphMode.NONE + else: + self.cudagraph_mode = self.compilation_config.cudagraph_mode + if self.compilation_config.cudagraph_capture_sizes is not None: + self.cudagraph_sizes = sorted( + self.compilation_config.cudagraph_capture_sizes + ) + else: + self.cudagraph_sizes = [] self.padded_sizes = self._init_padded_sizes() self.graphs: dict[int, torch.cuda.CUDAGraph] = {} diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 08aad9ddd06b3..9ca37ff282d82 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -329,8 +329,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch.cuda.synchronize() def update_states(self, scheduler_output: SchedulerOutput) -> None: - for req_id in scheduler_output.preempted_req_ids: - self.req_states.remove_request(req_id) + if scheduler_output.preempted_req_ids is not None: + for req_id in scheduler_output.preempted_req_ids: + self.req_states.remove_request(req_id) for req_id in scheduler_output.finished_req_ids: self.req_states.remove_request(req_id) @@ -346,6 +347,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Add new requests. for new_req_data in scheduler_output.scheduled_new_reqs: + assert new_req_data.prompt_token_ids is not None + assert new_req_data.prefill_token_ids is not None + assert new_req_data.sampling_params is not None req_id = new_req_data.req_id self.req_states.add_request( req_id=req_id, @@ -398,8 +402,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Decode first, then prefill. # batch_idx -> req_id req_ids = sorted( - scheduler_output.num_scheduled_tokens, - key=scheduler_output.num_scheduled_tokens.get, + scheduler_output.num_scheduled_tokens.keys(), + key=lambda k: scheduler_output.num_scheduled_tokens[k], ) num_scheduled_tokens = np.array( [scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32 @@ -637,9 +641,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_runner_output = ModelRunnerOutput( req_ids=input_batch.req_ids, req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, - sampled_token_ids=None, + sampled_token_ids=None, # type: ignore logprobs=None, - prompt_logprobs_dict=prompt_logprobs_dict, + prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore pooler_output=[], kv_connector_output=None, num_nans_in_logits=None, diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index e916aadb6b5a0..55f98ca6bb6a3 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -8,8 +8,8 @@ import triton.language as tl from vllm.config.model import LogprobsMode from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.worker.gpu.states import SamplingMetadata class Sampler: