mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 09:06:01 +08:00
[BugFix] Fix Memory Leak (#17567)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
cc2a77d7f1
commit
c777df79f7
@ -1165,3 +1165,80 @@ def test_kv_connector_handles_preemption():
|
|||||||
# All memory should be freed since nothing is running.
|
# All memory should be freed since nothing is running.
|
||||||
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
|
||||||
== NUM_BLOCKS - 1
|
== NUM_BLOCKS - 1
|
||||||
|
|
||||||
|
|
||||||
|
def make_output(scheduler: Scheduler):
|
||||||
|
return ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in scheduler.running],
|
||||||
|
req_id_to_index={
|
||||||
|
req.request_id: i
|
||||||
|
for i, req in enumerate(scheduler.running)
|
||||||
|
},
|
||||||
|
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_scheduler_empty(scheduler: Scheduler):
|
||||||
|
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||||
|
# Scheduler Metadata.
|
||||||
|
assert len(scheduler.requests) == 0
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
|
assert len(scheduler._cached_reqs_data) == 0
|
||||||
|
|
||||||
|
# EncoderCacheManager.
|
||||||
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||||
|
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||||
|
|
||||||
|
# KVCache Manager.
|
||||||
|
assert len(scheduler.kv_cache_manager.req_to_blocks) == 0
|
||||||
|
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||||
|
assert len(scheduler.kv_cache_manager.num_cached_block) == 0
|
||||||
|
num_free_blocks = (
|
||||||
|
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||||
|
assert num_free_blocks == (
|
||||||
|
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||||
|
|
||||||
|
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||||
|
# value, etc will remain since we lazily evict for prefix cache.
|
||||||
|
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||||
|
assert block.ref_cnt == 0
|
||||||
|
# assert block._block_hash is None
|
||||||
|
# assert (
|
||||||
|
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
||||||
|
# ) == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_leak():
|
||||||
|
"""Test that we do not have a memory leak."""
|
||||||
|
|
||||||
|
scheduler = create_scheduler(enable_prefix_caching=True)
|
||||||
|
|
||||||
|
NUM_REQUESTS = 5
|
||||||
|
NUM_TOKENS = 10
|
||||||
|
MAX_TOKENS = 10
|
||||||
|
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
max_tokens=MAX_TOKENS)
|
||||||
|
|
||||||
|
# Add each request.
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = make_output(scheduler)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# Iterate until done.
|
||||||
|
while True:
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
if len(scheduler.running) == 0:
|
||||||
|
break
|
||||||
|
model_runner_output = make_output(scheduler)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# Confirm no memory leak.
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
|
|||||||
@ -761,7 +761,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Return the cached request data to the queue so they can be reused.
|
# Return the cached request data to the queue so they can be reused.
|
||||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
||||||
|
# to _cached_reqs_data will cause a memory leak.
|
||||||
|
if req_data.req_id not in self.finished_req_ids:
|
||||||
|
self._cached_reqs_data[req_data.req_id].append(req_data)
|
||||||
|
|
||||||
self.running = new_running
|
self.running = new_running
|
||||||
engine_core_outputs = EngineCoreOutputs(
|
engine_core_outputs = EngineCoreOutputs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user