mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 22:25:01 +08:00
[Core] Add reload_weights RPC method (#20096)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
14bf19e39f
commit
5c9b807b34
@ -460,11 +460,16 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
|||||||
{"load_config": {
|
{"load_config": {
|
||||||
"load_format": original_load_format
|
"load_format": original_load_format
|
||||||
}})
|
}})
|
||||||
model_runner_2.load_model() # Load real weights inplace
|
model_runner_2.reload_weights() # Load real weights inplace
|
||||||
assert str(model_runner.get_model().state_dict()) == str(
|
assert str(model_runner.get_model().state_dict()) == str(
|
||||||
model_runner_2.get_model().state_dict())
|
model_runner_2.get_model().state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_reload_weights_before_load_model(model_runner):
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
model_runner.reload_weights()
|
||||||
|
|
||||||
|
|
||||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
layer_0 = "model.layers.0.self_attn.attn"
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
|||||||
@ -1873,17 +1873,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with DeviceMemoryProfiler() as m:
|
with DeviceMemoryProfiler() as m:
|
||||||
time_before_load = time.perf_counter()
|
time_before_load = time.perf_counter()
|
||||||
model_loader = get_model_loader(self.load_config)
|
model_loader = get_model_loader(self.load_config)
|
||||||
if not hasattr(self, "model"):
|
|
||||||
logger.info("Loading model from scratch...")
|
logger.info("Loading model from scratch...")
|
||||||
self.model = model_loader.load_model(
|
self.model = model_loader.load_model(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config, model_config=self.model_config)
|
||||||
model_config=self.model_config)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Model was already initialized. Loading weights inplace..."
|
|
||||||
)
|
|
||||||
model_loader.load_weights(self.model,
|
|
||||||
model_config=self.model_config)
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.model = self.load_lora_model(self.model,
|
self.model = self.load_lora_model(self.model,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
@ -1916,6 +1908,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
rank_mapping,
|
rank_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reload_weights(self) -> None:
|
||||||
|
assert getattr(self, "model", None) is not None, \
|
||||||
|
"Cannot reload weights before model is loaded."
|
||||||
|
model_loader = get_model_loader(self.load_config)
|
||||||
|
logger.info("Reloading weights inplace...")
|
||||||
|
model_loader.load_weights(self.model, model_config=self.model_config)
|
||||||
|
|
||||||
def save_tensorized_model(
|
def save_tensorized_model(
|
||||||
self,
|
self,
|
||||||
tensorizer_config: "TensorizerConfig",
|
tensorizer_config: "TensorizerConfig",
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
from contextlib import AbstractContextManager, nullcontext
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -118,6 +119,21 @@ class Worker(WorkerBase):
|
|||||||
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
buffer.data.copy_(self._sleep_saved_buffers[name].data)
|
||||||
self._sleep_saved_buffers = {}
|
self._sleep_saved_buffers = {}
|
||||||
|
|
||||||
|
def _maybe_get_memory_pool_context(self,
|
||||||
|
tag: str) -> AbstractContextManager:
|
||||||
|
if self.vllm_config.model_config.enable_sleep_mode:
|
||||||
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
|
|
||||||
|
allocator = CuMemAllocator.get_instance()
|
||||||
|
if tag == "weights":
|
||||||
|
assert allocator.get_current_usage() == 0, (
|
||||||
|
"Sleep mode can only be "
|
||||||
|
"used for one instance per process.")
|
||||||
|
context = allocator.use_memory_pool(tag=tag)
|
||||||
|
else:
|
||||||
|
context = nullcontext()
|
||||||
|
return context
|
||||||
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int,
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int) -> None:
|
num_cpu_blocks: int) -> None:
|
||||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
@ -179,24 +195,17 @@ class Worker(WorkerBase):
|
|||||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||||
# to hijack tensor allocation.
|
# to hijack tensor allocation.
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
if self.vllm_config.model_config.enable_sleep_mode:
|
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
|
||||||
|
|
||||||
allocator = CuMemAllocator.get_instance()
|
|
||||||
assert allocator.get_current_usage() == 0, (
|
|
||||||
"Sleep mode can only be "
|
|
||||||
"used for one instance per process.")
|
|
||||||
context = allocator.use_memory_pool(tag="weights")
|
|
||||||
else:
|
|
||||||
from contextlib import nullcontext
|
|
||||||
context = nullcontext()
|
|
||||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||||
with context:
|
with self._maybe_get_memory_pool_context(tag="weights"):
|
||||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||||
|
|
||||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||||
self.model_runner.update_config(overrides)
|
self.model_runner.update_config(overrides)
|
||||||
|
|
||||||
|
def reload_weights(self) -> None:
|
||||||
|
with self._maybe_get_memory_pool_context(tag="weights"):
|
||||||
|
self.model_runner.reload_weights()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def determine_available_memory(self) -> int:
|
def determine_available_memory(self) -> int:
|
||||||
"""Profiles the peak memory usage of the model to determine how much
|
"""Profiles the peak memory usage of the model to determine how much
|
||||||
|
|||||||
@ -1174,16 +1174,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
mesh=self.mesh)
|
mesh=self.mesh)
|
||||||
else:
|
else:
|
||||||
model_loader = get_model_loader(self.load_config)
|
model_loader = get_model_loader(self.load_config)
|
||||||
if not hasattr(self, "model"):
|
|
||||||
logger.info("Loading model from scratch...")
|
logger.info("Loading model from scratch...")
|
||||||
model = model_loader.load_model(
|
model = model_loader.load_model(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
model_config=self.model_config)
|
model_config=self.model_config)
|
||||||
else:
|
|
||||||
logger.info("Model was already initialized. \
|
|
||||||
Loading weights inplace...")
|
|
||||||
model_loader.load_weights(
|
|
||||||
self.model, model_config=self.model_config)
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unable to load model, a likely reason is the model is "
|
f"Unable to load model, a likely reason is the model is "
|
||||||
@ -1205,6 +1199,13 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.sampler = TPUSampler()
|
self.sampler = TPUSampler()
|
||||||
|
|
||||||
|
def reload_weights(self) -> None:
|
||||||
|
assert getattr(self, "model", None) is not None, \
|
||||||
|
"Cannot reload weights before model is loaded."
|
||||||
|
model_loader = get_model_loader(self.load_config)
|
||||||
|
logger.info("Reloading weights inplace...")
|
||||||
|
model_loader.load_weights(self.model, model_config=self.model_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _dummy_run(self, num_tokens: int, num_reqs: int,
|
def _dummy_run(self, num_tokens: int, num_reqs: int,
|
||||||
num_blocks: int) -> None:
|
num_blocks: int) -> None:
|
||||||
|
|||||||
@ -265,6 +265,9 @@ class TPUWorker:
|
|||||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||||
self.model_runner.update_config(overrides)
|
self.model_runner.update_config(overrides)
|
||||||
|
|
||||||
|
def reload_weights(self) -> None:
|
||||||
|
self.model_runner.reload_weights()
|
||||||
|
|
||||||
def compile_or_warm_up_model(self) -> None:
|
def compile_or_warm_up_model(self) -> None:
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user