From 56669c1f293d5c53b6a19ddf2f78802fa9fff2c2 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Thu, 20 Nov 2025 22:36:07 -0500 Subject: [PATCH] [CI] Fix mypy for `vllm/v1/worker` (#29037) Signed-off-by: yewentao256 --- tools/pre_commit/mypy.py | 2 +- vllm/model_executor/utils.py | 2 +- vllm/multimodal/utils.py | 4 +- vllm/v1/worker/cpu_worker.py | 12 +- vllm/v1/worker/gpu_model_runner.py | 128 +++++++++++------- vllm/v1/worker/gpu_ubatch_wrapper.py | 16 ++- vllm/v1/worker/gpu_worker.py | 62 +++++---- .../worker/kv_connector_model_runner_mixin.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 28 +++- vllm/v1/worker/tpu_worker.py | 5 +- vllm/v1/worker/utils.py | 8 +- vllm/v1/worker/worker_base.py | 2 + vllm/v1/worker/xpu_worker.py | 9 +- 13 files changed, 178 insertions(+), 102 deletions(-) diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 8d04848f8f780..34f6e8c928ffb 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -38,6 +38,7 @@ FILES = [ "vllm/usage", "vllm/v1/core", "vllm/v1/engine", + "vllm/v1/worker", ] # After fixing errors resulting from changing follow_imports @@ -62,7 +63,6 @@ SEPARATE_GROUPS = [ "vllm/v1/sample", "vllm/v1/spec_decode", "vllm/v1/structured_output", - "vllm/v1/worker", ] # TODO(woosuk): Include the code from Megatron and HuggingFace. diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 759b809433b14..8aad59e84ff25 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -10,7 +10,7 @@ import torch from vllm.utils.torch_utils import is_torch_equal_or_newer -def set_random_seed(seed: int) -> None: +def set_random_seed(seed: int | None) -> None: from vllm.platforms import current_platform current_platform.seed_everything(seed) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3f55c46ca334d..ac89bdacc01d5 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,7 +3,7 @@ import asyncio import atexit -from collections.abc import Iterable, Set +from collections.abc import Generator, Set from concurrent.futures import ThreadPoolExecutor from itertools import groupby from pathlib import Path @@ -403,7 +403,7 @@ def group_mm_kwargs_by_modality( pin_memory: bool = False, merge_by_field_config: bool | None = None, multimodal_cpu_fields: Set[str] = frozenset(), -) -> Iterable[tuple[str, int, BatchedTensorInputs]]: +) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]: """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same modality together into the same `MultiModalKwargs` instance. diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 4420a057d1e58..b080fea1d2dd6 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -3,6 +3,7 @@ import os import platform from collections.abc import Callable +from typing import Any import torch @@ -37,6 +38,9 @@ class CPUWorker(Worker): self.parallel_config.disable_custom_all_reduce = True + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + self.profiler: Any | None = None if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" @@ -80,13 +84,13 @@ class CPUWorker(Worker): self.local_omp_cpuid = "nobind" else: local_dp_rank = self.parallel_config.data_parallel_rank_local - omp_cpuids = omp_cpuids.split("|") + omp_cpuids_list = omp_cpuids.split("|") if local_dp_rank is not None: world_size = self.parallel_config.world_size - omp_cpuids = omp_cpuids[ + omp_cpuids_list = omp_cpuids_list[ local_dp_rank * world_size : (local_dp_rank + 1) * world_size ] - self.local_omp_cpuid = omp_cpuids[self.rank] + self.local_omp_cpuid = omp_cpuids_list[self.rank] if self.local_omp_cpuid != "nobind": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) @@ -120,7 +124,7 @@ class CPUWorker(Worker): pass def determine_available_memory(self) -> int: - return self.cache_config.cpu_kvcache_space_bytes # type: ignore + return self.cache_config.cpu_kvcache_space_bytes or 0 def compile_or_warm_up_model(self) -> None: # Reset the seed to ensure that the random state is not affected by diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4b0a08ab57e16..a7fa68b20ac50 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,7 +5,7 @@ import gc import itertools import time from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy from functools import reduce @@ -53,6 +53,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import ( + SupportsMRoPE, SupportsMultiModal, is_mixture_of_experts, supports_eagle3, @@ -126,6 +127,7 @@ from vllm.v1.outputs import ( ) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs +from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -404,7 +406,10 @@ class GPUModelRunner( # solution, we initialize the input batch here, and re-initialize it # in `initialize_kv_cache` if the block_sizes here is different from # the block_sizes in the kv cache config. - custom_logitsprocs = model_config.logits_processors + logits_processors = model_config.logits_processors + custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( + tuple(logits_processors) if logits_processors is not None else () + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoer @@ -959,9 +964,13 @@ class GPUModelRunner( def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() assert supports_mrope(model), "M-RoPE support is not implemented." + assert req_state.prompt_token_ids is not None, ( + "M-RoPE requires prompt_token_ids to be available." + ) + mrope_model = cast(SupportsMRoPE, model) req_state.mrope_positions, req_state.mrope_position_delta = ( - model.get_mrope_input_positions( + mrope_model.get_mrope_input_positions( req_state.prompt_token_ids, req_state.mm_features, ) @@ -1762,6 +1771,7 @@ class GPUModelRunner( dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len + assert req.mrope_position_delta is not None MRotaryEmbedding.get_next_input_positions_tensor( out=self.mrope_positions.np, out_offset=dst_start, @@ -1907,6 +1917,8 @@ class GPUModelRunner( for mm_input_id in encoder_input_ids: mm_feature = req_state.mm_features[mm_input_id] + if mm_feature.data is None: + continue mm_hash = mm_feature.identifier mm_kwargs.append(mm_feature.data) mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) @@ -1930,7 +1942,7 @@ class GPUModelRunner( # multimodal inputs. The proper solution should be reordering the # encoder outputs. model = cast(SupportsMultiModal, self.model) - encoder_outputs = [] + encoder_outputs: list[torch.Tensor] = [] for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs, device=self.device, @@ -1938,7 +1950,7 @@ class GPUModelRunner( merge_by_field_config=model.merge_by_field_config, multimodal_cpu_fields=model.multimodal_cpu_fields, ): - curr_group_outputs = [] + curr_group_outputs: list[torch.Tensor] = [] # EVS-related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when @@ -1980,7 +1992,7 @@ class GPUModelRunner( # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) + curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment] sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -2180,7 +2192,7 @@ class GPUModelRunner( def sync_and_slice_intermediate_tensors( self, num_tokens: int, - intermediate_tensors: IntermediateTensors, + intermediate_tensors: IntermediateTensors | None, sync_self: bool, ) -> IntermediateTensors: assert self.intermediate_tensors is not None @@ -2397,6 +2409,7 @@ class GPUModelRunner( if is_first_rank: intermediate_tensors = None else: + assert intermediate_tensors is not None intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True ) @@ -2765,14 +2778,14 @@ class GPUModelRunner( uniform_decode = ( max_num_scheduled_tokens == self.uniform_decode_query_len ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - batch_descriptor = BatchDescriptor( + batch_desc = BatchDescriptor( num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, ) cudagraph_runtime_mode, batch_descriptor = ( self.cudagraph_dispatcher.dispatch( - batch_descriptor, + batch_desc, use_cascade_attn=cascade_attn_prefix_lens is not None, ) ) @@ -2856,15 +2869,15 @@ class GPUModelRunner( else: logits = self.model.compute_logits(sample_hidden_states) - model_output_broadcast_data = {} + model_output_broadcast_data: dict[str, Any] = {} if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() - model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + broadcasted = get_pp_group().broadcast_tensor_dict( model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 ) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] + assert broadcasted is not None + logits = broadcasted["logits"] self.execute_model_state = ExecuteModelState( scheduler_output, @@ -2889,7 +2902,7 @@ class GPUModelRunner( if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. if not kv_connector_output: - return None # noqa + return None # type: ignore[return-value] # In case of PP with kv transfer, we need to pass through the # kv_connector_output @@ -2941,33 +2954,37 @@ class GPUModelRunner( spec_decode_common_attn_metadata, ) + spec_config = self.speculative_config use_padded_batch_for_eagle = ( - self.speculative_config - and self.speculative_config.use_eagle() - and not self.speculative_config.disable_padded_drafter_batch + spec_config is not None + and spec_config.use_eagle() + and not spec_config.disable_padded_drafter_batch ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len if ( - self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len is not None + spec_config is not None + and spec_config.draft_model_config is not None + and spec_config.draft_model_config.max_model_len is not None ): effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len + spec_config.draft_model_config.max_model_len ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= effective_drafter_max_model_len ) if use_padded_batch_for_eagle: + assert self.speculative_config is not None + assert isinstance(self.drafter, EagleProposer) sampled_token_ids = sampler_output.sampled_token_ids if input_fits_in_drafter: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( spec_decode_common_attn_metadata, @@ -3105,7 +3122,9 @@ class GPUModelRunner( common_attn_metadata: CommonAttentionMetadata, ) -> torch.Tensor | list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if self.speculative_config.method == "ngram": + spec_config = self.speculative_config + assert spec_config is not None + if spec_config.method == "ngram": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( @@ -3115,11 +3134,11 @@ class GPUModelRunner( self.input_batch.token_ids_cpu, self.input_batch.spec_decode_unsupported_reqs, ) - elif self.speculative_config.method == "suffix": + elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) - elif self.speculative_config.method == "medusa": + elif spec_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -3144,10 +3163,10 @@ class GPUModelRunner( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif self.speculative_config.use_eagle(): + elif spec_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - if self.speculative_config.disable_padded_drafter_batch: + if spec_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. @@ -3197,7 +3216,7 @@ class GPUModelRunner( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - if self.speculative_config.disable_padded_drafter_batch: + if spec_config.disable_padded_drafter_batch: token_indices_to_sample = None common_attn_metadata, token_indices = self.drafter.prepare_inputs( common_attn_metadata, @@ -3292,9 +3311,12 @@ class GPUModelRunner( and is_mixture_of_experts(self.drafter.model) and self.parallel_config.enable_eplb ): + spec_config = self.vllm_config.speculative_config + assert spec_config is not None + assert spec_config.draft_model_config is not None logger.info_once( "EPLB is enabled for drafter model %s.", - self.vllm_config.speculative_config.draft_model_config.model, + spec_config.draft_model_config.model, ) global_expert_load = ( @@ -3311,7 +3333,7 @@ class GPUModelRunner( self.eplb_state = EplbState(self.parallel_config, self.device) self.eplb_state.add_model( self.drafter.model, - self.vllm_config.speculative_config.draft_model_config, + spec_config.draft_model_config, global_expert_load, old_global_expert_indices, rank_mapping, @@ -3346,9 +3368,11 @@ class GPUModelRunner( scope="local", ) prepare_communication_buffer_for_model(self.model) + mm_config = self.model_config.multimodal_config self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model()) - and self.model_config.multimodal_config.is_multimodal_pruning_enabled() + and mm_config is not None + and mm_config.is_multimodal_pruning_enabled() ) if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: @@ -3383,15 +3407,14 @@ class GPUModelRunner( # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if ( - self.compilation_config.cudagraph_mode.has_full_cudagraphs() - and not self.parallel_config.enable_dbo - ): + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None + if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) elif self.parallel_config.enable_dbo: - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + if cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( self.model, self.vllm_config, CUDAGraphMode.FULL, self.device ) @@ -4071,7 +4094,8 @@ class GPUModelRunner( def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: - if self.model_config.multimodal_config.skip_mm_profiling: + mm_config = self.model_config.multimodal_config + if mm_config is not None and mm_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " "encoder cache." @@ -4333,8 +4357,9 @@ class GPUModelRunner( def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: + layer_type = cast(type[Any], AttentionLayerBase) layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names + self.vllm_config, layer_type, kv_cache_group_spec.layer_names ) attn_backends = {} attn_backend_layers = defaultdict(list) @@ -4349,7 +4374,7 @@ class GPUModelRunner( if layer_name in self.kv_sharing_fast_prefill_eligible_layers: attn_backend = create_fast_prefill_custom_backend( "FastPrefill", - attn_backend, + attn_backend, # type: ignore[arg-type] ) full_cls_name = attn_backend.full_cls_name() @@ -4448,6 +4473,7 @@ class GPUModelRunner( min_cg_backend_name = attn_backend.__name__ # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None # check cudagraph for mixed batch is supported if ( cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL @@ -4562,12 +4588,17 @@ class GPUModelRunner( self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size ) - self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes + capture_sizes = self.compilation_config.cudagraph_capture_sizes + self.cudagraph_batch_sizes = ( + capture_sizes if capture_sizes is not None else [] + ) # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None self.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, self.uniform_decode_query_len + cudagraph_mode, self.uniform_decode_query_len ) def calculate_reorder_batch_threshold(self) -> None: @@ -4579,7 +4610,7 @@ class GPUModelRunner( """ min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) - reorder_batch_thresholds = [ + reorder_batch_thresholds: list[int | None] = [ group.get_metadata_builder().reorder_batch_threshold for group in self._attn_group_iterator() ] @@ -4588,7 +4619,7 @@ class GPUModelRunner( if len(reorder_batch_thresholds) == 0: self.reorder_batch_threshold = None return - self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) + self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment] @staticmethod def select_common_block_size( @@ -5048,12 +5079,16 @@ class GPUModelRunner( kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) if self.dcp_world_size > 1: - layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) + layer_type = cast(type[Any], AttentionLayerBase) + layers = get_layers_from_vllm_config(self.vllm_config, layer_type) for layer in layers.values(): - assert layer.impl.need_to_return_lse_for_decode, ( + layer_impl = getattr(layer, "impl", None) + if layer_impl is None: + continue + assert layer_impl.need_to_return_lse_for_decode, ( "DCP requires attention impls to return" " the softmax lse for decode, but the impl " - f"{layer.impl.__class__.__name__} " + f"{layer_impl.__class__.__name__} " "does not return the softmax lse for decode." ) @@ -5094,7 +5129,8 @@ class GPUModelRunner( if has_ec_transfer() and get_ec_transfer().is_producer: return {} kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) + layer_type = cast(type[Any], AttentionLayerBase) + attn_layers = get_layers_from_vllm_config(self.vllm_config, layer_type) for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention) and ( kv_tgt_layer := attn_module.kv_sharing_target_layer_name diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9de123263755b..2ce2b64512560 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -121,18 +121,24 @@ class UBatchWrapper: @staticmethod def _create_sm_control_context(vllm_config: VllmConfig): - comm_sms = envs.VLLM_DBO_COMM_SMS + comm_sms: int = envs.VLLM_DBO_COMM_SMS set_comm_sms = lambda sms: None if vllm_config.parallel_config.enable_expert_parallel: # Currently only DeepEP highthroughput supports SM control so this # only affects that case. - all2all_manager = get_ep_group().device_communicator.all2all_manager + ep_group = get_ep_group() + device_communicator = ep_group.device_communicator + all2all_manager = None + if device_communicator is not None: + all2all_manager = device_communicator.all2all_manager - if all2all_manager.max_sms_used() is not None: - comm_sms = min(comm_sms, all2all_manager.max_sms_used()) + if all2all_manager is not None: + max_sms_used = all2all_manager.max_sms_used() + if max_sms_used is not None: + comm_sms = min(comm_sms, max_sms_used) - if comm_sms > 0: + if comm_sms > 0 and all2all_manager is not None: set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms) # TODO(lucas): support other kernels besides DeepGEMM diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 18cbc38262793..f1fd5be966c37 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,7 +6,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext from types import NoneType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch import torch.distributed @@ -87,8 +87,10 @@ class Worker(WorkerBase): # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} - # Torch profiler. Enabled and configured through env vars: + # Torch/CUDA profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + # VLLM_TORCH_CUDA_PROFILE=1 + self.profiler: Any | None = None if envs.VLLM_TORCH_PROFILER_DIR: worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" self.profiler = TorchProfilerWrapper( @@ -146,17 +148,17 @@ class Worker(WorkerBase): assert allocator.get_current_usage() == 0, ( "Sleep mode can only be used for one instance per process." ) - context = allocator.use_memory_pool(tag=tag) + return allocator.use_memory_pool(tag=tag) else: - context = nullcontext() - return context + return nullcontext() def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks def init_device(self): - if self.device_config.device.type == "cuda": + device = self.device_config.device + if isinstance(device, torch.device) and device.type == "cuda": # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) if ( @@ -375,23 +377,21 @@ class Worker(WorkerBase): from vllm.device_allocator.cumem import CuMemAllocator allocator = CuMemAllocator.get_instance() - context = allocator.use_memory_pool(tag="kv_cache") + with allocator.use_memory_pool(tag="kv_cache"): + self.model_runner.initialize_kv_cache(kv_cache_config) else: - context = nullcontext() - with context: self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes, # but users still want to compile for better performance, # e.g. for the max-num-batched token size in chunked prefill. - warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + compile_sizes = self.vllm_config.compilation_config.compile_sizes + warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] if not self.model_config.enforce_eager: - warmup_sizes = [ - x - for x in warmup_sizes - if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes - ] + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + if capture_sizes is not None: + warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes] # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) @@ -532,12 +532,12 @@ class Worker(WorkerBase): ) } if forward_pass and not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors, - ) + tensor_dict = get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, ) + assert tensor_dict is not None + intermediate_tensors = IntermediateTensors(tensor_dict) with self.annotate_profile(scheduler_output): output = self.model_runner.execute_model( @@ -605,7 +605,7 @@ class Worker(WorkerBase): assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( execute_shuffle=True, - global_expert_load=None, + global_expert_loads=None, rank_mapping=rank_mapping, ) torch.cuda.synchronize() @@ -661,7 +661,7 @@ class Worker(WorkerBase): def _reconfigure_moe( self, old_ep_size: int, new_ep_size: int - ) -> torch.Tensor | None: + ) -> list[torch.Tensor] | None: """ Reconfigure MoE modules with provided reconfig_request @@ -728,26 +728,29 @@ class Worker(WorkerBase): num_local_physical_experts = num_local_experts assert self.model_runner.eplb_state is not None new_physical_experts = ( - self.model_runner.eplb_state.physical_to_logical_map.shape[1] + self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined] ) parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - - self.model_runner.eplb_state.logical_replica_count.shape[1] + - self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined] ) global_expert_loads = None else: - num_local_physical_experts = torch.tensor( + num_local_physical_experts_tensor = torch.tensor( [num_local_experts], dtype=torch.int32, device="cpu" ) torch.distributed.broadcast( - num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + num_local_physical_experts_tensor, + group=get_ep_group().cpu_group, + group_src=0, ) - num_local_physical_experts = num_local_physical_experts.item() + num_local_physical_experts = int(num_local_physical_experts_tensor.item()) new_physical_experts = num_local_physical_experts * new_ep_size assert self.model_runner.eplb_state is not None - global_expert_loads = self.model_runner.eplb_state.rearrange( + global_expert_loads_any = self.model_runner.eplb_state.rearrange( execute_shuffle=False ) + global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any) parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - global_expert_loads[0].shape[1] ) @@ -849,8 +852,9 @@ def init_worker_distributed_environment( init_batch_invariance() set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + init_method = distributed_init_method or "env://" init_distributed_environment( - parallel_config.world_size, rank, distributed_init_method, local_rank, backend + parallel_config.world_size, rank, init_method, local_rank, backend ) ensure_model_parallel_initialized( diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index e59361f21372a..ff047d8d03f0e 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -59,7 +59,7 @@ class KVConnectorModelRunnerMixin: @staticmethod def ensure_kv_transfer_shutdown() -> None: # has_kv_transfer_group can be None during interpreter shutdown. - if has_kv_transfer_group and has_kv_transfer_group(): + if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function] ensure_kv_transfer_shutdown() @staticmethod diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 923c31c187f31..450160d28649f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -572,7 +572,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): format. Layers that do not need KV cache are not included. """ - layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) + layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) block_size = self.vllm_config.cache_config.block_size cache_dtype_str = self.vllm_config.cache_config.cache_dtype @@ -725,7 +728,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_id = self.input_batch.req_ids[i] assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - if not use_max_model_len and num_tokens > self.most_model_len: + if ( + not use_max_model_len + and self.most_model_len is not None + and num_tokens > self.most_model_len + ): use_max_model_len = True num_scheduled_tokens_per_req.append(num_tokens) if use_max_model_len: @@ -737,6 +744,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: end_index = num_reqs else: + assert self.num_reqs_most_model_len is not None if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len: num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ : self.num_reqs_most_model_len @@ -829,6 +837,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ].to(self.device) seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device) else: + assert self.num_reqs_most_model_len is not None block_tables = self.block_table_cpu[ : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req ] @@ -931,6 +940,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for mm_input_id in encoder_input_ids: mm_feature = req_state.mm_features[mm_input_id] + if mm_feature.data is None: + continue mm_hash = mm_feature.identifier mm_kwargs.append(mm_feature.data) mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) @@ -1114,7 +1125,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> ModelRunnerOutput: if self.scheduler_output is None: # Nothing to do (PP non-final rank case), output isn't used. - return None # noqa + return None # type: ignore[return-value] scheduler_output = self.scheduler_output mm_embed_inputs = self.mm_embed_inputs self.scheduler_output = None @@ -1696,7 +1707,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> None: # Profile with multimodal encoder & encoder cache. if self.supports_mm_inputs: - if self.model_config.multimodal_config.skip_mm_profiling: + mm_config = self.model_config.multimodal_config + if mm_config is not None and mm_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " "encoder cache." @@ -2166,5 +2178,9 @@ def replace_set_lora(model): if isinstance(module, BaseLayerWithLoRA): module._original_set_lora = module.set_lora module._original_reset_lora = module.reset_lora - module.set_lora = _tpu_set_lora.__get__(module, module.__class__) - module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) + module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign] + module, module.__class__ + ) + module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign] + module, module.__class__ + ) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index a716a9c3aa822..569b2aaa766e4 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -141,8 +141,7 @@ class TPUWorker: # Set random seed. set_random_seed(self.model_config.seed) - if self.model_config.seed is not None: - xm.set_rng_state(self.model_config.seed, self.device) + xm.set_rng_state(self.model_config.seed, self.device) # Increase the cache size limit, which is the maximum number of # dynamo graphs that can be compiled. @@ -332,7 +331,7 @@ class TPUWorker: world_size=parallel_config.world_size, rank=rank, local_rank=local_rank, - distributed_init_method=distributed_init_method, + distributed_init_method=distributed_init_method or "env://", backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 9e99ea964ee08..92e4ce3abdba3 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -280,7 +280,7 @@ def bind_kv_cache( kv_caches: dict[str, torch.Tensor], forward_context: dict[str, "Attention"], runner_kv_caches: list[torch.Tensor], - num_attn_module: int | None = 1, + num_attn_module: int = 1, ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so @@ -362,5 +362,7 @@ def is_residual_scattered_for_sp( or vllm_config.compilation_config.use_inductor_graph_partition ): return True - - return num_input_tokens in vllm_config.compilation_config.compile_sizes + compile_sizes = vllm_config.compilation_config.compile_sizes + if compile_sizes is None: + return False + return num_input_tokens in compile_sizes diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 16f321c080779..57e7037e946ec 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -315,10 +315,12 @@ class WorkerWrapperBase: def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: kv_cache_config = kv_cache_configs[self.global_rank] + assert self.vllm_config is not None with set_current_vllm_config(self.vllm_config): self.worker.initialize_from_config(kv_cache_config) # type: ignore def init_device(self): + assert self.vllm_config is not None with set_current_vllm_config(self.vllm_config): # To make vLLM config available during device initialization self.worker.init_device() # type: ignore diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 26c6f8d06bdcd..4d7864e90496a 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from typing import Any import torch import torch.distributed @@ -37,6 +38,7 @@ class XPUWorker(Worker): # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + self.profiler: Any | None = None if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" @@ -148,7 +150,12 @@ class XPUWorker(Worker): return int(available_kv_cache_memory) def init_device(self): - if self.device_config.device.type == "xpu" and current_platform.is_xpu(): + device = self.device_config.device + if ( + isinstance(device, torch.device) + and device.type == "xpu" + and current_platform.is_xpu() + ): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype)