diff --git a/tests/test_config.py b/tests/test_config.py index a160b08f28aa5..015baef918110 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ import pytest from vllm.compilation.backends import VllmBackend from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - get_field) + get_field, update_config) from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform @@ -46,6 +46,34 @@ def test_get_field(): assert c.default_factory is MISSING +@dataclass +class _TestNestedConfig: + a: _TestConfigFields = field( + default_factory=lambda: _TestConfigFields(a=0)) + + +def test_update_config(): + # Simple update + config1 = _TestConfigFields(a=0) + new_config1 = update_config(config1, {"a": 42}) + assert new_config1.a == 42 + # Nonexistent field + with pytest.raises(AssertionError): + new_config1 = update_config(config1, {"nonexistent": 1}) + # Nested update with dataclass + config2 = _TestNestedConfig() + new_inner_config = _TestConfigFields(a=1, c="new_value") + new_config2 = update_config(config2, {"a": new_inner_config}) + assert new_config2.a == new_inner_config + # Nested update with dict + config3 = _TestNestedConfig() + new_config3 = update_config(config3, {"a": {"c": "new_value"}}) + assert new_config3.a.c == "new_value" + # Nested update with invalid type + with pytest.raises(AssertionError): + new_config3 = update_config(config3, {"a": "new_value"}) + + @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index d13df553db623..0bdf1f9820d34 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) +def test_update_config(model_runner): + # Simple update + model_runner.update_config({"load_config": {"load_format": "dummy"}}) + assert model_runner.load_config.load_format == "dummy" + # Raise error on non-existing config + with pytest.raises(AssertionError): + model_runner.update_config({"do_not_exist_config": "dummy"}) + + def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): # In this test, model_runner loads model + weights in one go, while # model_runner_2 loads dummy weights first then load real weights inplace model_runner.load_model() original_load_format = model_runner_2.load_config.load_format - model_runner_2.load_config.load_format = "dummy" + model_runner_2.update_config({"load_config": {"load_format": "dummy"}}) model_runner_2.load_model() # Initial model loading with dummy weights assert str(model_runner.get_model().state_dict()) != str( model_runner_2.get_model().state_dict()) - model_runner_2.load_config.load_format = original_load_format + model_runner_2.update_config( + {"load_config": { + "load_format": original_load_format + }}) model_runner_2.load_model() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( model_runner_2.get_model().state_dict()) diff --git a/vllm/config.py b/vllm/config.py index d475cdbcb1c7c..adfc684c4f93f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -71,6 +71,7 @@ if TYPE_CHECKING: ConfigType = type[DataclassInstance] HfOverrides = Union[dict, Callable[[type], type]] else: + DataclassInstance = Any PlacementGroup = Any PretrainedConfig = Any ExecutorBase = Any @@ -87,7 +88,7 @@ else: "vllm.model_executor.models") logger = init_logger(__name__) - +DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance) ConfigT = TypeVar("ConfigT", bound=ConfigType) TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", @@ -5049,3 +5050,21 @@ class SpeechToTextConfig: @property def allow_audio_chunking(self) -> bool: return self.min_energy_split_window_size is not None + + +def update_config(config: DataclassInstanceT, + overrides: dict[str, Any]) -> DataclassInstanceT: + processed_overrides = {} + for field_name, value in overrides.items(): + assert hasattr( + config, field_name), f"{type(config)} has no field `{field_name}`" + current_value = getattr(config, field_name) + if is_dataclass(current_value) and not is_dataclass(value): + assert isinstance(value, dict), ( + f"Overrides to {type(config)}.{field_name} must be a dict" + f" or {type(current_value)}, but got {type(value)}") + value = update_config( + current_value, # type: ignore[type-var] + value) + processed_overrides[field_name] = value + return replace(config, **processed_overrides) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 44de1469d1b10..4551cb2df98ac 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) + get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): draft_token_ids.append(drafter_output.tolist()) return draft_token_ids + def update_config(self, overrides: dict[str, Any]) -> None: + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 3c764bcdcb21c..6458b55777a4d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -4,7 +4,7 @@ import copy import gc import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed @@ -193,6 +193,9 @@ class Worker(WorkerBase): with context: self.model_runner.load_model() + def update_config(self, overrides: dict[str, Any]) -> None: + self.model_runner.update_config(overrides) + @torch.inference_mode() def determine_available_memory(self) -> int: """Profiles the peak memory usage of the model to determine how much diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5af052e685117..eb96e56f495f1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from unittest.mock import patch import numpy as np @@ -18,7 +18,8 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config +from vllm.config import (ParallelConfig, VllmConfig, + get_layers_from_vllm_config, update_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA @@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): return model_runner_output + def update_config(self, overrides: dict[str, Any]) -> None: + # TODO: TPU config may need extra validation + # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754 + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + def load_model(self) -> None: self.device = self.device_config.device diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ade4d08211683..c5336e9ad519e 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" import os -from typing import Optional +from typing import Any, Optional import torch import torch.distributed @@ -260,6 +260,9 @@ class TPUWorker: def load_model(self) -> None: self.model_runner.load_model() + def update_config(self, overrides: dict[str, Any]) -> None: + self.model_runner.update_config(overrides) + def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model()