mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Bugfix][Core] Prevent token lengths exceeding max_model_len in V0 (#19348)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
5cf2daea9a
commit
c1c7dbbeeb
@ -25,6 +25,12 @@ TOKEN_IDS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
"""We can run both engines for this test."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def llm():
|
def llm():
|
||||||
# pytest caches the fixture so we use weakref.proxy to
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
@ -104,3 +110,19 @@ def test_multiple_sampling_params(llm: LLM):
|
|||||||
# sampling_params is None, default params should be applied
|
# sampling_params is None, default params should be applied
|
||||||
outputs = llm.generate(PROMPTS, sampling_params=None)
|
outputs = llm.generate(PROMPTS, sampling_params=None)
|
||||||
assert len(PROMPTS) == len(outputs)
|
assert len(PROMPTS) == len(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_model_len():
|
||||||
|
max_model_len = 20
|
||||||
|
llm = LLM(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
gpu_memory_utilization=0.10,
|
||||||
|
enforce_eager=True, # reduce test time
|
||||||
|
)
|
||||||
|
sampling_params = SamplingParams(max_tokens=max_model_len + 10)
|
||||||
|
outputs = llm.generate(PROMPTS, sampling_params)
|
||||||
|
for output in outputs:
|
||||||
|
num_total_tokens = len(output.prompt_token_ids) + len(
|
||||||
|
output.outputs[0].token_ids)
|
||||||
|
assert num_total_tokens == max_model_len
|
||||||
|
|||||||
@ -82,7 +82,7 @@ class StopChecker:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check if the sequence has reached max_model_len.
|
# Check if the sequence has reached max_model_len.
|
||||||
if seq.get_len() > self._get_max_model_len(lora_req):
|
if seq.get_len() >= self._get_max_model_len(lora_req):
|
||||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user