mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 09:14:29 +08:00
[CI] Fix mypy for vllm/v1/worker (#29037)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
3f5f36da3f
commit
56669c1f29
@ -38,6 +38,7 @@ FILES = [
|
||||
"vllm/usage",
|
||||
"vllm/v1/core",
|
||||
"vllm/v1/engine",
|
||||
"vllm/v1/worker",
|
||||
]
|
||||
|
||||
# After fixing errors resulting from changing follow_imports
|
||||
@ -62,7 +63,6 @@ SEPARATE_GROUPS = [
|
||||
"vllm/v1/sample",
|
||||
"vllm/v1/spec_decode",
|
||||
"vllm/v1/structured_output",
|
||||
"vllm/v1/worker",
|
||||
]
|
||||
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
|
||||
@ -10,7 +10,7 @@ import torch
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
def set_random_seed(seed: int | None) -> None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Iterable, Set
|
||||
from collections.abc import Generator, Set
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
@ -403,7 +403,7 @@ def group_mm_kwargs_by_modality(
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool | None = None,
|
||||
multimodal_cpu_fields: Set[str] = frozenset(),
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import os
|
||||
import platform
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -37,6 +38,9 @@ class CPUWorker(Worker):
|
||||
|
||||
self.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
self.profiler: Any | None = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
@ -80,13 +84,13 @@ class CPUWorker(Worker):
|
||||
self.local_omp_cpuid = "nobind"
|
||||
else:
|
||||
local_dp_rank = self.parallel_config.data_parallel_rank_local
|
||||
omp_cpuids = omp_cpuids.split("|")
|
||||
omp_cpuids_list = omp_cpuids.split("|")
|
||||
if local_dp_rank is not None:
|
||||
world_size = self.parallel_config.world_size
|
||||
omp_cpuids = omp_cpuids[
|
||||
omp_cpuids_list = omp_cpuids_list[
|
||||
local_dp_rank * world_size : (local_dp_rank + 1) * world_size
|
||||
]
|
||||
self.local_omp_cpuid = omp_cpuids[self.rank]
|
||||
self.local_omp_cpuid = omp_cpuids_list[self.rank]
|
||||
|
||||
if self.local_omp_cpuid != "nobind":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
@ -120,7 +124,7 @@ class CPUWorker(Worker):
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
return self.cache_config.cpu_kvcache_space_bytes # type: ignore
|
||||
return self.cache_config.cpu_kvcache_space_bytes or 0
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
|
||||
@ -5,7 +5,7 @@ import gc
|
||||
import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from functools import reduce
|
||||
@ -53,6 +53,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
@ -126,6 +127,7 @@ from vllm.v1.outputs import (
|
||||
)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
@ -404,7 +406,10 @@ class GPUModelRunner(
|
||||
# solution, we initialize the input batch here, and re-initialize it
|
||||
# in `initialize_kv_cache` if the block_sizes here is different from
|
||||
# the block_sizes in the kv cache config.
|
||||
custom_logitsprocs = model_config.logits_processors
|
||||
logits_processors = model_config.logits_processors
|
||||
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
|
||||
tuple(logits_processors) if logits_processors is not None else ()
|
||||
)
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
# We need to use the encoder length for encoder-decoer
|
||||
@ -959,9 +964,13 @@ class GPUModelRunner(
|
||||
def _init_mrope_positions(self, req_state: CachedRequestState):
|
||||
model = self.get_model()
|
||||
assert supports_mrope(model), "M-RoPE support is not implemented."
|
||||
assert req_state.prompt_token_ids is not None, (
|
||||
"M-RoPE requires prompt_token_ids to be available."
|
||||
)
|
||||
mrope_model = cast(SupportsMRoPE, model)
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = (
|
||||
model.get_mrope_input_positions(
|
||||
mrope_model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
req_state.mm_features,
|
||||
)
|
||||
@ -1762,6 +1771,7 @@ class GPUModelRunner(
|
||||
dst_start = mrope_pos_ptr
|
||||
dst_end = mrope_pos_ptr + completion_part_len
|
||||
|
||||
assert req.mrope_position_delta is not None
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
out=self.mrope_positions.np,
|
||||
out_offset=dst_start,
|
||||
@ -1907,6 +1917,8 @@ class GPUModelRunner(
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = req_state.mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
continue
|
||||
mm_hash = mm_feature.identifier
|
||||
mm_kwargs.append(mm_feature.data)
|
||||
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
||||
@ -1930,7 +1942,7 @@ class GPUModelRunner(
|
||||
# multimodal inputs. The proper solution should be reordering the
|
||||
# encoder outputs.
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
encoder_outputs = []
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
@ -1938,7 +1950,7 @@ class GPUModelRunner(
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||
):
|
||||
curr_group_outputs = []
|
||||
curr_group_outputs: list[torch.Tensor] = []
|
||||
|
||||
# EVS-related change.
|
||||
# (ekhvedchenia): Temporary hack to limit peak memory usage when
|
||||
@ -1980,7 +1992,7 @@ class GPUModelRunner(
|
||||
# 2. A list or tuple (length: num_items) of tensors,
|
||||
# each of shape (feature_size, hidden_size) in case the feature
|
||||
# size is dynamic depending on the input multimodal items.
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment]
|
||||
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs,
|
||||
@ -2180,7 +2192,7 @@ class GPUModelRunner(
|
||||
def sync_and_slice_intermediate_tensors(
|
||||
self,
|
||||
num_tokens: int,
|
||||
intermediate_tensors: IntermediateTensors,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
sync_self: bool,
|
||||
) -> IntermediateTensors:
|
||||
assert self.intermediate_tensors is not None
|
||||
@ -2397,6 +2409,7 @@ class GPUModelRunner(
|
||||
if is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True
|
||||
)
|
||||
@ -2765,14 +2778,14 @@ class GPUModelRunner(
|
||||
uniform_decode = (
|
||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||||
batch_descriptor = BatchDescriptor(
|
||||
batch_desc = BatchDescriptor(
|
||||
num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
||||
)
|
||||
cudagraph_runtime_mode, batch_descriptor = (
|
||||
self.cudagraph_dispatcher.dispatch(
|
||||
batch_descriptor,
|
||||
batch_desc,
|
||||
use_cascade_attn=cascade_attn_prefix_lens is not None,
|
||||
)
|
||||
)
|
||||
@ -2856,15 +2869,15 @@ class GPUModelRunner(
|
||||
else:
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
model_output_broadcast_data = {}
|
||||
model_output_broadcast_data: dict[str, Any] = {}
|
||||
if logits is not None:
|
||||
model_output_broadcast_data["logits"] = logits.contiguous()
|
||||
|
||||
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
||||
broadcasted = get_pp_group().broadcast_tensor_dict(
|
||||
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
|
||||
)
|
||||
assert model_output_broadcast_data is not None
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
assert broadcasted is not None
|
||||
logits = broadcasted["logits"]
|
||||
|
||||
self.execute_model_state = ExecuteModelState(
|
||||
scheduler_output,
|
||||
@ -2889,7 +2902,7 @@ class GPUModelRunner(
|
||||
if self.execute_model_state is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
if not kv_connector_output:
|
||||
return None # noqa
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
@ -2941,33 +2954,37 @@ class GPUModelRunner(
|
||||
spec_decode_common_attn_metadata,
|
||||
)
|
||||
|
||||
spec_config = self.speculative_config
|
||||
use_padded_batch_for_eagle = (
|
||||
self.speculative_config
|
||||
and self.speculative_config.use_eagle()
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
spec_config is not None
|
||||
and spec_config.use_eagle()
|
||||
and not spec_config.disable_padded_drafter_batch
|
||||
)
|
||||
effective_drafter_max_model_len = self.max_model_len
|
||||
if effective_drafter_max_model_len is None:
|
||||
effective_drafter_max_model_len = self.model_config.max_model_len
|
||||
if (
|
||||
self.speculative_config
|
||||
and self.speculative_config.draft_model_config is not None
|
||||
and self.speculative_config.draft_model_config.max_model_len is not None
|
||||
spec_config is not None
|
||||
and spec_config.draft_model_config is not None
|
||||
and spec_config.draft_model_config.max_model_len is not None
|
||||
):
|
||||
effective_drafter_max_model_len = (
|
||||
self.speculative_config.draft_model_config.max_model_len
|
||||
spec_config.draft_model_config.max_model_len
|
||||
)
|
||||
input_fits_in_drafter = spec_decode_common_attn_metadata and (
|
||||
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
|
||||
<= effective_drafter_max_model_len
|
||||
)
|
||||
if use_padded_batch_for_eagle:
|
||||
assert self.speculative_config is not None
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampled_token_ids)
|
||||
elif self.valid_sampled_token_count_event is not None:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
spec_decode_common_attn_metadata,
|
||||
@ -3105,7 +3122,9 @@ class GPUModelRunner(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> torch.Tensor | list[list[int]]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
spec_config = self.speculative_config
|
||||
assert spec_config is not None
|
||||
if spec_config.method == "ngram":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
@ -3115,11 +3134,11 @@ class GPUModelRunner(
|
||||
self.input_batch.token_ids_cpu,
|
||||
self.input_batch.spec_decode_unsupported_reqs,
|
||||
)
|
||||
elif self.speculative_config.method == "suffix":
|
||||
elif spec_config.method == "suffix":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, SuffixDecodingProposer)
|
||||
draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
elif spec_config.method == "medusa":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
|
||||
@ -3144,10 +3163,10 @@ class GPUModelRunner(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
elif spec_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
if spec_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||
# request, with invalid requests having empty lists.
|
||||
@ -3197,7 +3216,7 @@ class GPUModelRunner(
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
if spec_config.disable_padded_drafter_batch:
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
|
||||
common_attn_metadata,
|
||||
@ -3292,9 +3311,12 @@ class GPUModelRunner(
|
||||
and is_mixture_of_experts(self.drafter.model)
|
||||
and self.parallel_config.enable_eplb
|
||||
):
|
||||
spec_config = self.vllm_config.speculative_config
|
||||
assert spec_config is not None
|
||||
assert spec_config.draft_model_config is not None
|
||||
logger.info_once(
|
||||
"EPLB is enabled for drafter model %s.",
|
||||
self.vllm_config.speculative_config.draft_model_config.model,
|
||||
spec_config.draft_model_config.model,
|
||||
)
|
||||
|
||||
global_expert_load = (
|
||||
@ -3311,7 +3333,7 @@ class GPUModelRunner(
|
||||
self.eplb_state = EplbState(self.parallel_config, self.device)
|
||||
self.eplb_state.add_model(
|
||||
self.drafter.model,
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
spec_config.draft_model_config,
|
||||
global_expert_load,
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
@ -3346,9 +3368,11 @@ class GPUModelRunner(
|
||||
scope="local",
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
mm_config = self.model_config.multimodal_config
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
supports_multimodal_pruning(self.get_model())
|
||||
and self.model_config.multimodal_config.is_multimodal_pruning_enabled()
|
||||
and mm_config is not None
|
||||
and mm_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
|
||||
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
|
||||
@ -3383,15 +3407,14 @@ class GPUModelRunner(
|
||||
# CudagraphWraper and CudagraphDispatcher of vllm.
|
||||
|
||||
# wrap the model with full cudagraph wrapper if needed.
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
and not self.parallel_config.enable_dbo
|
||||
):
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
|
||||
self.model = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
)
|
||||
elif self.parallel_config.enable_dbo:
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
if cudagraph_mode.has_full_cudagraphs():
|
||||
self.model = UBatchWrapper(
|
||||
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
|
||||
)
|
||||
@ -4071,7 +4094,8 @@ class GPUModelRunner(
|
||||
def profile_run(self) -> None:
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
if self.supports_mm_inputs:
|
||||
if self.model_config.multimodal_config.skip_mm_profiling:
|
||||
mm_config = self.model_config.multimodal_config
|
||||
if mm_config is not None and mm_config.skip_mm_profiling:
|
||||
logger.info(
|
||||
"Skipping memory profiling for multimodal encoder and "
|
||||
"encoder cache."
|
||||
@ -4333,8 +4357,9 @@ class GPUModelRunner(
|
||||
def get_attn_backends_for_group(
|
||||
kv_cache_group_spec: KVCacheGroupSpec,
|
||||
) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names
|
||||
self.vllm_config, layer_type, kv_cache_group_spec.layer_names
|
||||
)
|
||||
attn_backends = {}
|
||||
attn_backend_layers = defaultdict(list)
|
||||
@ -4349,7 +4374,7 @@ class GPUModelRunner(
|
||||
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
|
||||
attn_backend = create_fast_prefill_custom_backend(
|
||||
"FastPrefill",
|
||||
attn_backend,
|
||||
attn_backend, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
full_cls_name = attn_backend.full_cls_name()
|
||||
@ -4448,6 +4473,7 @@ class GPUModelRunner(
|
||||
min_cg_backend_name = attn_backend.__name__
|
||||
# Flexible resolve the cudagraph mode
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
# check cudagraph for mixed batch is supported
|
||||
if (
|
||||
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
|
||||
@ -4562,12 +4588,17 @@ class GPUModelRunner(
|
||||
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
|
||||
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
|
||||
)
|
||||
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
|
||||
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
||||
self.cudagraph_batch_sizes = (
|
||||
capture_sizes if capture_sizes is not None else []
|
||||
)
|
||||
|
||||
# Trigger cudagraph dispatching keys initialization after
|
||||
# resolved cudagraph mode.
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
||||
self.compilation_config.cudagraph_mode, self.uniform_decode_query_len
|
||||
cudagraph_mode, self.uniform_decode_query_len
|
||||
)
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
@ -4579,7 +4610,7 @@ class GPUModelRunner(
|
||||
"""
|
||||
min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b)
|
||||
|
||||
reorder_batch_thresholds = [
|
||||
reorder_batch_thresholds: list[int | None] = [
|
||||
group.get_metadata_builder().reorder_batch_threshold
|
||||
for group in self._attn_group_iterator()
|
||||
]
|
||||
@ -4588,7 +4619,7 @@ class GPUModelRunner(
|
||||
if len(reorder_batch_thresholds) == 0:
|
||||
self.reorder_batch_threshold = None
|
||||
return
|
||||
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)
|
||||
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment]
|
||||
|
||||
@staticmethod
|
||||
def select_common_block_size(
|
||||
@ -5048,12 +5079,16 @@ class GPUModelRunner(
|
||||
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
|
||||
for layer in layers.values():
|
||||
assert layer.impl.need_to_return_lse_for_decode, (
|
||||
layer_impl = getattr(layer, "impl", None)
|
||||
if layer_impl is None:
|
||||
continue
|
||||
assert layer_impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer.impl.__class__.__name__} "
|
||||
f"{layer_impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode."
|
||||
)
|
||||
|
||||
@ -5094,7 +5129,8 @@ class GPUModelRunner(
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
return {}
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if isinstance(attn_module, Attention) and (
|
||||
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
|
||||
|
||||
@ -121,18 +121,24 @@ class UBatchWrapper:
|
||||
|
||||
@staticmethod
|
||||
def _create_sm_control_context(vllm_config: VllmConfig):
|
||||
comm_sms = envs.VLLM_DBO_COMM_SMS
|
||||
comm_sms: int = envs.VLLM_DBO_COMM_SMS
|
||||
|
||||
set_comm_sms = lambda sms: None
|
||||
if vllm_config.parallel_config.enable_expert_parallel:
|
||||
# Currently only DeepEP highthroughput supports SM control so this
|
||||
# only affects that case.
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
ep_group = get_ep_group()
|
||||
device_communicator = ep_group.device_communicator
|
||||
all2all_manager = None
|
||||
if device_communicator is not None:
|
||||
all2all_manager = device_communicator.all2all_manager
|
||||
|
||||
if all2all_manager.max_sms_used() is not None:
|
||||
comm_sms = min(comm_sms, all2all_manager.max_sms_used())
|
||||
if all2all_manager is not None:
|
||||
max_sms_used = all2all_manager.max_sms_used()
|
||||
if max_sms_used is not None:
|
||||
comm_sms = min(comm_sms, max_sms_used)
|
||||
|
||||
if comm_sms > 0:
|
||||
if comm_sms > 0 and all2all_manager is not None:
|
||||
set_comm_sms = lambda sms: all2all_manager.set_num_sms(sms)
|
||||
|
||||
# TODO(lucas): support other kernels besides DeepGEMM
|
||||
|
||||
@ -6,7 +6,7 @@ import gc
|
||||
import os
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -87,8 +87,10 @@ class Worker(WorkerBase):
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# Torch/CUDA profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
# VLLM_TORCH_CUDA_PROFILE=1
|
||||
self.profiler: Any | None = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
self.profiler = TorchProfilerWrapper(
|
||||
@ -146,17 +148,17 @@ class Worker(WorkerBase):
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be used for one instance per process."
|
||||
)
|
||||
context = allocator.use_memory_pool(tag=tag)
|
||||
return allocator.use_memory_pool(tag=tag)
|
||||
else:
|
||||
context = nullcontext()
|
||||
return context
|
||||
return nullcontext()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
device = self.device_config.device
|
||||
if isinstance(device, torch.device) and device.type == "cuda":
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
if (
|
||||
@ -375,23 +377,21 @@ class Worker(WorkerBase):
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
with allocator.use_memory_pool(tag="kv_cache"):
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
||||
compile_sizes = self.vllm_config.compilation_config.compile_sizes
|
||||
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x
|
||||
for x in warmup_sizes
|
||||
if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
if capture_sizes is not None:
|
||||
warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes]
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
@ -532,12 +532,12 @@ class Worker(WorkerBase):
|
||||
)
|
||||
}
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
tensor_dict = get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
assert tensor_dict is not None
|
||||
intermediate_tensors = IntermediateTensors(tensor_dict)
|
||||
|
||||
with self.annotate_profile(scheduler_output):
|
||||
output = self.model_runner.execute_model(
|
||||
@ -605,7 +605,7 @@ class Worker(WorkerBase):
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_load=None,
|
||||
global_expert_loads=None,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -661,7 +661,7 @@ class Worker(WorkerBase):
|
||||
|
||||
def _reconfigure_moe(
|
||||
self, old_ep_size: int, new_ep_size: int
|
||||
) -> torch.Tensor | None:
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""
|
||||
Reconfigure MoE modules with provided reconfig_request
|
||||
|
||||
@ -728,26 +728,29 @@ class Worker(WorkerBase):
|
||||
num_local_physical_experts = num_local_experts
|
||||
assert self.model_runner.eplb_state is not None
|
||||
new_physical_experts = (
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts
|
||||
- self.model_runner.eplb_state.logical_replica_count.shape[1]
|
||||
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
global_expert_loads = None
|
||||
else:
|
||||
num_local_physical_experts = torch.tensor(
|
||||
num_local_physical_experts_tensor = torch.tensor(
|
||||
[num_local_experts], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
|
||||
num_local_physical_experts_tensor,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = num_local_physical_experts.item()
|
||||
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
|
||||
new_physical_experts = num_local_physical_experts * new_ep_size
|
||||
assert self.model_runner.eplb_state is not None
|
||||
global_expert_loads = self.model_runner.eplb_state.rearrange(
|
||||
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=False
|
||||
)
|
||||
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts - global_expert_loads[0].shape[1]
|
||||
)
|
||||
@ -849,8 +852,9 @@ def init_worker_distributed_environment(
|
||||
init_batch_invariance()
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_method = distributed_init_method or "env://"
|
||||
init_distributed_environment(
|
||||
parallel_config.world_size, rank, distributed_init_method, local_rank, backend
|
||||
parallel_config.world_size, rank, init_method, local_rank, backend
|
||||
)
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
|
||||
@ -59,7 +59,7 @@ class KVConnectorModelRunnerMixin:
|
||||
@staticmethod
|
||||
def ensure_kv_transfer_shutdown() -> None:
|
||||
# has_kv_transfer_group can be None during interpreter shutdown.
|
||||
if has_kv_transfer_group and has_kv_transfer_group():
|
||||
if has_kv_transfer_group and has_kv_transfer_group(): # type: ignore[truthy-function]
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -572,7 +572,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
||||
|
||||
@ -725,7 +728,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
if not use_max_model_len and num_tokens > self.most_model_len:
|
||||
if (
|
||||
not use_max_model_len
|
||||
and self.most_model_len is not None
|
||||
and num_tokens > self.most_model_len
|
||||
):
|
||||
use_max_model_len = True
|
||||
num_scheduled_tokens_per_req.append(num_tokens)
|
||||
if use_max_model_len:
|
||||
@ -737,6 +744,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
end_index = num_reqs
|
||||
else:
|
||||
assert self.num_reqs_most_model_len is not None
|
||||
if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len:
|
||||
num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
|
||||
: self.num_reqs_most_model_len
|
||||
@ -829,6 +837,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
].to(self.device)
|
||||
seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device)
|
||||
else:
|
||||
assert self.num_reqs_most_model_len is not None
|
||||
block_tables = self.block_table_cpu[
|
||||
: self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req
|
||||
]
|
||||
@ -931,6 +940,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = req_state.mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
continue
|
||||
mm_hash = mm_feature.identifier
|
||||
mm_kwargs.append(mm_feature.data)
|
||||
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
||||
@ -1114,7 +1125,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) -> ModelRunnerOutput:
|
||||
if self.scheduler_output is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
return None # noqa
|
||||
return None # type: ignore[return-value]
|
||||
scheduler_output = self.scheduler_output
|
||||
mm_embed_inputs = self.mm_embed_inputs
|
||||
self.scheduler_output = None
|
||||
@ -1696,7 +1707,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) -> None:
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
if self.supports_mm_inputs:
|
||||
if self.model_config.multimodal_config.skip_mm_profiling:
|
||||
mm_config = self.model_config.multimodal_config
|
||||
if mm_config is not None and mm_config.skip_mm_profiling:
|
||||
logger.info(
|
||||
"Skipping memory profiling for multimodal encoder and "
|
||||
"encoder cache."
|
||||
@ -2166,5 +2178,9 @@ def replace_set_lora(model):
|
||||
if isinstance(module, BaseLayerWithLoRA):
|
||||
module._original_set_lora = module.set_lora
|
||||
module._original_reset_lora = module.reset_lora
|
||||
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
|
||||
module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__)
|
||||
module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign]
|
||||
module, module.__class__
|
||||
)
|
||||
module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign]
|
||||
module, module.__class__
|
||||
)
|
||||
|
||||
@ -141,8 +141,7 @@ class TPUWorker:
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
if self.model_config.seed is not None:
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
@ -332,7 +331,7 @@ class TPUWorker:
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
distributed_init_method=distributed_init_method or "env://",
|
||||
backend=current_platform.dist_backend,
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
|
||||
@ -280,7 +280,7 @@ def bind_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
num_attn_module: int | None = 1,
|
||||
num_attn_module: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
@ -362,5 +362,7 @@ def is_residual_scattered_for_sp(
|
||||
or vllm_config.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
|
||||
return num_input_tokens in vllm_config.compilation_config.compile_sizes
|
||||
compile_sizes = vllm_config.compilation_config.compile_sizes
|
||||
if compile_sizes is None:
|
||||
return False
|
||||
return num_input_tokens in compile_sizes
|
||||
|
||||
@ -315,10 +315,12 @@ class WorkerWrapperBase:
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
|
||||
kv_cache_config = kv_cache_configs[self.global_rank]
|
||||
assert self.vllm_config is not None
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||
|
||||
def init_device(self):
|
||||
assert self.vllm_config is not None
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during device initialization
|
||||
self.worker.init_device() # type: ignore
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -37,6 +38,7 @@ class XPUWorker(Worker):
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
self.profiler: Any | None = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
|
||||
@ -148,7 +150,12 @@ class XPUWorker(Worker):
|
||||
return int(available_kv_cache_memory)
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "xpu" and current_platform.is_xpu():
|
||||
device = self.device_config.device
|
||||
if (
|
||||
isinstance(device, torch.device)
|
||||
and device.type == "xpu"
|
||||
and current_platform.is_xpu()
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
current_platform.check_if_supports_dtype(self.model_config.dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user