Wentao Ye 56669c1f29
[CI] Fix mypy for vllm/v1/worker (#29037)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-11-21 11:36:07 +08:00

95 lines
3.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utils for model executor."""
import copy
from typing import Any
import torch
from vllm.utils.torch_utils import is_torch_equal_or_newer
def set_random_seed(seed: int | None) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
def set_weight_attrs(
weight: torch.Tensor,
weight_attrs: dict[str, Any] | None,
):
"""Set attributes on a weight tensor.
This method is used to set attributes on a weight tensor. This method
will not overwrite existing attributes.
Args:
weight: The weight tensor.
weight_attrs: A dictionary of attributes to set on the weight tensor.
"""
if weight_attrs is None:
return
for key, value in weight_attrs.items():
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
# NOTE(woosuk): During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.use_sync_weight_loader() and key == "weight_loader":
value = current_platform.make_synced_weight_loader(value)
setattr(weight, key, value)
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
parent_map = getattr(model, "packed_modules_mapping", None)
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
# don't infer mapping if the model has defined it explicitly.
if parent_map:
return parent_map
# We only check main components instead of whole model submodules
for child in model.children():
child_map = getattr(child, "packed_modules_mapping", None)
child_map = copy.deepcopy(child_map) if child_map is not None else {}
if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()):
raise ValueError(
f"Can't update {type(model).__name__}'s packed_modules_mapping "
f"safely because of conflicts from {type(child).__name__}."
)
else:
parent_map.update(child_map)
return parent_map
def get_moe_expert_mapping(
model: torch.nn.Module,
) -> list[tuple[str, str, int, str]]:
if parent_map := getattr(model, "get_expert_mapping", None):
return parent_map()
else:
# We only check main components instead of whole model submodules
for child in model.children():
child_map = getattr(child, "get_expert_mapping", None)
if child_map is not None:
return child_map()
return []
def maybe_disable_graph_partition(current_backend: str) -> dict[str, bool]:
if current_backend == "inductor" and is_torch_equal_or_newer("2.9.0.dev"):
return {"graph_partition": False}
else:
return {}