mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[ROCm][CI] Fix v1/logits_processors failure on ROCm (#29927)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
parent
9ae3c55b10
commit
d1f7392c5f
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
|
|||||||
from tests.v1.logits_processors.utils import (
|
from tests.v1.logits_processors.utils import (
|
||||||
DUMMY_LOGITPROC_ARG,
|
DUMMY_LOGITPROC_ARG,
|
||||||
DUMMY_LOGITPROC_FQCN,
|
DUMMY_LOGITPROC_FQCN,
|
||||||
DUMMY_LOGITPROC_MODULE,
|
|
||||||
MAX_TOKENS,
|
MAX_TOKENS,
|
||||||
MODEL_NAME,
|
MODEL_NAME,
|
||||||
POOLING_MODEL_NAME,
|
POOLING_MODEL_NAME,
|
||||||
@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
|
|||||||
CustomLogitprocSource,
|
CustomLogitprocSource,
|
||||||
DummyLogitsProcessor,
|
DummyLogitsProcessor,
|
||||||
WrappedPerReqLogitsProcessor,
|
WrappedPerReqLogitsProcessor,
|
||||||
dummy_module,
|
|
||||||
prompts,
|
prompts,
|
||||||
)
|
)
|
||||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||||
@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
|
|||||||
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
|
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
|
||||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||||
# Inject dummy module which defines logitproc
|
|
||||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
|
||||||
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||||
# Scenario: load logitproc from provided class object
|
# Scenario: load logitproc from provided class object
|
||||||
|
|||||||
@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
|
|||||||
from tests.v1.logits_processors.utils import (
|
from tests.v1.logits_processors.utils import (
|
||||||
DUMMY_LOGITPROC_ARG,
|
DUMMY_LOGITPROC_ARG,
|
||||||
DUMMY_LOGITPROC_FQCN,
|
DUMMY_LOGITPROC_FQCN,
|
||||||
DUMMY_LOGITPROC_MODULE,
|
|
||||||
MAX_TOKENS,
|
MAX_TOKENS,
|
||||||
MODEL_NAME,
|
MODEL_NAME,
|
||||||
TEMP_GREEDY,
|
TEMP_GREEDY,
|
||||||
dummy_module,
|
|
||||||
prompts,
|
prompts,
|
||||||
)
|
)
|
||||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||||
@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
|
|||||||
main.main()
|
main.main()
|
||||||
|
|
||||||
|
|
||||||
def _server_with_logitproc_module(
|
def _server_with_logitproc_fqcn(
|
||||||
env_dict: dict[str, str] | None,
|
env_dict: dict[str, str] | None,
|
||||||
model: str,
|
model: str,
|
||||||
vllm_serve_args: list[str],
|
vllm_serve_args: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start vLLM server, inject module with dummy logitproc"""
|
"""Start vLLM server, inject module with dummy logitproc"""
|
||||||
|
|
||||||
# Patch `modules` to inject dummy logitproc module
|
|
||||||
from vllm.entrypoints.cli import main
|
from vllm.entrypoints.cli import main
|
||||||
|
|
||||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
|
||||||
|
|
||||||
# fork is required for workers to see entrypoint patch
|
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
|
|
||||||
if env_dict is not None:
|
if env_dict is not None:
|
||||||
os.environ.update(env_dict)
|
os.environ.update(env_dict)
|
||||||
|
|
||||||
@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
|
|||||||
if request.param:
|
if request.param:
|
||||||
# Launch server, append FQCN argument, inject dummy logitproc module
|
# Launch server, append FQCN argument, inject dummy logitproc module
|
||||||
args = default_server_args + request.param
|
args = default_server_args + request.param
|
||||||
_server_fxn = _server_with_logitproc_module
|
_server_fxn = _server_with_logitproc_fqcn
|
||||||
else:
|
else:
|
||||||
# Launch server, inject dummy logitproc entrypoint
|
# Launch server, inject dummy logitproc entrypoint
|
||||||
args = default_server_args
|
args = default_server_args
|
||||||
|
|||||||
@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
|
|||||||
TEMP_GREEDY = 0.0
|
TEMP_GREEDY = 0.0
|
||||||
MAX_TOKENS = 20
|
MAX_TOKENS = 20
|
||||||
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
|
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
|
||||||
DUMMY_LOGITPROC_MODULE = "DummyModule"
|
DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils"
|
||||||
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user