[CI][gpt-oss] Enable python tool tests in CI (#24315)

Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
wuhang 2025-10-06 12:20:06 +08:00 committed by GitHub
parent 4be7d7c1c9
commit 91ac7f764d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 26 deletions

View File

@ -49,3 +49,4 @@ pybase64 # fast base64 implementation
cbor2 # Required for cross-language serialization of hashable objects cbor2 # Required for cross-language serialization of hashable objects
setproctitle # Used to set process names for better debugging and monitoring setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss openai-harmony >= 0.0.3 # Required for gpt-oss
gpt-oss >= 0.0.7

View File

@ -15,22 +15,15 @@ MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def monkeypatch_module(): def server():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"] args = ["--enforce-eager", "--tool-server", "demo"]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
)
with monkeypatch_module.context() as m: with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") yield remote_server
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
@ -316,7 +309,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
# TODO: Add back when web search and code interpreter are available in CI # TODO: Add back when web search and code interpreter are available in CI
prompts = [ prompts = [
"tell me a story about a cat in 20 words", "tell me a story about a cat in 20 words",
# "What is 13 * 24? Use python to calculate the result.", "What is 13 * 24? Use python to calculate the result.",
# "When did Jensen found NVIDIA? Search it and answer the year only.", # "When did Jensen found NVIDIA? Search it and answer the year only.",
] ]
@ -329,12 +322,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
# { # {
# "type": "web_search_preview" # "type": "web_search_preview"
# }, # },
# { {"type": "code_interpreter", "container": {"type": "auto"}},
# "type": "code_interpreter",
# "container": {
# "type": "auto"
# }
# },
], ],
stream=True, stream=True,
background=background, background=background,
@ -412,6 +400,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
async for event in stream: async for event in stream:
counter += 1 counter += 1
assert event == events[counter] assert event == events[counter]
assert counter == len(events) - 1
@pytest.mark.asyncio @pytest.mark.asyncio
@ -429,7 +418,6 @@ async def test_web_search(client: OpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_code_interpreter(client: OpenAI, model_name: str): async def test_code_interpreter(client: OpenAI, model_name: str):
response = await client.responses.create( response = await client.responses.create(
model=model_name, model=model_name,
@ -443,10 +431,16 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
"and you must print to see the output." "and you must print to see the output."
), ),
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0, # More deterministic output in response
) )
assert response is not None assert response is not None
assert response.status == "completed" assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0 assert response.usage.output_tokens_details.tool_output_tokens > 0
for item in response.output:
if item.type == "message":
output_string = item.content[0].text
print("output_string: ", output_string, flush=True)
assert "5846" in output_string
def get_weather(latitude, longitude): def get_weather(latitude, longitude):

View File

@ -14,10 +14,12 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
MIN_GPT_OSS_VERSION = "0.0.7"
def validate_gpt_oss_install(): def validate_gpt_oss_install():
""" """
Check if the gpt-oss is installed and its version is at least 0.0.3. Check if the gpt-oss is installed and its version is at least 0.0.7.
If not, raise an ImportError. If not, raise an ImportError.
""" """
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
@ -25,16 +27,17 @@ def validate_gpt_oss_install():
from packaging.version import InvalidVersion, Version from packaging.version import InvalidVersion, Version
try: try:
pkg_version_str = version("gpt_oss") # e.g., "0.0.5" pkg_version_str = version("gpt_oss")
pkg_version = Version(pkg_version_str) pkg_version = Version(pkg_version_str)
except PackageNotFoundError: except PackageNotFoundError:
raise ImportError("Package 'gpt_oss' is not installed.") from None raise ImportError("Package 'gpt_oss' is not installed.") from None
except InvalidVersion as e: except InvalidVersion as e:
raise ImportError(f"Invalid version string for 'gpt_oss': {e}") from None raise ImportError(f"Invalid version string for 'gpt_oss': {e}") from None
if pkg_version < Version("0.0.3"): if pkg_version < Version(MIN_GPT_OSS_VERSION):
raise ImportError( raise ImportError(
f"gpt_oss >= 0.0.3 is required, but {pkg_version} is installed." f"gpt_oss >= {MIN_GPT_OSS_VERSION} is required, "
f"but {pkg_version} is installed."
) from None ) from None