[Core] Add update_config RPC method (#20095)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-07-13 17:49:18 -07:00 committed by GitHub
parent 4bbfc36b16
commit 8632e831ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 97 additions and 9 deletions

View File

@ -7,7 +7,7 @@ import pytest
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
get_field) get_field, update_config)
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -46,6 +46,34 @@ def test_get_field():
assert c.default_factory is MISSING assert c.default_factory is MISSING
@dataclass
class _TestNestedConfig:
a: _TestConfigFields = field(
default_factory=lambda: _TestConfigFields(a=0))
def test_update_config():
# Simple update
config1 = _TestConfigFields(a=0)
new_config1 = update_config(config1, {"a": 42})
assert new_config1.a == 42
# Nonexistent field
with pytest.raises(AssertionError):
new_config1 = update_config(config1, {"nonexistent": 1})
# Nested update with dataclass
config2 = _TestNestedConfig()
new_inner_config = _TestConfigFields(a=1, c="new_value")
new_config2 = update_config(config2, {"a": new_inner_config})
assert new_config2.a == new_inner_config
# Nested update with dict
config3 = _TestNestedConfig()
new_config3 = update_config(config3, {"a": {"c": "new_value"}})
assert new_config3.a.c == "new_value"
# Nested update with invalid type
with pytest.raises(AssertionError):
new_config3 = update_config(config3, {"a": "new_value"})
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"), ("model_id", "expected_runner_type", "expected_task"),
[ [

View File

@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
assert model_runner.load_config.load_format == "dummy"
# Raise error on non-existing config
with pytest.raises(AssertionError):
model_runner.update_config({"do_not_exist_config": "dummy"})
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# In this test, model_runner loads model + weights in one go, while # In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace # model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model() model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy" model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
model_runner_2.load_model() # Initial model loading with dummy weights model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str( assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict()) model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format model_runner_2.update_config(
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str( assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict()) model_runner_2.get_model().state_dict())

View File

@ -71,6 +71,7 @@ if TYPE_CHECKING:
ConfigType = type[DataclassInstance] ConfigType = type[DataclassInstance]
HfOverrides = Union[dict, Callable[[type], type]] HfOverrides = Union[dict, Callable[[type], type]]
else: else:
DataclassInstance = Any
PlacementGroup = Any PlacementGroup = Any
PretrainedConfig = Any PretrainedConfig = Any
ExecutorBase = Any ExecutorBase = Any
@ -87,7 +88,7 @@ else:
"vllm.model_executor.models") "vllm.model_executor.models")
logger = init_logger(__name__) logger = init_logger(__name__)
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
ConfigT = TypeVar("ConfigT", bound=ConfigType) ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
@property @property
def allow_audio_chunking(self) -> bool: def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None return self.min_energy_split_window_size is not None
def update_config(config: DataclassInstanceT,
overrides: dict[str, Any]) -> DataclassInstanceT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(
config, field_name), f"{type(config)} has no field `{field_name}`"
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}")
value = update_config(
current_value, # type: ignore[type-var]
value)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)

View File

@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append(drafter_output.tolist()) draft_token_ids.append(drafter_output.tolist())
return draft_token_ids return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117

View File

@ -4,7 +4,7 @@
import copy import copy
import gc import gc
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
import torch.distributed import torch.distributed
@ -193,6 +193,9 @@ class Worker(WorkerBase):
with context: with context:
self.model_runner.load_model() self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
@torch.inference_mode() @torch.inference_mode()
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much """Profiles the peak memory usage of the model to determine how much

View File

@ -3,7 +3,7 @@
import bisect import bisect
import gc import gc
import time import time
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
@ -18,7 +18,8 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return model_runner_output return model_runner_output
def update_config(self, overrides: dict[str, Any]) -> None:
# TODO: TPU config may need extra validation
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None: def load_model(self) -> None:
self.device = self.device_config.device self.device = self.device_config.device

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class.""" """A TPU worker class."""
import os import os
from typing import Optional from typing import Any, Optional
import torch import torch
import torch.distributed import torch.distributed
@ -260,6 +260,9 @@ class TPUWorker:
def load_model(self) -> None: def load_model(self) -> None:
self.model_runner.load_model() self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model() self.model_runner.capture_model()