mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:55:19 +08:00
[BugFix] fix speculative decoding memory leak when speculation is disabled (#15506)
Signed-off-by: Noah Yoshida <noahcy117@gmail.com>
This commit is contained in:
parent
90d0a54c4d
commit
13cf6b6236
90
tests/spec_decode/test_memory_usage.py
Normal file
90
tests/spec_decode/test_memory_usage.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""This docstring details important information on the testing methodology.
|
||||||
|
|
||||||
|
This test verifies that memory usage remains constant (or never grows) when
|
||||||
|
we enable / disable speculation via --speculative-disable-by-batch-size.
|
||||||
|
|
||||||
|
There are a lot of things we try to keep track of between batches of requests
|
||||||
|
and if certain tensors are not freed from memory, can result in CUDA ooms.
|
||||||
|
|
||||||
|
This is particularly relevant for production situations where speculation might
|
||||||
|
be enabled during off hours, but disabled once traffic peaks during the workday.
|
||||||
|
Since traffic will stay high for a long period of time, verifying we do not
|
||||||
|
increase our memory usage over time is essential to prevent possible CUDA ooms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from tests.core.utils import create_dummy_prompt
|
||||||
|
from vllm.sequence import SequenceGroup
|
||||||
|
|
||||||
|
ITERATIONS = 100
|
||||||
|
MAIN_MODEL = "JackFram/llama-68m"
|
||||||
|
|
||||||
|
# speculative model
|
||||||
|
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
||||||
|
|
||||||
|
BATCH_SIZE = 5
|
||||||
|
SPEC_DISABLE_BATCH_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
def add_seq_group_to_engine(engine: vllm.LLMEngine, seq_group: SequenceGroup):
|
||||||
|
scheduler = engine.scheduler[0]
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Since we are using a batch size greater than the disabled batch size,
|
||||||
|
we can ensure we go through the _no_spec codepath for most of our engine steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_usage_no_spec():
|
||||||
|
previous_memory_allocated = None
|
||||||
|
llm = vllm.LLM(
|
||||||
|
model=MAIN_MODEL,
|
||||||
|
speculative_model=SPEC_MODEL,
|
||||||
|
num_speculative_tokens=3,
|
||||||
|
speculative_disable_by_batch_size=SPEC_DISABLE_BATCH_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_sequences = set()
|
||||||
|
engine = llm.llm_engine
|
||||||
|
|
||||||
|
for i in range(ITERATIONS):
|
||||||
|
seq, seq_group = create_dummy_prompt(request_id=str(i),
|
||||||
|
prompt_length=10,
|
||||||
|
min_tokens=10,
|
||||||
|
max_tokens=10)
|
||||||
|
|
||||||
|
add_seq_group_to_engine(engine, seq_group)
|
||||||
|
|
||||||
|
batch_sequences.add(seq)
|
||||||
|
engine.step()
|
||||||
|
for seq in list(batch_sequences):
|
||||||
|
if seq.is_finished():
|
||||||
|
batch_sequences.remove(seq)
|
||||||
|
|
||||||
|
# If we aren't at our batch size yet, continue
|
||||||
|
if len(batch_sequences) <= BATCH_SIZE:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Otherwise, loop until at least one request is done
|
||||||
|
while not any(seq.is_finished() for seq in batch_sequences):
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
# Remove it from the set
|
||||||
|
for seq in list(batch_sequences):
|
||||||
|
if seq.is_finished():
|
||||||
|
batch_sequences.remove(seq)
|
||||||
|
|
||||||
|
# At this point, we are always at the case where we have finished
|
||||||
|
# processing some number of requests from the batch after running
|
||||||
|
# several _no_spec executions. The memory should not have
|
||||||
|
# increased between the previous time this was recorded and the
|
||||||
|
# current time.
|
||||||
|
if previous_memory_allocated is None:
|
||||||
|
previous_memory_allocated = torch.cuda.memory_allocated()
|
||||||
|
else:
|
||||||
|
assert previous_memory_allocated == torch.cuda.memory_allocated()
|
||||||
@ -695,6 +695,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
|
|||||||
seq_group_meta_with_hidden):
|
seq_group_meta_with_hidden):
|
||||||
self.previous_hidden_states.update(hidden_states,
|
self.previous_hidden_states.update(hidden_states,
|
||||||
seq_group_meta_with_hidden)
|
seq_group_meta_with_hidden)
|
||||||
|
self.previous_hidden_states.prune(seq_group_meta_with_hidden)
|
||||||
|
|
||||||
if not skip_proposer:
|
if not skip_proposer:
|
||||||
# We prepare the prefill hidden states here so that there no
|
# We prepare the prefill hidden states here so that there no
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user