[Bugfix][V0 Deprecation][CI] use async mock and await for async method (#25325)

Signed-off-by: Yang <lymailforjob@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Yang Liu 2025-09-21 16:06:16 -07:00 committed by yewentao256
parent a815d820ee
commit e348e1027c

View File

@ -5,7 +5,7 @@ from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import Optional
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@ -83,20 +83,31 @@ def register_mock_resolver():
def mock_serving_setup(): def mock_serving_setup():
"""Provides a mocked engine and serving completion instance.""" """Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
def mock_add_lora_side_effect(lora_request: LoRARequest): tokenizer = get_tokenizer(MODEL_NAME)
mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)
async def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs.""" """Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora": if lora_request.lora_name == "test-lora":
# Simulate successful addition # Simulate successful addition
return return True
elif lora_request.lora_name == "invalid-lora": if lora_request.lora_name == "invalid-lora":
# Simulate failure during addition (e.g. invalid format) # Simulate failure during addition (e.g. invalid format)
raise ValueError(f"Simulated failure adding LoRA: " raise ValueError(f"Simulated failure adding LoRA: "
f"{lora_request.lora_name}") f"{lora_request.lora_name}")
return True
mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect)
async def mock_generate(*args, **kwargs):
for _ in []:
yield _
mock_engine.generate = MagicMock(spec=AsyncLLM.generate,
side_effect=mock_generate)
mock_engine.add_lora.side_effect = mock_add_lora_side_effect
mock_engine.generate.reset_mock() mock_engine.generate.reset_mock()
mock_engine.add_lora.reset_mock() mock_engine.add_lora.reset_mock()
@ -131,7 +142,7 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup,
with suppress(Exception): with suppress(Exception):
await serving_completion.create_completion(req_found) await serving_completion.create_completion(req_found)
mock_engine.add_lora.assert_called_once() mock_engine.add_lora.assert_awaited_once()
called_lora_request = mock_engine.add_lora.call_args[0][0] called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest) assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name assert called_lora_request.lora_name == lora_model_name
@ -157,7 +168,7 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,
response = await serving_completion.create_completion(req) response = await serving_completion.create_completion(req)
mock_engine.add_lora.assert_not_called() mock_engine.add_lora.assert_not_awaited()
mock_engine.generate.assert_not_called() mock_engine.generate.assert_not_called()
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
@ -181,7 +192,7 @@ async def test_serving_completion_resolver_add_lora_fails(
response = await serving_completion.create_completion(req) response = await serving_completion.create_completion(req)
# Assert add_lora was called before the failure # Assert add_lora was called before the failure
mock_engine.add_lora.assert_called_once() mock_engine.add_lora.assert_awaited_once()
called_lora_request = mock_engine.add_lora.call_args[0][0] called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest) assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == invalid_model assert called_lora_request.lora_name == invalid_model