mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 21:13:40 +08:00
[test][RL] Add sleep level 2 test and fix reload with sleep mode (#23521)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
0ff902f3b4
commit
2a167b2eeb
@ -177,3 +177,34 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output3[0].outputs[0].text
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_deep_sleep():
|
||||
model = "Qwen/Qwen3-0.6B"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM(model, enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# Put the engine to deep sleep
|
||||
llm.sleep(level=2)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
assert used_bytes < 3 * GiB_bytes
|
||||
|
||||
llm.wake_up(tags=["weights"])
|
||||
llm.collective_rpc("reload_weights")
|
||||
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||
assert used_bytes < 4 * GiB_bytes
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
|
||||
@ -216,8 +216,7 @@ class Worker(WorkerBase):
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
def reload_weights(self) -> None:
|
||||
with self._maybe_get_memory_pool_context(tag="weights"):
|
||||
self.model_runner.reload_weights()
|
||||
self.model_runner.reload_weights()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user