[CI] Fix mypy for vllm/v1/worker (#29037)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-20 22:36:07 -05:00 committed by GitHub
parent 3f5f36da3f
commit 56669c1f29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 178 additions and 102 deletions

View File

@ -38,6 +38,7 @@ FILES = [
"vllm/usage", "vllm/usage",
"vllm/v1/core", "vllm/v1/core",
"vllm/v1/engine", "vllm/v1/engine",
"vllm/v1/worker",
] ]
# After fixing errors resulting from changing follow_imports # After fixing errors resulting from changing follow_imports
@ -62,7 +63,6 @@ SEPARATE_GROUPS = [
"vllm/v1/sample", "vllm/v1/sample",
"vllm/v1/spec_decode", "vllm/v1/spec_decode",
"vllm/v1/structured_output", "vllm/v1/structured_output",
"vllm/v1/worker",
] ]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.

View File

@ -10,7 +10,7 @@ import torch
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int | None) -> None:
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.seed_everything(seed) current_platform.seed_everything(seed)

View File

@ -3,7 +3,7 @@
import asyncio import asyncio
import atexit import atexit
from collections.abc import Iterable, Set from collections.abc import Generator, Set
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
@ -403,7 +403,7 @@ def group_mm_kwargs_by_modality(
pin_memory: bool = False, pin_memory: bool = False,
merge_by_field_config: bool | None = None, merge_by_field_config: bool | None = None,
multimodal_cpu_fields: Set[str] = frozenset(), multimodal_cpu_fields: Set[str] = frozenset(),
) -> Iterable[tuple[str, int, BatchedTensorInputs]]: ) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance. modality together into the same `MultiModalKwargs` instance.

View File

@ -3,6 +3,7 @@
import os import os
import platform import platform
from collections.abc import Callable from collections.abc import Callable
from typing import Any
import torch import torch
@ -37,6 +38,9 @@ class CPUWorker(Worker):
self.parallel_config.disable_custom_all_reduce = True self.parallel_config.disable_custom_all_reduce = True
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR: if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
@ -80,13 +84,13 @@ class CPUWorker(Worker):
self.local_omp_cpuid = "nobind" self.local_omp_cpuid = "nobind"
else: else:
local_dp_rank = self.parallel_config.data_parallel_rank_local local_dp_rank = self.parallel_config.data_parallel_rank_local
omp_cpuids = omp_cpuids.split("|") omp_cpuids_list = omp_cpuids.split("|")
if local_dp_rank is not None: if local_dp_rank is not None:
world_size = self.parallel_config.world_size world_size = self.parallel_config.world_size
omp_cpuids = omp_cpuids[ omp_cpuids_list = omp_cpuids_list[
local_dp_rank * world_size : (local_dp_rank + 1) * world_size local_dp_rank * world_size : (local_dp_rank + 1) * world_size
] ]
self.local_omp_cpuid = omp_cpuids[self.rank] self.local_omp_cpuid = omp_cpuids_list[self.rank]
if self.local_omp_cpuid != "nobind": if self.local_omp_cpuid != "nobind":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
@ -120,7 +124,7 @@ class CPUWorker(Worker):
pass pass
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes # type: ignore return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by

View File

@ -5,7 +5,7 @@ import gc
import itertools import itertools
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy, deepcopy from copy import copy, deepcopy
from functools import reduce from functools import reduce
@ -53,6 +53,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
SupportsMRoPE,
SupportsMultiModal, SupportsMultiModal,
is_mixture_of_experts, is_mixture_of_experts,
supports_eagle3, supports_eagle3,
@ -126,6 +127,7 @@ from vllm.v1.outputs import (
) )
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
@ -404,7 +406,10 @@ class GPUModelRunner(
# solution, we initialize the input batch here, and re-initialize it # solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from # in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config. # the block_sizes in the kv cache config.
custom_logitsprocs = model_config.logits_processors logits_processors = model_config.logits_processors
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
tuple(logits_processors) if logits_processors is not None else ()
)
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer # We need to use the encoder length for encoder-decoer
@ -959,9 +964,13 @@ class GPUModelRunner(
def _init_mrope_positions(self, req_state: CachedRequestState): def _init_mrope_positions(self, req_state: CachedRequestState):
model = self.get_model() model = self.get_model()
assert supports_mrope(model), "M-RoPE support is not implemented." assert supports_mrope(model), "M-RoPE support is not implemented."
assert req_state.prompt_token_ids is not None, (
"M-RoPE requires prompt_token_ids to be available."
)
mrope_model = cast(SupportsMRoPE, model)
req_state.mrope_positions, req_state.mrope_position_delta = ( req_state.mrope_positions, req_state.mrope_position_delta = (
model.get_mrope_input_positions( mrope_model.get_mrope_input_positions(
req_state.prompt_token_ids, req_state.prompt_token_ids,
req_state.mm_features, req_state.mm_features,
) )
@ -1762,6 +1771,7 @@ class GPUModelRunner(
dst_start = mrope_pos_ptr dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len dst_end = mrope_pos_ptr + completion_part_len
assert req.mrope_position_delta is not None
MRotaryEmbedding.get_next_input_positions_tensor( MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions.np, out=self.mrope_positions.np,
out_offset=dst_start, out_offset=dst_start,
@ -1907,6 +1917,8 @@ class GPUModelRunner(
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_feature = req_state.mm_features[mm_input_id] mm_feature = req_state.mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data) mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
@ -1930,7 +1942,7 @@ class GPUModelRunner(
# multimodal inputs. The proper solution should be reordering the # multimodal inputs. The proper solution should be reordering the
# encoder outputs. # encoder outputs.
model = cast(SupportsMultiModal, self.model) model = cast(SupportsMultiModal, self.model)
encoder_outputs = [] encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
@ -1938,7 +1950,7 @@ class GPUModelRunner(
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields, multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
curr_group_outputs = [] curr_group_outputs: list[torch.Tensor] = []
# EVS-related change. # EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when # (ekhvedchenia): Temporary hack to limit peak memory usage when
@ -1980,7 +1992,7 @@ class GPUModelRunner(
# 2. A list or tuple (length: num_items) of tensors, # 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature # each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items. # size is dynamic depending on the input multimodal items.
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment]
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs,
@ -2180,7 +2192,7 @@ class GPUModelRunner(
def sync_and_slice_intermediate_tensors( def sync_and_slice_intermediate_tensors(
self, self,
num_tokens: int, num_tokens: int,
intermediate_tensors: IntermediateTensors, intermediate_tensors: IntermediateTensors | None,
sync_self: bool, sync_self: bool,
) -> IntermediateTensors: ) -> IntermediateTensors:
assert self.intermediate_tensors is not None assert self.intermediate_tensors is not None
@ -2397,6 +2409,7 @@ class GPUModelRunner(
if is_first_rank: if is_first_rank:
intermediate_tensors = None intermediate_tensors = None
else: else:
assert intermediate_tensors is not None
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True num_input_tokens, intermediate_tensors, True
) )
@ -2765,14 +2778,14 @@ class GPUModelRunner(
uniform_decode = ( uniform_decode = (
max_num_scheduled_tokens == self.uniform_decode_query_len max_num_scheduled_tokens == self.uniform_decode_query_len
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
batch_descriptor = BatchDescriptor( batch_desc = BatchDescriptor(
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
) )
cudagraph_runtime_mode, batch_descriptor = ( cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch( self.cudagraph_dispatcher.dispatch(
batch_descriptor, batch_desc,
use_cascade_attn=cascade_attn_prefix_lens is not None, use_cascade_attn=cascade_attn_prefix_lens is not None,
) )
) )
@ -2856,15 +2869,15 @@ class GPUModelRunner(
else: else:
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
model_output_broadcast_data = {} model_output_broadcast_data: dict[str, Any] = {}
if logits is not None: if logits is not None:
model_output_broadcast_data["logits"] = logits.contiguous() model_output_broadcast_data["logits"] = logits.contiguous()
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( broadcasted = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
) )
assert model_output_broadcast_data is not None assert broadcasted is not None
logits = model_output_broadcast_data["logits"] logits = broadcasted["logits"]
self.execute_model_state = ExecuteModelState( self.execute_model_state = ExecuteModelState(
scheduler_output, scheduler_output,
@ -2889,7 +2902,7 @@ class GPUModelRunner(
if self.execute_model_state is None: if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used. # Nothing to do (PP non-final rank case), output isn't used.
if not kv_connector_output: if not kv_connector_output:
return None # noqa return None # type: ignore[return-value]
# In case of PP with kv transfer, we need to pass through the # In case of PP with kv transfer, we need to pass through the
# kv_connector_output # kv_connector_output
@ -2941,33 +2954,37 @@ class GPUModelRunner(
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
) )
spec_config = self.speculative_config
use_padded_batch_for_eagle = ( use_padded_batch_for_eagle = (
self.speculative_config spec_config is not None
and self.speculative_config.use_eagle() and spec_config.use_eagle()
and not self.speculative_config.disable_padded_drafter_batch and not spec_config.disable_padded_drafter_batch
) )
effective_drafter_max_model_len = self.max_model_len effective_drafter_max_model_len = self.max_model_len
if effective_drafter_max_model_len is None: if effective_drafter_max_model_len is None:
effective_drafter_max_model_len = self.model_config.max_model_len effective_drafter_max_model_len = self.model_config.max_model_len
if ( if (
self.speculative_config spec_config is not None
and self.speculative_config.draft_model_config is not None and spec_config.draft_model_config is not None
and self.speculative_config.draft_model_config.max_model_len is not None and spec_config.draft_model_config.max_model_len is not None
): ):
effective_drafter_max_model_len = ( effective_drafter_max_model_len = (
self.speculative_config.draft_model_config.max_model_len spec_config.draft_model_config.max_model_len
) )
input_fits_in_drafter = spec_decode_common_attn_metadata and ( input_fits_in_drafter = spec_decode_common_attn_metadata and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= effective_drafter_max_model_len <= effective_drafter_max_model_len
) )
if use_padded_batch_for_eagle: if use_padded_batch_for_eagle:
assert self.speculative_config is not None
assert isinstance(self.drafter, EagleProposer)
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter: if input_fits_in_drafter:
# EAGLE speculative decoding can use the GPU sampled tokens # EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish. # as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampled_token_ids) propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None: elif self.valid_sampled_token_count_event is not None:
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = ( next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded( self.drafter.prepare_next_token_ids_padded(
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
@ -3105,7 +3122,9 @@ class GPUModelRunner(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> torch.Tensor | list[list[int]]: ) -> torch.Tensor | list[list[int]]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram": spec_config = self.speculative_config
assert spec_config is not None
if spec_config.method == "ngram":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
@ -3115,11 +3134,11 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu, self.input_batch.token_ids_cpu,
self.input_batch.spec_decode_unsupported_reqs, self.input_batch.spec_decode_unsupported_reqs,
) )
elif self.speculative_config.method == "suffix": elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer) assert isinstance(self.drafter, SuffixDecodingProposer)
draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
elif self.speculative_config.method == "medusa": elif spec_config.method == "medusa":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer) assert isinstance(self.drafter, MedusaProposer)
@ -3144,10 +3163,10 @@ class GPUModelRunner(
target_hidden_states=hidden_states, target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
elif self.speculative_config.use_eagle(): elif spec_config.use_eagle():
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
if self.speculative_config.disable_padded_drafter_batch: if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be # When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each # the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists. # request, with invalid requests having empty lists.
@ -3197,7 +3216,7 @@ class GPUModelRunner(
else: else:
target_hidden_states = hidden_states[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens]
else: else:
if self.speculative_config.disable_padded_drafter_batch: if spec_config.disable_padded_drafter_batch:
token_indices_to_sample = None token_indices_to_sample = None
common_attn_metadata, token_indices = self.drafter.prepare_inputs( common_attn_metadata, token_indices = self.drafter.prepare_inputs(
common_attn_metadata, common_attn_metadata,
@ -3292,9 +3311,12 @@ class GPUModelRunner(
and is_mixture_of_experts(self.drafter.model) and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb and self.parallel_config.enable_eplb
): ):
spec_config = self.vllm_config.speculative_config
assert spec_config is not None
assert spec_config.draft_model_config is not None
logger.info_once( logger.info_once(
"EPLB is enabled for drafter model %s.", "EPLB is enabled for drafter model %s.",
self.vllm_config.speculative_config.draft_model_config.model, spec_config.draft_model_config.model,
) )
global_expert_load = ( global_expert_load = (
@ -3311,7 +3333,7 @@ class GPUModelRunner(
self.eplb_state = EplbState(self.parallel_config, self.device) self.eplb_state = EplbState(self.parallel_config, self.device)
self.eplb_state.add_model( self.eplb_state.add_model(
self.drafter.model, self.drafter.model,
self.vllm_config.speculative_config.draft_model_config, spec_config.draft_model_config,
global_expert_load, global_expert_load,
old_global_expert_indices, old_global_expert_indices,
rank_mapping, rank_mapping,
@ -3346,9 +3368,11 @@ class GPUModelRunner(
scope="local", scope="local",
) )
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model()) supports_multimodal_pruning(self.get_model())
and self.model_config.multimodal_config.is_multimodal_pruning_enabled() and mm_config is not None
and mm_config.is_multimodal_pruning_enabled()
) )
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
@ -3383,15 +3407,14 @@ class GPUModelRunner(
# CudagraphWraper and CudagraphDispatcher of vllm. # CudagraphWraper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed. # wrap the model with full cudagraph wrapper if needed.
if ( cudagraph_mode = self.compilation_config.cudagraph_mode
self.compilation_config.cudagraph_mode.has_full_cudagraphs() assert cudagraph_mode is not None
and not self.parallel_config.enable_dbo if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
):
self.model = CUDAGraphWrapper( self.model = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
) )
elif self.parallel_config.enable_dbo: elif self.parallel_config.enable_dbo:
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if cudagraph_mode.has_full_cudagraphs():
self.model = UBatchWrapper( self.model = UBatchWrapper(
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
) )
@ -4071,7 +4094,8 @@ class GPUModelRunner(
def profile_run(self) -> None: def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs: if self.supports_mm_inputs:
if self.model_config.multimodal_config.skip_mm_profiling: mm_config = self.model_config.multimodal_config
if mm_config is not None and mm_config.skip_mm_profiling:
logger.info( logger.info(
"Skipping memory profiling for multimodal encoder and " "Skipping memory profiling for multimodal encoder and "
"encoder cache." "encoder cache."
@ -4333,8 +4357,9 @@ class GPUModelRunner(
def get_attn_backends_for_group( def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec, kv_cache_group_spec: KVCacheGroupSpec,
) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config( layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names self.vllm_config, layer_type, kv_cache_group_spec.layer_names
) )
attn_backends = {} attn_backends = {}
attn_backend_layers = defaultdict(list) attn_backend_layers = defaultdict(list)
@ -4349,7 +4374,7 @@ class GPUModelRunner(
if layer_name in self.kv_sharing_fast_prefill_eligible_layers: if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
attn_backend = create_fast_prefill_custom_backend( attn_backend = create_fast_prefill_custom_backend(
"FastPrefill", "FastPrefill",
attn_backend, attn_backend, # type: ignore[arg-type]
) )
full_cls_name = attn_backend.full_cls_name() full_cls_name = attn_backend.full_cls_name()
@ -4448,6 +4473,7 @@ class GPUModelRunner(
min_cg_backend_name = attn_backend.__name__ min_cg_backend_name = attn_backend.__name__
# Flexible resolve the cudagraph mode # Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
# check cudagraph for mixed batch is supported # check cudagraph for mixed batch is supported
if ( if (
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
@ -4562,12 +4588,17 @@ class GPUModelRunner(
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
) )
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes capture_sizes = self.compilation_config.cudagraph_capture_sizes
self.cudagraph_batch_sizes = (
capture_sizes if capture_sizes is not None else []
)
# Trigger cudagraph dispatching keys initialization after # Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode. # resolved cudagraph mode.
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
self.cudagraph_dispatcher.initialize_cudagraph_keys( self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode, self.uniform_decode_query_len cudagraph_mode, self.uniform_decode_query_len
) )
def calculate_reorder_batch_threshold(self) -> None: def calculate_reorder_batch_threshold(self) -> None:
@ -4579,7 +4610,7 @@ class GPUModelRunner(
""" """
min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b) min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b)
reorder_batch_thresholds = [ reorder_batch_thresholds: list[int | None] = [
group.get_metadata_builder().reorder_batch_threshold group.get_metadata_builder().reorder_batch_threshold
for group in self._attn_group_iterator() for group in self._attn_group_iterator()
] ]
@ -4588,7 +4619,7 @@ class GPUModelRunner(
if len(reorder_batch_thresholds) == 0: if len(reorder_batch_thresholds) == 0:
self.reorder_batch_threshold = None self.reorder_batch_threshold = None
return return
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment]
@staticmethod @staticmethod
def select_common_block_size( def select_common_block_size(
@ -5048,12 +5079,16 @@ class GPUModelRunner(
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
for layer in layers.values(): for layer in layers.values():
assert layer.impl.need_to_return_lse_for_decode, ( layer_impl = getattr(layer, "impl", None)
if layer_impl is None:
continue
assert layer_impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return" "DCP requires attention impls to return"
" the softmax lse for decode, but the impl " " the softmax lse for decode, but the impl "
f"{layer.impl.__class__.__name__} " f"{layer_impl.__class__.__name__} "
"does not return the softmax lse for decode." "does not return the softmax lse for decode."
) )
@ -5094,7 +5129,8 @@ class GPUModelRunner(
if has_ec_transfer() and get_ec_transfer().is_producer: if has_ec_transfer() and get_ec_transfer().is_producer:
return {} return {}
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) layer_type = cast(type[Any], AttentionLayerBase)
attn_layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
for layer_name, attn_module in attn_layers.items(): for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention) and ( if isinstance(attn_module, Attention) and (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name kv_tgt_layer := attn_module.kv_sharing_target_layer_name

View File

@ -121,18 +121,24 @@ class UBatchWrapper:
@staticmethod @staticmethod
def _create_sm_control_context(vllm_config: VllmConfig): def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms = envs.VLLM_DBO_COMM_SMS comm_sms: int = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None set_comm_sms = lambda sms: None
if vllm_config.parallel_config.enable_expert_parallel: if vllm_config.parallel_config.enable_expert_parallel:
# Currently only DeepEP highthroughput supports SM control so this # Currently only DeepEP highthroughput supports SM control so this
# only affects that case. # only affects that case.
all2all_manager = get_ep_group().device_communicator.all2all_manager ep_group = get_ep_group()
device_communicator = ep_group.device_communicator
all2all_manager = None
if device_communicator is not None:
all2all_manager = device_communicator.all2all_manager
if all2all_manager.max_sms_used() is not None: if all2all_manager is not None:
comm_sms = min(comm_sms, all2all_manager.max_sms_used()) max_sms_used = all2all_manager.max_sms_used()
if max_sms_used is not None:
comm_sms = min(comm_sms, max_sms_used)
if comm_sms > 0: if comm_sms > 0 and all2all_manager is not None:
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms) set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
# TODO(lucas): support other kernels besides DeepGEMM # TODO(lucas): support other kernels besides DeepGEMM

View File

@ -6,7 +6,7 @@ import gc
import os import os
from contextlib import AbstractContextManager, nullcontext from contextlib import AbstractContextManager, nullcontext
from types import NoneType from types import NoneType
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, cast
import torch import torch
import torch.distributed import torch.distributed
@ -87,8 +87,10 @@ class Worker(WorkerBase):
# Buffers saved before sleep # Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {} self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars: # Torch/CUDA profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
# VLLM_TORCH_CUDA_PROFILE=1
self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR: if envs.VLLM_TORCH_PROFILER_DIR:
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler = TorchProfilerWrapper( self.profiler = TorchProfilerWrapper(
@ -146,17 +148,17 @@ class Worker(WorkerBase):
assert allocator.get_current_usage() == 0, ( assert allocator.get_current_usage() == 0, (
"Sleep mode can only be used for one instance per process." "Sleep mode can only be used for one instance per process."
) )
context = allocator.use_memory_pool(tag=tag) return allocator.use_memory_pool(tag=tag)
else: else:
context = nullcontext() return nullcontext()
return context
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self): def init_device(self):
if self.device_config.device.type == "cuda": device = self.device_config.device
if isinstance(device, torch.device) and device.type == "cuda":
# This env var set by Ray causes exceptions with graph building. # This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
if ( if (
@ -375,23 +377,21 @@ class Worker(WorkerBase):
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache") with allocator.use_memory_pool(tag="kv_cache"):
self.model_runner.initialize_kv_cache(kv_cache_config)
else: else:
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes, # warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance, # but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill. # e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() compile_sizes = self.vllm_config.compilation_config.compile_sizes
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
warmup_sizes = [ capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
x if capture_sizes is not None:
for x in warmup_sizes warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes]
if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
]
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True): for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size) logger.info("Compile and warming up model for size %d", size)
@ -532,12 +532,12 @@ class Worker(WorkerBase):
) )
} }
if forward_pass and not get_pp_group().is_first_rank: if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors( tensor_dict = get_pp_group().recv_tensor_dict(
get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group(),
all_gather_group=get_tp_group(), all_gather_tensors=all_gather_tensors,
all_gather_tensors=all_gather_tensors,
)
) )
assert tensor_dict is not None
intermediate_tensors = IntermediateTensors(tensor_dict)
with self.annotate_profile(scheduler_output): with self.annotate_profile(scheduler_output):
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
@ -605,7 +605,7 @@ class Worker(WorkerBase):
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange( self.model_runner.eplb_state.rearrange(
execute_shuffle=True, execute_shuffle=True,
global_expert_load=None, global_expert_loads=None,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -661,7 +661,7 @@ class Worker(WorkerBase):
def _reconfigure_moe( def _reconfigure_moe(
self, old_ep_size: int, new_ep_size: int self, old_ep_size: int, new_ep_size: int
) -> torch.Tensor | None: ) -> list[torch.Tensor] | None:
""" """
Reconfigure MoE modules with provided reconfig_request Reconfigure MoE modules with provided reconfig_request
@ -728,26 +728,29 @@ class Worker(WorkerBase):
num_local_physical_experts = num_local_experts num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
new_physical_experts = ( new_physical_experts = (
self.model_runner.eplb_state.physical_to_logical_map.shape[1] self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
) )
parallel_config.eplb_config.num_redundant_experts = ( parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] - self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
) )
global_expert_loads = None global_expert_loads = None
else: else:
num_local_physical_experts = torch.tensor( num_local_physical_experts_tensor = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu" [num_local_experts], dtype=torch.int32, device="cpu"
) )
torch.distributed.broadcast( torch.distributed.broadcast(
num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 num_local_physical_experts_tensor,
group=get_ep_group().cpu_group,
group_src=0,
) )
num_local_physical_experts = num_local_physical_experts.item() num_local_physical_experts = int(num_local_physical_experts_tensor.item())
new_physical_experts = num_local_physical_experts * new_ep_size new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
global_expert_loads = self.model_runner.eplb_state.rearrange( global_expert_loads_any = self.model_runner.eplb_state.rearrange(
execute_shuffle=False execute_shuffle=False
) )
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
parallel_config.eplb_config.num_redundant_experts = ( parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_loads[0].shape[1] new_physical_experts - global_expert_loads[0].shape[1]
) )
@ -849,8 +852,9 @@ def init_worker_distributed_environment(
init_batch_invariance() init_batch_invariance()
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_method = distributed_init_method or "env://"
init_distributed_environment( init_distributed_environment(
parallel_config.world_size, rank, distributed_init_method, local_rank, backend parallel_config.world_size, rank, init_method, local_rank, backend
) )
ensure_model_parallel_initialized( ensure_model_parallel_initialized(

View File

@ -59,7 +59,7 @@ class KVConnectorModelRunnerMixin:
@staticmethod @staticmethod
def ensure_kv_transfer_shutdown() -> None: def ensure_kv_transfer_shutdown() -> None:
# has_kv_transfer_group can be None during interpreter shutdown. # has_kv_transfer_group can be None during interpreter shutdown.
if has_kv_transfer_group and has_kv_transfer_group(): if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
ensure_kv_transfer_shutdown() ensure_kv_transfer_shutdown()
@staticmethod @staticmethod

View File

@ -572,7 +572,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
cache_dtype_str = self.vllm_config.cache_config.cache_dtype cache_dtype_str = self.vllm_config.cache_config.cache_dtype
@ -725,7 +728,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_id = self.input_batch.req_ids[i] req_id = self.input_batch.req_ids[i]
assert req_id is not None assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if not use_max_model_len and num_tokens > self.most_model_len: if (
not use_max_model_len
and self.most_model_len is not None
and num_tokens > self.most_model_len
):
use_max_model_len = True use_max_model_len = True
num_scheduled_tokens_per_req.append(num_tokens) num_scheduled_tokens_per_req.append(num_tokens)
if use_max_model_len: if use_max_model_len:
@ -737,6 +744,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
end_index = num_reqs end_index = num_reqs
else: else:
assert self.num_reqs_most_model_len is not None
if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len: if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len:
num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
: self.num_reqs_most_model_len : self.num_reqs_most_model_len
@ -829,6 +837,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
].to(self.device) ].to(self.device)
seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device) seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device)
else: else:
assert self.num_reqs_most_model_len is not None
block_tables = self.block_table_cpu[ block_tables = self.block_table_cpu[
: self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req
] ]
@ -931,6 +940,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for mm_input_id in encoder_input_ids: for mm_input_id in encoder_input_ids:
mm_feature = req_state.mm_features[mm_input_id] mm_feature = req_state.mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data) mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
@ -1114,7 +1125,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
if self.scheduler_output is None: if self.scheduler_output is None:
# Nothing to do (PP non-final rank case), output isn't used. # Nothing to do (PP non-final rank case), output isn't used.
return None # noqa return None # type: ignore[return-value]
scheduler_output = self.scheduler_output scheduler_output = self.scheduler_output
mm_embed_inputs = self.mm_embed_inputs mm_embed_inputs = self.mm_embed_inputs
self.scheduler_output = None self.scheduler_output = None
@ -1696,7 +1707,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> None: ) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs: if self.supports_mm_inputs:
if self.model_config.multimodal_config.skip_mm_profiling: mm_config = self.model_config.multimodal_config
if mm_config is not None and mm_config.skip_mm_profiling:
logger.info( logger.info(
"Skipping memory profiling for multimodal encoder and " "Skipping memory profiling for multimodal encoder and "
"encoder cache." "encoder cache."
@ -2166,5 +2178,9 @@ def replace_set_lora(model):
if isinstance(module, BaseLayerWithLoRA): if isinstance(module, BaseLayerWithLoRA):
module._original_set_lora = module.set_lora module._original_set_lora = module.set_lora
module._original_reset_lora = module.reset_lora module._original_reset_lora = module.reset_lora
module.set_lora = _tpu_set_lora.__get__(module, module.__class__) module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign]
module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) module, module.__class__
)
module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign]
module, module.__class__
)

View File

@ -141,8 +141,7 @@ class TPUWorker:
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
if self.model_config.seed is not None: xm.set_rng_state(self.model_config.seed, self.device)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of # Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled. # dynamo graphs that can be compiled.
@ -332,7 +331,7 @@ class TPUWorker:
world_size=parallel_config.world_size, world_size=parallel_config.world_size,
rank=rank, rank=rank,
local_rank=local_rank, local_rank=local_rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method or "env://",
backend=current_platform.dist_backend, backend=current_platform.dist_backend,
) )
ensure_model_parallel_initialized( ensure_model_parallel_initialized(

View File

@ -280,7 +280,7 @@ def bind_kv_cache(
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"], forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor], runner_kv_caches: list[torch.Tensor],
num_attn_module: int | None = 1, num_attn_module: int = 1,
) -> None: ) -> None:
""" """
Bind the allocated KV cache to both ModelRunner and forward context so Bind the allocated KV cache to both ModelRunner and forward context so
@ -362,5 +362,7 @@ def is_residual_scattered_for_sp(
or vllm_config.compilation_config.use_inductor_graph_partition or vllm_config.compilation_config.use_inductor_graph_partition
): ):
return True return True
compile_sizes = vllm_config.compilation_config.compile_sizes
return num_input_tokens in vllm_config.compilation_config.compile_sizes if compile_sizes is None:
return False
return num_input_tokens in compile_sizes

View File

@ -315,10 +315,12 @@ class WorkerWrapperBase:
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.global_rank] kv_cache_config = kv_cache_configs[self.global_rank]
assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self): def init_device(self):
assert self.vllm_config is not None
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization # To make vLLM config available during device initialization
self.worker.init_device() # type: ignore self.worker.init_device() # type: ignore

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from typing import Any
import torch import torch
import torch.distributed import torch.distributed
@ -37,6 +38,7 @@ class XPUWorker(Worker):
# Torch profiler. Enabled and configured through env vars: # Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
self.profiler: Any | None = None
if envs.VLLM_TORCH_PROFILER_DIR: if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
@ -148,7 +150,12 @@ class XPUWorker(Worker):
return int(available_kv_cache_memory) return int(available_kv_cache_memory)
def init_device(self): def init_device(self):
if self.device_config.device.type == "xpu" and current_platform.is_xpu(): device = self.device_config.device
if (
isinstance(device, torch.device)
and device.type == "xpu"
and current_platform.is_xpu()
):
self.device = torch.device(f"xpu:{self.local_rank}") self.device = torch.device(f"xpu:{self.local_rank}")
current_platform.set_device(self.device) current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype) current_platform.check_if_supports_dtype(self.model_config.dtype)