mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 18:55:16 +08:00
57 lines
2.2 KiB
Python
57 lines
2.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch.cuda
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.v1.engine import EngineCoreRequest
|
|
from vllm.v1.engine.core import EngineCore
|
|
|
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
|
|
|
|
|
def test_preprocess_error_handling(monkeypatch: pytest.MonkeyPatch):
|
|
"""Test that preprocessing errors are handled gracefully."""
|
|
|
|
assert not torch.cuda.is_initialized(), (
|
|
"fork needs to be used for the engine "
|
|
"core process and this isn't possible if cuda is already initialized"
|
|
)
|
|
|
|
# Store original method to call for non-failing requests
|
|
original_preprocess = EngineCore.preprocess_add_request
|
|
|
|
# Monkeypatch to make preprocess_add_request raise an exception
|
|
# only for requests with "FAIL" in the first token
|
|
def conditional_failing_preprocess(self, request: EngineCoreRequest):
|
|
# Fail if the first token id is 333
|
|
if request.prompt_token_ids and request.prompt_token_ids[0] == 333:
|
|
raise ValueError("Simulated preprocessing error!")
|
|
return original_preprocess(self, request)
|
|
|
|
monkeypatch.setattr(
|
|
EngineCore, "preprocess_add_request", conditional_failing_preprocess
|
|
)
|
|
|
|
llm = LLM(model=MODEL_NAME)
|
|
|
|
# Create a failing request by crafting a request with an invalid token
|
|
# We need to use a direct approach since LLM.generate tokenizes for us
|
|
from vllm.inputs import TokensPrompt
|
|
|
|
# This should raise an exception due to the preprocessing failure
|
|
# Special token id to trigger the failure
|
|
failing_prompt = TokensPrompt(prompt_token_ids=[333])
|
|
outputs = llm.generate(failing_prompt, SamplingParams(max_tokens=10)) # type: ignore
|
|
assert len(outputs) == 1
|
|
assert len(outputs[0].outputs[0].token_ids) == 0
|
|
assert outputs[0].finished
|
|
assert outputs[0].outputs[0].finish_reason == "error"
|
|
|
|
# Verify the engine is still functional with a normal request
|
|
outputs = llm.generate("Hello, my name is", SamplingParams(max_tokens=10))
|
|
assert len(outputs) == 1
|
|
assert len(outputs[0].outputs[0].token_ids) > 0
|
|
assert outputs[0].outputs[0].finish_reason in ("stop", "length")
|