[Core] Add update_config RPC method (#20095)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-07-13 17:49:18 -07:00 committed by GitHub
parent 4bbfc36b16
commit 8632e831ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 97 additions and 9 deletions

View File

@ -7,7 +7,7 @@ import pytest
from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
get_field)
get_field, update_config)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@ -46,6 +46,34 @@ def test_get_field():
assert c.default_factory is MISSING
@dataclass
class _TestNestedConfig:
a: _TestConfigFields = field(
default_factory=lambda: _TestConfigFields(a=0))
def test_update_config():
# Simple update
config1 = _TestConfigFields(a=0)
new_config1 = update_config(config1, {"a": 42})
assert new_config1.a == 42
# Nonexistent field
with pytest.raises(AssertionError):
new_config1 = update_config(config1, {"nonexistent": 1})
# Nested update with dataclass
config2 = _TestNestedConfig()
new_inner_config = _TestConfigFields(a=1, c="new_value")
new_config2 = update_config(config2, {"a": new_inner_config})
assert new_config2.a == new_inner_config
# Nested update with dict
config3 = _TestNestedConfig()
new_config3 = update_config(config3, {"a": {"c": "new_value"}})
assert new_config3.a.c == "new_value"
# Nested update with invalid type
with pytest.raises(AssertionError):
new_config3 = update_config(config3, {"a": "new_value"})
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[

View File

@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
assert model_runner.load_config.load_format == "dummy"
# Raise error on non-existing config
with pytest.raises(AssertionError):
model_runner.update_config({"do_not_exist_config": "dummy"})
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy"
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format
model_runner_2.update_config(
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())

View File

@ -71,6 +71,7 @@ if TYPE_CHECKING:
ConfigType = type[DataclassInstance]
HfOverrides = Union[dict, Callable[[type], type]]
else:
DataclassInstance = Any
PlacementGroup = Any
PretrainedConfig = Any
ExecutorBase = Any
@ -87,7 +88,7 @@ else:
"vllm.model_executor.models")
logger = init_logger(__name__)
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None
def update_config(config: DataclassInstanceT,
overrides: dict[str, Any]) -> DataclassInstanceT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(
config, field_name), f"{type(config)} has no field `{field_name}`"
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}")
value = update_config(
current_value, # type: ignore[type-var]
value)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)

View File

@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117

View File

@ -4,7 +4,7 @@
import copy
import gc
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
import torch.distributed
@ -193,6 +193,9 @@ class Worker(WorkerBase):
with context:
self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much

View File

@ -3,7 +3,7 @@
import bisect
import gc
import time
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch
import numpy as np
@ -18,7 +18,8 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return model_runner_output
def update_config(self, overrides: dict[str, Any]) -> None:
# TODO: TPU config may need extra validation
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None:
self.device = self.device_config.device

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from typing import Optional
from typing import Any, Optional
import torch
import torch.distributed
@ -260,6 +260,9 @@ class TPUWorker:
def load_model(self) -> None:
self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()