[Bugfix][Core] Prevent token lengths exceeding max_model_len in V0 (#19348)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-06-09 08:01:29 -07:00 committed by GitHub
parent 5cf2daea9a
commit c1c7dbbeeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 1 deletions

View File

@ -25,6 +25,12 @@ TOKEN_IDS = [
]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
"""We can run both engines for this test."""
pass
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
@ -104,3 +110,19 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)
def test_max_model_len():
max_model_len = 20
llm = LLM(
model=MODEL_NAME,
max_model_len=max_model_len,
gpu_memory_utilization=0.10,
enforce_eager=True, # reduce test time
)
sampling_params = SamplingParams(max_tokens=max_model_len + 10)
outputs = llm.generate(PROMPTS, sampling_params)
for output in outputs:
num_total_tokens = len(output.prompt_token_ids) + len(
output.outputs[0].token_ids)
assert num_total_tokens == max_model_len

View File

@ -82,7 +82,7 @@ class StopChecker:
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self._get_max_model_len(lora_req):
if seq.get_len() >= self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return