diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 6ddcbfea24ad3..7fec4782517cf 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -460,11 +460,16 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): {"load_config": { "load_format": original_load_format }}) - model_runner_2.load_model() # Load real weights inplace + model_runner_2.reload_weights() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( model_runner_2.get_model().state_dict()) +def test_reload_weights_before_load_model(model_runner): + with pytest.raises(AssertionError): + model_runner.reload_weights() + + def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 864cf91e78508..1ee379d34277d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1873,17 +1873,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - self.model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) - else: - logger.info( - "Model was already initialized. Loading weights inplace..." - ) - model_loader.load_weights(self.model, - model_config=self.model_config) + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config) if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, @@ -1916,6 +1908,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): rank_mapping, ) + def reload_weights(self) -> None: + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6411874883ef2..1c180322e12d1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -4,6 +4,7 @@ import copy import gc import os +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Optional import torch @@ -118,6 +119,21 @@ class Worker(WorkerBase): buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} + def _maybe_get_memory_pool_context(self, + tag: str) -> AbstractContextManager: + if self.vllm_config.model_config.enable_sleep_mode: + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag=tag) + else: + context = nullcontext() + return context + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -179,24 +195,17 @@ class Worker(WorkerBase): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: - if self.vllm_config.model_config.enable_sleep_mode: - from vllm.device_allocator.cumem import CuMemAllocator - - allocator = CuMemAllocator.get_instance() - assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") - context = allocator.use_memory_pool(tag="weights") - else: - from contextlib import nullcontext - context = nullcontext() eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" - with context: + with self._maybe_get_memory_pool_context(tag="weights"): self.model_runner.load_model(eep_scale_up=eep_scale_up) def update_config(self, overrides: dict[str, Any]) -> None: self.model_runner.update_config(overrides) + def reload_weights(self) -> None: + with self._maybe_get_memory_pool_context(tag="weights"): + self.model_runner.reload_weights() + @torch.inference_mode() def determine_available_memory(self) -> int: """Profiles the peak memory usage of the model to determine how much diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 31e9cff91247c..f160384f8f67a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1174,16 +1174,10 @@ class TPUModelRunner(LoRAModelRunnerMixin): mesh=self.mesh) else: model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) - else: - logger.info("Model was already initialized. \ - Loading weights inplace...") - model_loader.load_weights( - self.model, model_config=self.model_config) + logger.info("Loading model from scratch...") + model = model_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.model_config) except RuntimeError as e: raise RuntimeError( f"Unable to load model, a likely reason is the model is " @@ -1205,6 +1199,13 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.model = model self.sampler = TPUSampler() + def reload_weights(self) -> None: + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + @torch.no_grad() def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 592d9fc17c9e4..1d61878ca0803 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -265,6 +265,9 @@ class TPUWorker: def update_config(self, overrides: dict[str, Any]) -> None: self.model_runner.update_config(overrides) + def reload_weights(self) -> None: + self.model_runner.reload_weights() + def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model()