mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:35:27 +08:00
[Core] Add update_config RPC method (#20095)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
4bbfc36b16
commit
8632e831ba
@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
|
||||
get_field)
|
||||
get_field, update_config)
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -46,6 +46,34 @@ def test_get_field():
|
||||
assert c.default_factory is MISSING
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestNestedConfig:
|
||||
a: _TestConfigFields = field(
|
||||
default_factory=lambda: _TestConfigFields(a=0))
|
||||
|
||||
|
||||
def test_update_config():
|
||||
# Simple update
|
||||
config1 = _TestConfigFields(a=0)
|
||||
new_config1 = update_config(config1, {"a": 42})
|
||||
assert new_config1.a == 42
|
||||
# Nonexistent field
|
||||
with pytest.raises(AssertionError):
|
||||
new_config1 = update_config(config1, {"nonexistent": 1})
|
||||
# Nested update with dataclass
|
||||
config2 = _TestNestedConfig()
|
||||
new_inner_config = _TestConfigFields(a=1, c="new_value")
|
||||
new_config2 = update_config(config2, {"a": new_inner_config})
|
||||
assert new_config2.a == new_inner_config
|
||||
# Nested update with dict
|
||||
config3 = _TestNestedConfig()
|
||||
new_config3 = update_config(config3, {"a": {"c": "new_value"}})
|
||||
assert new_config3.a.c == "new_value"
|
||||
# Nested update with invalid type
|
||||
with pytest.raises(AssertionError):
|
||||
new_config3 = update_config(config3, {"a": "new_value"})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_task"),
|
||||
[
|
||||
|
||||
@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
|
||||
|
||||
def test_update_config(model_runner):
|
||||
# Simple update
|
||||
model_runner.update_config({"load_config": {"load_format": "dummy"}})
|
||||
assert model_runner.load_config.load_format == "dummy"
|
||||
# Raise error on non-existing config
|
||||
with pytest.raises(AssertionError):
|
||||
model_runner.update_config({"do_not_exist_config": "dummy"})
|
||||
|
||||
|
||||
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
||||
# In this test, model_runner loads model + weights in one go, while
|
||||
# model_runner_2 loads dummy weights first then load real weights inplace
|
||||
model_runner.load_model()
|
||||
original_load_format = model_runner_2.load_config.load_format
|
||||
model_runner_2.load_config.load_format = "dummy"
|
||||
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
|
||||
model_runner_2.load_model() # Initial model loading with dummy weights
|
||||
assert str(model_runner.get_model().state_dict()) != str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
model_runner_2.load_config.load_format = original_load_format
|
||||
model_runner_2.update_config(
|
||||
{"load_config": {
|
||||
"load_format": original_load_format
|
||||
}})
|
||||
model_runner_2.load_model() # Load real weights inplace
|
||||
assert str(model_runner.get_model().state_dict()) == str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
|
||||
@ -71,6 +71,7 @@ if TYPE_CHECKING:
|
||||
ConfigType = type[DataclassInstance]
|
||||
HfOverrides = Union[dict, Callable[[type], type]]
|
||||
else:
|
||||
DataclassInstance = Any
|
||||
PlacementGroup = Any
|
||||
PretrainedConfig = Any
|
||||
ExecutorBase = Any
|
||||
@ -87,7 +88,7 @@ else:
|
||||
"vllm.model_executor.models")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
|
||||
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return self.min_energy_split_window_size is not None
|
||||
|
||||
|
||||
def update_config(config: DataclassInstanceT,
|
||||
overrides: dict[str, Any]) -> DataclassInstanceT:
|
||||
processed_overrides = {}
|
||||
for field_name, value in overrides.items():
|
||||
assert hasattr(
|
||||
config, field_name), f"{type(config)} has no field `{field_name}`"
|
||||
current_value = getattr(config, field_name)
|
||||
if is_dataclass(current_value) and not is_dataclass(value):
|
||||
assert isinstance(value, dict), (
|
||||
f"Overrides to {type(config)}.{field_name} must be a dict"
|
||||
f" or {type(current_value)}, but got {type(value)}")
|
||||
value = update_config(
|
||||
current_value, # type: ignore[type-var]
|
||||
value)
|
||||
processed_overrides[field_name] = value
|
||||
return replace(config, **processed_overrides)
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
get_layers_from_vllm_config, update_config)
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
return draft_token_ids
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
allowed_config_names = {"load_config", "model_config"}
|
||||
for config_name, config_overrides in overrides.items():
|
||||
assert config_name in allowed_config_names, \
|
||||
f"Config `{config_name}` not supported. " \
|
||||
f"Allowed configs: {allowed_config_names}"
|
||||
config = getattr(self, config_name)
|
||||
new_config = update_config(config, config_overrides)
|
||||
setattr(self, config_name, new_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -193,6 +193,9 @@ class Worker(WorkerBase):
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Profiles the peak memory usage of the model to determine how much
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import bisect
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@ -18,7 +18,8 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import (ParallelConfig, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return model_runner_output
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
# TODO: TPU config may need extra validation
|
||||
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
|
||||
allowed_config_names = {"load_config", "model_config"}
|
||||
for config_name, config_overrides in overrides.items():
|
||||
assert config_name in allowed_config_names, \
|
||||
f"Config `{config_name}` not supported. " \
|
||||
f"Allowed configs: {allowed_config_names}"
|
||||
config = getattr(self, config_name)
|
||||
new_config = update_config(config, config_overrides)
|
||||
setattr(self, config_name, new_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A TPU worker class."""
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -260,6 +260,9 @@ class TPUWorker:
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user