mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[V0 Deprecation] Remove V0 Engine tests (#25114)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
5963b98b46
commit
2c3c1bd07a
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@ -1,37 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_computed_prefix_blocks(model: str, block_size: int):
|
||||
# This test checks if we are able to run the engine to completion
|
||||
# without triggering asserts.
|
||||
# We are in a scenario where all blocks from the second request's prompt
|
||||
# are full and already computed when the second request arrives.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
prompt2 = (
|
||||
" Please recommend to me some resources where I can learn not only to "
|
||||
"handle technical difficulties of building a car, but also "
|
||||
"decoration.")
|
||||
|
||||
engine_args = EngineArgs(model=model,
|
||||
block_size=block_size,
|
||||
enable_prefix_caching=True)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
engine.add_request("0", prompt + prompt2, sampling_params)
|
||||
engine.step()
|
||||
engine.add_request("1", prompt, sampling_params)
|
||||
engine.step()
|
||||
@ -1,111 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class Mock:
|
||||
...
|
||||
|
||||
|
||||
class CustomUniExecutor(UniProcExecutor):
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None) -> list[Any]:
|
||||
# Drop marker to show that this was run
|
||||
with open(".marker", "w"):
|
||||
...
|
||||
return super().collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
CustomUniExecutorAsync = CustomUniExecutor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_custom_executor_type_checking(model):
|
||||
with pytest.raises(ValueError):
|
||||
engine_args = EngineArgs(model=model,
|
||||
distributed_executor_backend=Mock)
|
||||
LLMEngine.from_engine_args(engine_args)
|
||||
with pytest.raises(ValueError):
|
||||
engine_args = AsyncEngineArgs(model=model,
|
||||
distributed_executor_backend=Mock)
|
||||
AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_custom_executor(model, tmp_path):
|
||||
cwd = os.path.abspath(".")
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
distributed_executor_backend=CustomUniExecutor,
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
engine.add_request("0", "foo", sampling_params)
|
||||
engine.step()
|
||||
|
||||
assert os.path.exists(".marker")
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_custom_executor_async(model, tmp_path):
|
||||
cwd = os.path.abspath(".")
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
distributed_executor_backend=CustomUniExecutorAsync,
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
async def t():
|
||||
stream = await engine.add_request("0", "foo", sampling_params)
|
||||
async for x in stream:
|
||||
...
|
||||
|
||||
asyncio.run(t())
|
||||
|
||||
assert os.path.exists(".marker")
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_respect_ray(model):
|
||||
# even for TP=1 and PP=1,
|
||||
# if users specify ray, we should use ray.
|
||||
# users might do this if they want to manage the
|
||||
# resources using ray.
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
distributed_executor_backend="ray",
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
assert engine.model_executor.uses_ray
|
||||
@ -1,179 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
class DummyWorkerWrapper(WorkerWrapperBase):
|
||||
"""Dummy version of vllm.worker.worker.Worker"""
|
||||
|
||||
def worker_method(self, worker_input: Any) -> tuple[int, Any]:
|
||||
sleep(0.05)
|
||||
|
||||
if isinstance(worker_input, Exception):
|
||||
# simulate error case
|
||||
raise worker_input
|
||||
|
||||
return self.rpc_rank, input
|
||||
|
||||
|
||||
def _start_workers() -> tuple[list[ProcessWorkerWrapper], WorkerMonitor]:
|
||||
result_handler = ResultHandler()
|
||||
vllm_config = VllmConfig()
|
||||
workers = [
|
||||
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
|
||||
rank) for rank in range(8)
|
||||
]
|
||||
|
||||
worker_monitor = WorkerMonitor(workers, result_handler)
|
||||
assert not worker_monitor.is_alive()
|
||||
|
||||
result_handler.start()
|
||||
worker_monitor.start()
|
||||
assert worker_monitor.is_alive()
|
||||
|
||||
return workers, worker_monitor
|
||||
|
||||
|
||||
def test_local_workers() -> None:
|
||||
"""Test workers with sync task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
def execute_workers(worker_input: str) -> None:
|
||||
worker_outputs = [
|
||||
worker.execute_method("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
for rank, output in enumerate(worker_outputs):
|
||||
assert output.get() == (rank, input)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# Test concurrent submission from different threads
|
||||
futures = [
|
||||
executor.submit(partial(execute_workers, f"thread {thread_num}"))
|
||||
for thread_num in range(4)
|
||||
]
|
||||
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
result = workers[0].execute_method("worker_method", exception)
|
||||
try:
|
||||
result.get()
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(20)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
def test_local_workers_clean_shutdown() -> None:
|
||||
"""Test clean shutdown"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
assert worker_monitor.is_alive()
|
||||
assert all(worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Clean shutdown
|
||||
worker_monitor.close()
|
||||
|
||||
worker_monitor.join(20)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_workers_async() -> None:
|
||||
"""Test local workers with async task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
async def execute_workers(worker_input: str) -> None:
|
||||
worker_coros = [
|
||||
worker.execute_method_async("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*worker_coros)
|
||||
for rank, result in enumerate(results):
|
||||
assert result == (rank, input)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(execute_workers(f"task {task_num}"))
|
||||
for task_num in range(4)
|
||||
]
|
||||
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", exception)
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(20)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
@ -1,58 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_skip_tokenizer_initialization(model: str):
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
enforce_eager=True,
|
||||
)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
assert completions[0].text == ""
|
||||
assert completions[0].token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_enable_prompt_embeds(hf_runner, model: str,
|
||||
enable_prompt_embeds: bool):
|
||||
prompt = "abc"
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
token_ids = token_ids.to(hf_model.model.device)
|
||||
|
||||
embed_layer = hf_model.model.get_input_embeddings()
|
||||
prompt_embeds = embed_layer(token_ids).squeeze(0)
|
||||
|
||||
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
|
||||
ValueError, match="set `--enable-prompt-embeds`"))
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
with ctx:
|
||||
llm.generate({"prompt_embeds": prompt_embeds})
|
||||
@ -25,6 +25,7 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
model,
|
||||
max_model_len=128, # LLaVA has a feature size of 576
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
with vllm_model:
|
||||
|
||||
@ -1,225 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
class MockReasoningParser(ReasoningParser):
|
||||
"""Mock reasoning parser for testing purposes."""
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: AutoTokenizer,
|
||||
reasoning_active: bool = False):
|
||||
super().__init__(tokenizer)
|
||||
self.reasoning_active = reasoning_active
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return not self.reasoning_active
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return input_ids
|
||||
|
||||
|
||||
class MockSequence(Sequence):
|
||||
"""Mock sequence for testing purposes."""
|
||||
|
||||
def __init__(self, token_ids, output_text="test_output", eos_token_id=0):
|
||||
self.token_ids = token_ids
|
||||
self.output_text = output_text
|
||||
self.eos_token_id = eos_token_id
|
||||
self.status = SequenceStatus.RUNNING
|
||||
self.stop_reason = None
|
||||
|
||||
def get_token_ids(self):
|
||||
return self.token_ids
|
||||
|
||||
def get_last_token_id(self):
|
||||
return self.token_ids[-1] if self.token_ids else None
|
||||
|
||||
def get_len(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
def get_output_len(self):
|
||||
return len(self.token_ids) - 1 # Simulating prompt + outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deepseek_r1_qwen_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker():
|
||||
return StopChecker(max_model_len=10)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker_with_reasoner():
|
||||
reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer)
|
||||
return StopChecker(max_model_len=10, reasoner=reasoner)
|
||||
|
||||
|
||||
def test_eos_token_stopping(stop_checker):
|
||||
"""Test sequence stopping when EOS token is encountered."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
|
||||
def test_ignore_eos(stop_checker):
|
||||
"""Test sequence continuing when EOS token is ignored."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams(ignore_eos=True)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
|
||||
def test_min_tokens(stop_checker):
|
||||
"""Test min_tokens prevents early stopping."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams(min_tokens=3)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
|
||||
def test_stop_token_ids(stop_checker):
|
||||
"""Test sequence stopping with custom stop token IDs."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.stop_reason == 3
|
||||
|
||||
|
||||
def test_stop_strings(stop_checker):
|
||||
"""Test sequence stopping with stop strings."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3],
|
||||
output_text="test output with STOP",
|
||||
eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.stop_reason == "STOP"
|
||||
assert "STOP" not in seq.output_text # Default behavior removes stop string
|
||||
|
||||
|
||||
def test_include_stop_str_in_output(stop_checker):
|
||||
"""Test keeping stop strings in output."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3],
|
||||
output_text="test output with STOP",
|
||||
eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop=["STOP"],
|
||||
include_stop_str_in_output=True)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=5,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert "STOP" in seq.output_text
|
||||
|
||||
|
||||
def test_max_tokens(stop_checker):
|
||||
"""Test sequence stopping at max_tokens."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(max_tokens=2)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
|
||||
def test_max_model_len(stop_checker):
|
||||
"""Test sequence stopping at max_model_len."""
|
||||
seq = MockSequence(token_ids=list(range(11)),
|
||||
eos_token_id=0) # 11 tokens, max is 10
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
|
||||
def test_reasoning_skip_stops(stop_checker_with_reasoner):
|
||||
"""Test that stop tokens and strings are ignored during reasoning."""
|
||||
# Set reasoning_active to True to simulate being in reasoning mode
|
||||
stop_checker_with_reasoner.reasoner.reasoning_active = True
|
||||
|
||||
# Test with stop token
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
# Test with stop string
|
||||
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=4, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
# But EOS token still stops the sequence
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
|
||||
def test_reasoning_end_enables_stops(stop_checker_with_reasoner):
|
||||
"""Test that stop tokens work after reasoning ends."""
|
||||
# Set reasoning_active to False to simulate being out of reasoning mode
|
||||
stop_checker_with_reasoner.reasoner.reasoning_active = False
|
||||
|
||||
# Test with stop token
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
# Test with stop string
|
||||
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=4, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
Loading…
x
Reference in New Issue
Block a user