mirror of
https://git.datalinker.icu/ltdrdata/ComfyUI-Manager
synced 2025-12-08 21:54:26 +08:00
[tests] Add API test suite
This commit is contained in:
parent
416122d61d
commit
da87651e53
19
tests-api/.gitignore
vendored
Normal file
19
tests-api/.gitignore
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
# Python cache files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Pytest cache
|
||||
.pytest_cache/
|
||||
|
||||
# Coverage reports
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# Test-specific resources
|
||||
resources/tmp/
|
||||
91
tests-api/README.md
Normal file
91
tests-api/README.md
Normal file
@ -0,0 +1,91 @@
|
||||
# ComfyUI-Manager API Tests
|
||||
|
||||
This directory contains tests for the ComfyUI-Manager API endpoints, validating the OpenAPI specification and ensuring API functionality.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Install test dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements-test.txt
|
||||
```
|
||||
|
||||
2. Ensure ComfyUI is running with ComfyUI-Manager installed:
|
||||
|
||||
```bash
|
||||
# Start ComfyUI with the default server
|
||||
python main.py
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Run all tests
|
||||
|
||||
```bash
|
||||
pytest -xvs
|
||||
```
|
||||
|
||||
### Run specific test files
|
||||
|
||||
```bash
|
||||
# Run only the spec validation tests
|
||||
pytest -xvs test_spec_validation.py
|
||||
|
||||
# Run only the custom node API tests
|
||||
pytest -xvs test_customnode_api.py
|
||||
```
|
||||
|
||||
### Run specific test functions
|
||||
|
||||
```bash
|
||||
# Run a specific test
|
||||
pytest -xvs test_customnode_api.py::test_get_custom_node_list
|
||||
```
|
||||
|
||||
## Test Configuration
|
||||
|
||||
The tests use the following default configuration:
|
||||
|
||||
- Server URL: `http://localhost:8188`
|
||||
- Server timeout: 2 seconds
|
||||
- Wait between requests: 0.5 seconds
|
||||
- Maximum retries: 3
|
||||
|
||||
You can override these settings with environment variables:
|
||||
|
||||
```bash
|
||||
# Use a different server URL
|
||||
COMFYUI_SERVER_URL=http://localhost:8189 pytest -xvs
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
The tests are organized into the following categories:
|
||||
|
||||
1. **Spec Validation** (`test_spec_validation.py`): Validates that the OpenAPI specification is correct and complete.
|
||||
2. **Custom Node API** (`test_customnode_api.py`): Tests for custom node management endpoints.
|
||||
3. **Snapshot API** (`test_snapshot_api.py`): Tests for snapshot management endpoints.
|
||||
4. **Queue API** (`test_queue_api.py`): Tests for queue management endpoints.
|
||||
5. **Config API** (`test_config_api.py`): Tests for configuration endpoints.
|
||||
6. **Model API** (`test_model_api.py`): Tests for model management endpoints (minimal as these are being deprecated).
|
||||
|
||||
## Test Implementation Details
|
||||
|
||||
### Fixtures
|
||||
|
||||
- `test_config`: Provides the test configuration
|
||||
- `server_url`: Returns the server URL from the configuration
|
||||
- `openapi_spec`: Loads the OpenAPI specification
|
||||
- `api_client`: Creates a requests Session for API calls
|
||||
- `api_request`: Helper function for making consistent API requests
|
||||
|
||||
### Utilities
|
||||
|
||||
- `validation.py`: Functions for validating responses against the OpenAPI schema
|
||||
- `schema_utils.py`: Utilities for extracting and manipulating schemas
|
||||
|
||||
## Notes
|
||||
|
||||
- Some tests are skipped with `@pytest.mark.skip` to avoid modifying state in automated testing
|
||||
- Security-level restricted endpoints have minimal tests to avoid security issues
|
||||
- Tests focus on read operations rather than write operations where possible
|
||||
1
tests-api/__init__.py
Normal file
1
tests-api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Make tests-api directory a proper package
|
||||
237
tests-api/conftest.py
Normal file
237
tests-api/conftest.py
Normal file
@ -0,0 +1,237 @@
|
||||
"""
|
||||
PyTest configuration and fixtures for API tests.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import pytest
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, Optional, Tuple
|
||||
|
||||
# Import test utilities
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Get the absolute path to the current file (conftest.py)
|
||||
current_file = Path(os.path.abspath(__file__))
|
||||
|
||||
# Get the directory containing the current file (the tests-api directory)
|
||||
tests_api_dir = current_file.parent
|
||||
|
||||
# Add the tests-api directory to the Python path
|
||||
if str(tests_api_dir) not in sys.path:
|
||||
sys.path.insert(0, str(tests_api_dir))
|
||||
|
||||
# Apply mocks for ComfyUI imports
|
||||
from mocks.patch import apply_mocks
|
||||
apply_mocks()
|
||||
|
||||
# Now we can import from utils.validation
|
||||
from utils.validation import load_openapi_spec
|
||||
|
||||
|
||||
# Default test configuration
|
||||
DEFAULT_TEST_CONFIG = {
|
||||
"server_url": "http://localhost:8188",
|
||||
"server_timeout": 2, # seconds
|
||||
"wait_between_requests": 0.5, # seconds
|
||||
"max_retries": 3,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_config() -> Dict:
|
||||
"""
|
||||
Load test configuration from environment variables or use defaults.
|
||||
"""
|
||||
config = DEFAULT_TEST_CONFIG.copy()
|
||||
|
||||
# Override from environment variables if present
|
||||
if "COMFYUI_SERVER_URL" in os.environ:
|
||||
config["server_url"] = os.environ["COMFYUI_SERVER_URL"]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url(test_config: Dict) -> str:
|
||||
"""
|
||||
Get the server URL from the test configuration.
|
||||
"""
|
||||
return test_config["server_url"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def openapi_spec() -> Dict:
|
||||
"""
|
||||
Load the OpenAPI specification.
|
||||
"""
|
||||
return load_openapi_spec()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_client(server_url: str, test_config: Dict) -> requests.Session:
|
||||
"""
|
||||
Create a requests Session for API calls.
|
||||
"""
|
||||
session = requests.Session()
|
||||
|
||||
# Check if the server is running
|
||||
try:
|
||||
response = session.get(f"{server_url}/", timeout=test_config["server_timeout"])
|
||||
response.raise_for_status()
|
||||
except (requests.ConnectionError, requests.Timeout, requests.HTTPError):
|
||||
pytest.skip("ComfyUI server is not running or not accessible")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def temp_dir() -> Generator[Path, None, None]:
|
||||
"""
|
||||
Create a temporary directory for test files.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
class SecurityLevelContext:
|
||||
"""
|
||||
Context manager for setting and restoring security levels.
|
||||
"""
|
||||
def __init__(self, api_client: requests.Session, server_url: str, security_level: str):
|
||||
self.api_client = api_client
|
||||
self.server_url = server_url
|
||||
self.security_level = security_level
|
||||
self.original_level = None
|
||||
|
||||
async def __aenter__(self):
|
||||
# Get the current security level (not directly exposed in API, would require more setup)
|
||||
# For now, we'll just set the new level
|
||||
|
||||
# Set the new security level
|
||||
# Note: In a real implementation, we would need a way to set this
|
||||
# This is a placeholder - the actual implementation would depend on how
|
||||
# security levels are managed in ComfyUI-Manager
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# Restore the original security level if needed
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def security_level_context(api_client: requests.Session, server_url: str):
|
||||
"""
|
||||
Create a context manager for setting security levels.
|
||||
"""
|
||||
return lambda level: SecurityLevelContext(api_client, server_url, level)
|
||||
|
||||
|
||||
def make_api_url(server_url: str, path: str) -> str:
|
||||
"""
|
||||
Construct a full API URL from the server URL and path.
|
||||
"""
|
||||
# Ensure the path starts with a slash
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
# Remove trailing slash from server_url if present
|
||||
if server_url.endswith("/"):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
return f"{server_url}{path}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_request(api_client: requests.Session, server_url: str, test_config: Dict):
|
||||
"""
|
||||
Helper function for making API requests with consistent behavior.
|
||||
"""
|
||||
def _request(
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[Dict] = None,
|
||||
json_data: Optional[Dict] = None,
|
||||
headers: Optional[Dict] = None,
|
||||
expected_status: int = 200,
|
||||
retry_on_error: bool = True,
|
||||
) -> Tuple[requests.Response, Optional[Dict]]:
|
||||
"""
|
||||
Make an API request with automatic validation.
|
||||
|
||||
Args:
|
||||
method: HTTP method
|
||||
path: API path
|
||||
params: Query parameters
|
||||
json_data: JSON request body
|
||||
headers: HTTP headers
|
||||
expected_status: Expected HTTP status code
|
||||
retry_on_error: Whether to retry on connection errors
|
||||
|
||||
Returns:
|
||||
Tuple of (Response object, JSON response data or None)
|
||||
"""
|
||||
method = method.lower()
|
||||
url = make_api_url(server_url, path)
|
||||
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# Add common headers
|
||||
headers.setdefault("Accept", "application/json")
|
||||
|
||||
# Sleep between requests to avoid overwhelming the server
|
||||
time.sleep(test_config["wait_between_requests"])
|
||||
|
||||
retries = test_config["max_retries"] if retry_on_error else 0
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
if method == "get":
|
||||
response = api_client.get(url, params=params, headers=headers)
|
||||
elif method == "post":
|
||||
response = api_client.post(url, params=params, json=json_data, headers=headers)
|
||||
elif method == "put":
|
||||
response = api_client.put(url, params=params, json=json_data, headers=headers)
|
||||
elif method == "delete":
|
||||
response = api_client.delete(url, params=params, headers=headers)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
# Check status code
|
||||
assert response.status_code == expected_status, (
|
||||
f"Expected status code {expected_status}, got {response.status_code}"
|
||||
)
|
||||
|
||||
# Parse JSON response if possible
|
||||
json_response = None
|
||||
if response.headers.get("Content-Type", "").startswith("application/json"):
|
||||
try:
|
||||
json_response = response.json()
|
||||
except json.JSONDecodeError:
|
||||
if expected_status == 200:
|
||||
raise ValueError("Response was not valid JSON")
|
||||
|
||||
return response, json_response
|
||||
|
||||
except (requests.ConnectionError, requests.Timeout) as e:
|
||||
last_exception = e
|
||||
if attempt < retries:
|
||||
# Wait before retrying
|
||||
time.sleep(1)
|
||||
continue
|
||||
break
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
raise RuntimeError("Failed to make API request")
|
||||
|
||||
return _request
|
||||
1
tests-api/mocks/__init__.py
Normal file
1
tests-api/mocks/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Make tests-api/mocks directory a proper package
|
||||
26
tests-api/mocks/custom_node_manager.py
Normal file
26
tests-api/mocks/custom_node_manager.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
Mock CustomNodeManager for testing purposes
|
||||
"""
|
||||
|
||||
class CustomNodeManager:
|
||||
"""
|
||||
Mock implementation of the CustomNodeManager class
|
||||
"""
|
||||
instance = None
|
||||
|
||||
def __init__(self):
|
||||
self.custom_nodes = {}
|
||||
self.node_paths = []
|
||||
self.refresh_timeout = None
|
||||
|
||||
def get_node_path(self, node_class):
|
||||
"""
|
||||
Mock implementation to get the path for a node class
|
||||
"""
|
||||
return self.custom_nodes.get(node_class, None)
|
||||
|
||||
def update_node_paths(self):
|
||||
"""
|
||||
Mock implementation to update node paths
|
||||
"""
|
||||
pass
|
||||
116
tests-api/mocks/patch.py
Normal file
116
tests-api/mocks/patch.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""
|
||||
Patch module to mock imports for testing
|
||||
"""
|
||||
import sys
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Import mock modules
|
||||
from mocks.prompt_server import PromptServer
|
||||
from mocks.custom_node_manager import CustomNodeManager
|
||||
|
||||
# Current directory
|
||||
current_dir = Path(__file__).parent.parent # tests-api directory
|
||||
|
||||
# Define mocks
|
||||
class MockModule:
|
||||
"""Base class for mock modules"""
|
||||
pass
|
||||
|
||||
# Create server mock module with PromptServer
|
||||
server_mock = MockModule()
|
||||
server_mock.PromptServer = PromptServer
|
||||
prompt_server_instance = PromptServer()
|
||||
server_mock.PromptServer.instance = prompt_server_instance
|
||||
server_mock.PromptServer.inst = prompt_server_instance
|
||||
|
||||
# Create app mock module with custom_node_manager submodule
|
||||
app_mock = MockModule()
|
||||
app_custom_node_manager = MockModule()
|
||||
app_custom_node_manager.CustomNodeManager = CustomNodeManager
|
||||
app_custom_node_manager.CustomNodeManager.instance = CustomNodeManager()
|
||||
|
||||
# Create utils mock module with json_util submodule
|
||||
utils_mock = MockModule()
|
||||
utils_json_util = MockModule()
|
||||
|
||||
# Create utils.validation and utils.schema_utils submodules
|
||||
utils_validation = MockModule()
|
||||
utils_schema_utils = MockModule()
|
||||
|
||||
# Import actual modules (make sure path is set up correctly)
|
||||
sys.path.insert(0, str(current_dir))
|
||||
|
||||
try:
|
||||
# Import the validation module
|
||||
from utils.validation import load_openapi_spec
|
||||
utils_validation.load_openapi_spec = load_openapi_spec
|
||||
|
||||
# Import all schema_utils functions
|
||||
from utils.schema_utils import (
|
||||
get_all_paths,
|
||||
get_grouped_paths,
|
||||
get_methods_for_path,
|
||||
find_paths_with_security,
|
||||
get_content_types_for_response,
|
||||
get_required_parameters
|
||||
)
|
||||
|
||||
utils_schema_utils.get_all_paths = get_all_paths
|
||||
utils_schema_utils.get_grouped_paths = get_grouped_paths
|
||||
utils_schema_utils.get_methods_for_path = get_methods_for_path
|
||||
utils_schema_utils.find_paths_with_security = find_paths_with_security
|
||||
utils_schema_utils.get_content_types_for_response = get_content_types_for_response
|
||||
utils_schema_utils.get_required_parameters = get_required_parameters
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Error importing test utilities: {e}")
|
||||
# Define dummy functions if imports fail
|
||||
def dummy_load_openapi_spec():
|
||||
"""Dummy function for testing"""
|
||||
return {"paths": {}}
|
||||
utils_validation.load_openapi_spec = dummy_load_openapi_spec
|
||||
|
||||
def dummy_get_all_paths(spec):
|
||||
return list(spec.get("paths", {}).keys())
|
||||
utils_schema_utils.get_all_paths = dummy_get_all_paths
|
||||
|
||||
def dummy_get_grouped_paths(spec):
|
||||
return {}
|
||||
utils_schema_utils.get_grouped_paths = dummy_get_grouped_paths
|
||||
|
||||
def dummy_get_methods_for_path(spec, path):
|
||||
return []
|
||||
utils_schema_utils.get_methods_for_path = dummy_get_methods_for_path
|
||||
|
||||
def dummy_find_paths_with_security(spec, security_scheme=None):
|
||||
return []
|
||||
utils_schema_utils.find_paths_with_security = dummy_find_paths_with_security
|
||||
|
||||
def dummy_get_content_types_for_response(spec, path, method, status_code="200"):
|
||||
return []
|
||||
utils_schema_utils.get_content_types_for_response = dummy_get_content_types_for_response
|
||||
|
||||
def dummy_get_required_parameters(spec, path, method):
|
||||
return []
|
||||
utils_schema_utils.get_required_parameters = dummy_get_required_parameters
|
||||
|
||||
# Add merge_json_recursive from our mock utils
|
||||
from mocks.utils import merge_json_recursive
|
||||
utils_json_util.merge_json_recursive = merge_json_recursive
|
||||
|
||||
# Apply the mocks to sys.modules
|
||||
def apply_mocks():
|
||||
"""Apply all mocks to sys.modules"""
|
||||
sys.modules['server'] = server_mock
|
||||
sys.modules['app'] = app_mock
|
||||
sys.modules['app.custom_node_manager'] = app_custom_node_manager
|
||||
sys.modules['utils'] = utils_mock
|
||||
sys.modules['utils.json_util'] = utils_json_util
|
||||
sys.modules['utils.validation'] = utils_validation
|
||||
sys.modules['utils.schema_utils'] = utils_schema_utils
|
||||
|
||||
# Make sure our actual utils module is importable
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, str(current_dir))
|
||||
71
tests-api/mocks/prompt_server.py
Normal file
71
tests-api/mocks/prompt_server.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Mock PromptServer for testing purposes
|
||||
"""
|
||||
|
||||
class MockRoutes:
|
||||
"""
|
||||
Mock routing class with method decorators
|
||||
"""
|
||||
def __init__(self):
|
||||
self.routes = {}
|
||||
|
||||
def get(self, path):
|
||||
"""Decorator for GET routes"""
|
||||
def decorator(f):
|
||||
self.routes[('GET', path)] = f
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def post(self, path):
|
||||
"""Decorator for POST routes"""
|
||||
def decorator(f):
|
||||
self.routes[('POST', path)] = f
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def put(self, path):
|
||||
"""Decorator for PUT routes"""
|
||||
def decorator(f):
|
||||
self.routes[('PUT', path)] = f
|
||||
return f
|
||||
return decorator
|
||||
|
||||
def delete(self, path):
|
||||
"""Decorator for DELETE routes"""
|
||||
def decorator(f):
|
||||
self.routes[('DELETE', path)] = f
|
||||
return f
|
||||
return decorator
|
||||
|
||||
|
||||
class PromptServer:
|
||||
"""
|
||||
Mock implementation of the PromptServer class
|
||||
"""
|
||||
instance = None
|
||||
inst = None
|
||||
|
||||
def __init__(self):
|
||||
self.routes = MockRoutes()
|
||||
self.registered_paths = set()
|
||||
self.base_url = "http://127.0.0.1:8188" # Assuming server is running on default port
|
||||
self.queue_lock = None
|
||||
|
||||
def add_route(self, method, path, handler, *args, **kwargs):
|
||||
"""
|
||||
Add a mock route to the server
|
||||
"""
|
||||
self.routes.routes[(method.upper(), path)] = handler
|
||||
self.registered_paths.add(path)
|
||||
|
||||
async def send_msg(self, message, data=None):
|
||||
"""
|
||||
Mock send_msg method (does nothing in the mock)
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_sync(self, message, data=None):
|
||||
"""
|
||||
Mock send_sync method (does nothing in the mock)
|
||||
"""
|
||||
pass
|
||||
20
tests-api/mocks/utils.py
Normal file
20
tests-api/mocks/utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""
|
||||
Mock utils module for testing purposes
|
||||
"""
|
||||
|
||||
def merge_json_recursive(a, b):
|
||||
"""
|
||||
Mock implementation of merge_json_recursive
|
||||
"""
|
||||
if isinstance(a, dict) and isinstance(b, dict):
|
||||
result = a.copy()
|
||||
for key, value in b.items():
|
||||
if key in result and isinstance(result[key], (dict, list)) and isinstance(value, (dict, list)):
|
||||
result[key] = merge_json_recursive(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
elif isinstance(a, list) and isinstance(b, list):
|
||||
return a + b
|
||||
else:
|
||||
return b
|
||||
382
tests-api/openapi.yaml
Normal file
382
tests-api/openapi.yaml
Normal file
@ -0,0 +1,382 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: ComfyUI-Manager API
|
||||
description: API for managing ComfyUI extensions, custom nodes, and models
|
||||
version: 1.0.0
|
||||
contact:
|
||||
name: ComfyUI Community
|
||||
url: https://github.com/comfyanonymous/ComfyUI
|
||||
|
||||
servers:
|
||||
- url: http://localhost:8188
|
||||
description: Local ComfyUI server
|
||||
|
||||
paths:
|
||||
/customnode/getlist:
|
||||
get:
|
||||
summary: Get the list of custom nodes
|
||||
description: Returns the list of custom nodes from all configured channels
|
||||
parameters:
|
||||
- name: mode
|
||||
in: query
|
||||
description: "The mode to retrieve (local=installed nodes, remote=available nodes)"
|
||||
schema:
|
||||
type: string
|
||||
enum: [local, remote]
|
||||
default: remote
|
||||
responses:
|
||||
'200':
|
||||
description: List of custom nodes
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
nodes:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/CustomNode'
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
/customnode/get_node_mappings:
|
||||
get:
|
||||
summary: Get mappings between node class names and their custom nodes
|
||||
description: Returns mappings that help identify which custom node package provides specific node classes
|
||||
parameters:
|
||||
- name: mode
|
||||
in: query
|
||||
description: "The mode for mappings (local=installed nodes, nickname=node nicknames)"
|
||||
schema:
|
||||
type: string
|
||||
enum: [local, nickname]
|
||||
default: local
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
description: Node mappings
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
/customnode/get_node_alternatives:
|
||||
get:
|
||||
summary: Get alternative nodes for specific node classes
|
||||
description: Returns alternative implementations of node classes from different custom node packages
|
||||
parameters:
|
||||
- name: mode
|
||||
in: query
|
||||
description: "The mode to retrieve alternatives (local=installed nodes, remote=all available nodes)"
|
||||
schema:
|
||||
type: string
|
||||
enum: [local, remote]
|
||||
default: remote
|
||||
responses:
|
||||
'200':
|
||||
description: Node alternatives
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
/externalmodel/getlist:
|
||||
get:
|
||||
summary: Get the list of external models
|
||||
description: Returns the list of models from all configured channels
|
||||
parameters:
|
||||
- name: mode
|
||||
in: query
|
||||
description: "The mode to retrieve (local=installed models, remote=available models)"
|
||||
schema:
|
||||
type: string
|
||||
enum: [local, remote]
|
||||
default: remote
|
||||
responses:
|
||||
'200':
|
||||
description: List of external models
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
models:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ExternalModel'
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
/manager/get_config:
|
||||
get:
|
||||
summary: Get manager configuration
|
||||
description: Returns the current configuration of ComfyUI-Manager
|
||||
parameters:
|
||||
- name: key
|
||||
in: query
|
||||
description: "The configuration key to retrieve"
|
||||
schema:
|
||||
type: string
|
||||
required: true
|
||||
responses:
|
||||
'200':
|
||||
description: Configuration value
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
value:
|
||||
type: string
|
||||
'400':
|
||||
description: Invalid key or missing parameter
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
/manager/set_config:
|
||||
post:
|
||||
summary: Set manager configuration
|
||||
description: Updates the configuration of ComfyUI-Manager
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
required:
|
||||
- key
|
||||
- value
|
||||
properties:
|
||||
key:
|
||||
type: string
|
||||
description: "The configuration key to update"
|
||||
value:
|
||||
type: string
|
||||
description: "The new value for the configuration key"
|
||||
responses:
|
||||
'200':
|
||||
description: Configuration updated successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
success:
|
||||
type: boolean
|
||||
'400':
|
||||
description: Invalid key or value
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
/snapshot/getlist:
|
||||
get:
|
||||
summary: Get the list of snapshots
|
||||
description: Returns the list of saved snapshots
|
||||
responses:
|
||||
'200':
|
||||
description: List of snapshots
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
snapshots:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Snapshot'
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
/comfyui_manager/queue/status:
|
||||
get:
|
||||
summary: Get queue status
|
||||
description: Returns the current status of the operation queue
|
||||
responses:
|
||||
'200':
|
||||
description: Queue status
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/QueueStatus'
|
||||
'500':
|
||||
description: Server error
|
||||
|
||||
components:
|
||||
schemas:
|
||||
CustomNode:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- title
|
||||
- reference
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: "Internal name/ID of the custom node"
|
||||
title:
|
||||
type: string
|
||||
description: "Display title of the custom node"
|
||||
reference:
|
||||
type: string
|
||||
description: "Reference URL (usually GitHub repository URL)"
|
||||
description:
|
||||
type: string
|
||||
description: "Description of what the custom node does"
|
||||
install_type:
|
||||
type: string
|
||||
enum: [git, pip, copy]
|
||||
description: "Installation method for the custom node"
|
||||
files:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "List of files provided by this custom node"
|
||||
node_class_names:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "List of node class names provided by this custom node"
|
||||
installed:
|
||||
type: boolean
|
||||
description: "Whether the custom node is installed"
|
||||
version:
|
||||
type: string
|
||||
description: "Version of the custom node"
|
||||
tags:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "Tags associated with the custom node"
|
||||
|
||||
ExternalModel:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- url
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: "Name of the model"
|
||||
type:
|
||||
type: string
|
||||
description: "Type of the model (checkpoint, lora, embedding, etc.)"
|
||||
url:
|
||||
type: string
|
||||
description: "Download URL for the model"
|
||||
description:
|
||||
type: string
|
||||
description: "Description of the model"
|
||||
size:
|
||||
type: integer
|
||||
description: "Size of the model in bytes"
|
||||
installed:
|
||||
type: boolean
|
||||
description: "Whether the model is installed"
|
||||
version:
|
||||
type: string
|
||||
description: "Version of the model"
|
||||
tags:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "Tags associated with the model"
|
||||
|
||||
Snapshot:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- date
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: "Name of the snapshot"
|
||||
date:
|
||||
type: string
|
||||
format: date-time
|
||||
description: "Date when the snapshot was created"
|
||||
description:
|
||||
type: string
|
||||
description: "Description of the snapshot"
|
||||
nodes:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "List of custom nodes in the snapshot"
|
||||
models:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: "List of models in the snapshot"
|
||||
|
||||
QueueStatus:
|
||||
type: object
|
||||
properties:
|
||||
pending:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/QueueItem'
|
||||
description: "List of pending operations in the queue"
|
||||
completed:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/QueueItem'
|
||||
description: "List of completed operations in the queue"
|
||||
failed:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/QueueItem'
|
||||
description: "List of failed operations in the queue"
|
||||
running:
|
||||
type: boolean
|
||||
description: "Whether the queue is currently running"
|
||||
|
||||
QueueItem:
|
||||
type: object
|
||||
required:
|
||||
- id
|
||||
- type
|
||||
- target
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: "Unique ID of the queue item"
|
||||
type:
|
||||
type: string
|
||||
enum: [install, update, uninstall]
|
||||
description: "Type of operation"
|
||||
target:
|
||||
type: string
|
||||
description: "Target of the operation (e.g., custom node name, model name)"
|
||||
status:
|
||||
type: string
|
||||
enum: [pending, processing, completed, failed]
|
||||
description: "Current status of the operation"
|
||||
error:
|
||||
type: string
|
||||
description: "Error message if the operation failed"
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: "Time when the operation was added to the queue"
|
||||
completed_at:
|
||||
type: string
|
||||
format: date-time
|
||||
description: "Time when the operation was completed"
|
||||
|
||||
securitySchemes:
|
||||
ApiKeyAuth:
|
||||
type: apiKey
|
||||
in: header
|
||||
name: X-API-Key
|
||||
description: "API key for authentication"
|
||||
6
tests-api/requirements-test.txt
Normal file
6
tests-api/requirements-test.txt
Normal file
@ -0,0 +1,6 @@
|
||||
pytest>=7.3.1
|
||||
requests>=2.31.0
|
||||
openapi-spec-validator>=0.6.0
|
||||
jsonschema>=4.17.3
|
||||
pytest-asyncio>=0.21.0
|
||||
pyyaml>=6.0
|
||||
270
tests-api/test_config_api.py
Normal file
270
tests-api/test_config_api.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""
|
||||
Tests for configuration endpoints.
|
||||
"""
|
||||
import pytest
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
from utils.validation import validate_response
|
||||
|
||||
|
||||
def test_get_preview_method(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test getting the current preview method.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/preview_method"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify the response is one of the valid preview methods
|
||||
assert response.text in ["auto", "latent2rgb", "taesd", "none"]
|
||||
|
||||
|
||||
def test_get_db_mode(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test getting the current database mode.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/db_mode"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify the response is one of the valid database modes
|
||||
assert response.text in ["channel", "local", "remote"]
|
||||
|
||||
|
||||
def test_get_component_policy(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test getting the current component policy.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/policy/component"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Component policy could be any string
|
||||
assert response.text is not None
|
||||
|
||||
|
||||
def test_get_update_policy(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test getting the current update policy.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/policy/update"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify the response is one of the valid update policies
|
||||
assert response.text in ["stable", "nightly", "nightly-comfyui"]
|
||||
|
||||
|
||||
def test_get_channel_url_list(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict
|
||||
):
|
||||
"""
|
||||
Test getting the channel URL list.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/channel_url_list"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response contains the expected fields
|
||||
assert "selected" in json_data
|
||||
assert "list" in json_data
|
||||
assert isinstance(json_data["list"], list)
|
||||
|
||||
# Each channel should have a name and URL
|
||||
if json_data["list"]:
|
||||
first_channel = json_data["list"][0]
|
||||
assert "name" in first_channel
|
||||
assert "url" in first_channel
|
||||
|
||||
|
||||
def test_get_manager_version(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test getting the manager version.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/version"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify the response is a version string
|
||||
assert response.text.startswith("V") # Version strings start with V
|
||||
|
||||
|
||||
def test_get_manager_notice(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test getting the manager notice.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/notice"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify the response is HTML content
|
||||
assert response.headers.get("Content-Type", "").startswith("text/html") or "ComfyUI" in response.text
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operations")
|
||||
class TestConfigChanges:
|
||||
"""
|
||||
Tests for changing configuration settings.
|
||||
These are skipped to avoid modifying state in automated tests.
|
||||
"""
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def save_original_config(self, api_request: Callable):
|
||||
"""
|
||||
Save the original configuration to restore after tests.
|
||||
"""
|
||||
# Save original values
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/manager/preview_method",
|
||||
expected_status=200,
|
||||
)
|
||||
self.original_preview_method = response.text
|
||||
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/manager/db_mode",
|
||||
expected_status=200,
|
||||
)
|
||||
self.original_db_mode = response.text
|
||||
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/manager/policy/update",
|
||||
expected_status=200,
|
||||
)
|
||||
self.original_update_policy = response.text
|
||||
|
||||
yield
|
||||
|
||||
# Restore original values
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/preview_method",
|
||||
params={"value": self.original_preview_method},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/db_mode",
|
||||
params={"value": self.original_db_mode},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/policy/update",
|
||||
params={"value": self.original_update_policy},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
def test_set_preview_method(self, api_request: Callable):
|
||||
"""
|
||||
Test setting the preview method.
|
||||
"""
|
||||
# Set to a different value (taesd)
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/preview_method",
|
||||
params={"value": "taesd"},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify it was changed
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/manager/preview_method",
|
||||
expected_status=200,
|
||||
)
|
||||
assert response.text == "taesd"
|
||||
|
||||
def test_set_db_mode(self, api_request: Callable):
|
||||
"""
|
||||
Test setting the database mode.
|
||||
"""
|
||||
# Set to local mode
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/db_mode",
|
||||
params={"value": "local"},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify it was changed
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/manager/db_mode",
|
||||
expected_status=200,
|
||||
)
|
||||
assert response.text == "local"
|
||||
|
||||
def test_set_update_policy(self, api_request: Callable):
|
||||
"""
|
||||
Test setting the update policy.
|
||||
"""
|
||||
# Set to stable
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/policy/update",
|
||||
params={"value": "stable"},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify it was changed
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/manager/policy/update",
|
||||
expected_status=200,
|
||||
)
|
||||
assert response.text == "stable"
|
||||
200
tests-api/test_customnode_api.py
Normal file
200
tests-api/test_customnode_api.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for custom node management endpoints.
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
from utils.validation import validate_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
["local", "remote"]
|
||||
)
|
||||
def test_get_custom_node_list(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict,
|
||||
mode: str
|
||||
):
|
||||
"""
|
||||
Test the endpoint for listing custom nodes.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/customnode/getlist"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
params={"mode": mode, "skip_update": "true"},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response contains the expected fields
|
||||
assert "channel" in json_data
|
||||
assert "node_packs" in json_data
|
||||
assert isinstance(json_data["node_packs"], dict)
|
||||
|
||||
# If there are any node packs, verify they have the expected structure
|
||||
if json_data["node_packs"]:
|
||||
# Take the first node pack to validate
|
||||
first_node_pack = next(iter(json_data["node_packs"].values()))
|
||||
assert "title" in first_node_pack
|
||||
assert "name" in first_node_pack
|
||||
|
||||
|
||||
def test_get_installed_nodes(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict
|
||||
):
|
||||
"""
|
||||
Test the endpoint for listing installed nodes.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/customnode/installed"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response is a dictionary of node packs
|
||||
assert isinstance(json_data, dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
["local", "nickname"]
|
||||
)
|
||||
def test_get_node_mappings(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict,
|
||||
mode: str
|
||||
):
|
||||
"""
|
||||
Test the endpoint for getting node-to-package mappings.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/customnode/getmappings"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
params={"mode": mode},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response is a dictionary mapping extension IDs to node info
|
||||
assert isinstance(json_data, dict)
|
||||
|
||||
# If there are any mappings, verify they have the expected structure
|
||||
if json_data:
|
||||
# Take the first mapping to validate
|
||||
first_mapping = next(iter(json_data.values()))
|
||||
assert isinstance(first_mapping, list)
|
||||
assert len(first_mapping) == 2
|
||||
assert isinstance(first_mapping[0], list) # List of node classes
|
||||
assert isinstance(first_mapping[1], dict) # Metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
["local", "remote"]
|
||||
)
|
||||
def test_get_node_alternatives(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict,
|
||||
mode: str
|
||||
):
|
||||
"""
|
||||
Test the endpoint for getting alternative node options.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/customnode/alternatives"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
params={"mode": mode},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response is a dictionary
|
||||
assert isinstance(json_data, dict)
|
||||
|
||||
|
||||
def test_fetch_updates(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test the endpoint for fetching updates.
|
||||
This might modify state, so we just check for a valid response.
|
||||
"""
|
||||
# Make the API request with skip_update=true to avoid actual updates
|
||||
path = "/customnode/fetch_updates"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
params={"mode": "local"},
|
||||
# Don't validate JSON since this endpoint doesn't return JSON
|
||||
expected_status=200,
|
||||
retry_on_error=False, # Don't retry as this might have side effects
|
||||
)
|
||||
|
||||
# Just check the status code is as expected (covered by api_request)
|
||||
assert response.status_code in [200, 201]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Queue endpoints are better tested with queue operations")
|
||||
def test_queue_update_all(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test the endpoint for queuing updates for all nodes.
|
||||
Skipping as this would actually modify the installation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Security-restricted endpoint")
|
||||
def test_install_node_via_git_url(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test the endpoint for installing a node via Git URL.
|
||||
Skipping as this requires high security level and would modify the installation.
|
||||
"""
|
||||
pass
|
||||
23
tests-api/test_import.py
Normal file
23
tests-api/test_import.py
Normal file
@ -0,0 +1,23 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Print current working directory
|
||||
print(f"Current directory: {os.getcwd()}")
|
||||
|
||||
# Print module search path
|
||||
print(f"System path: {sys.path}")
|
||||
|
||||
# Try to import
|
||||
try:
|
||||
from utils.validation import load_openapi_spec
|
||||
print("Import successful!")
|
||||
except ImportError as e:
|
||||
print(f"Import error: {e}")
|
||||
|
||||
# Try direct import
|
||||
try:
|
||||
sys.path.insert(0, os.path.join(os.getcwd(), "custom_nodes/ComfyUI-Manager/tests-api"))
|
||||
from utils.validation import load_openapi_spec
|
||||
print("Direct import successful!")
|
||||
except ImportError as e:
|
||||
print(f"Direct import error: {e}")
|
||||
62
tests-api/test_model_api.py
Normal file
62
tests-api/test_model_api.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
Tests for model management endpoints.
|
||||
These features are scheduled for deprecation, so tests are minimal.
|
||||
"""
|
||||
import pytest
|
||||
from typing import Callable, Dict
|
||||
|
||||
from utils.validation import validate_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
["local", "remote"]
|
||||
)
|
||||
def test_get_external_model_list(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict,
|
||||
mode: str
|
||||
):
|
||||
"""
|
||||
Test the endpoint for listing external models.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/externalmodel/getlist"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
params={"mode": mode},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response contains the expected fields
|
||||
assert "models" in json_data
|
||||
assert isinstance(json_data["models"], list)
|
||||
|
||||
# If there are any models, verify they have the expected structure
|
||||
if json_data["models"]:
|
||||
first_model = json_data["models"][0]
|
||||
assert "name" in first_model
|
||||
assert "type" in first_model
|
||||
assert "url" in first_model
|
||||
assert "filename" in first_model
|
||||
assert "installed" in first_model
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operation that requires auth")
|
||||
def test_install_model():
|
||||
"""
|
||||
Test queuing a model installation.
|
||||
Skipped to avoid modifying state and requires authentication.
|
||||
This feature is also scheduled for deprecation.
|
||||
"""
|
||||
pass
|
||||
213
tests-api/test_queue_api.py
Normal file
213
tests-api/test_queue_api.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""
|
||||
Tests for queue management endpoints.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
from utils.validation import validate_response
|
||||
|
||||
|
||||
def test_get_queue_status(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict
|
||||
):
|
||||
"""
|
||||
Test the endpoint for getting queue status.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/queue/status"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response contains the expected fields
|
||||
assert "total_count" in json_data
|
||||
assert "done_count" in json_data
|
||||
assert "in_progress_count" in json_data
|
||||
assert "is_processing" in json_data
|
||||
|
||||
# Type checks
|
||||
assert isinstance(json_data["total_count"], int)
|
||||
assert isinstance(json_data["done_count"], int)
|
||||
assert isinstance(json_data["in_progress_count"], int)
|
||||
assert isinstance(json_data["is_processing"], bool)
|
||||
|
||||
|
||||
def test_reset_queue(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test the endpoint for resetting the queue.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/manager/queue/reset"
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Now check the queue status to verify it was reset
|
||||
response2, json_data = api_request(
|
||||
method="get",
|
||||
path="/manager/queue/status",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Queue should be empty after reset
|
||||
assert json_data["total_count"] == json_data["done_count"] + json_data["in_progress_count"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operation that requires auth")
|
||||
def test_queue_install_node():
|
||||
"""
|
||||
Test queuing a node installation.
|
||||
Skipped to avoid modifying state and requires authentication.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operation that requires auth")
|
||||
def test_queue_update_node():
|
||||
"""
|
||||
Test queuing a node update.
|
||||
Skipped to avoid modifying state and requires authentication.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operation that requires auth")
|
||||
def test_queue_uninstall_node():
|
||||
"""
|
||||
Test queuing a node uninstallation.
|
||||
Skipped to avoid modifying state and requires authentication.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operation")
|
||||
def test_queue_start():
|
||||
"""
|
||||
Test starting the queue.
|
||||
Skipped to avoid modifying state.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestQueueOperations:
|
||||
"""
|
||||
Test a complete queue workflow.
|
||||
These tests are grouped to ensure proper sequencing but are still skipped
|
||||
to avoid modifying state in automated tests.
|
||||
"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def node_data(self) -> Dict:
|
||||
"""
|
||||
Create test data for a node operation.
|
||||
"""
|
||||
# This would be replaced with actual data for a known safe node
|
||||
return {
|
||||
"ui_id": "test_node_1",
|
||||
"id": "comfyui-manager", # Manager itself
|
||||
"version": "latest",
|
||||
"channel": "default",
|
||||
"mode": "local",
|
||||
}
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operation")
|
||||
def test_queue_operation_sequence(
|
||||
self,
|
||||
api_request: Callable,
|
||||
node_data: Dict
|
||||
):
|
||||
"""
|
||||
Test the queue operation sequence.
|
||||
"""
|
||||
# 1. Reset the queue
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/queue/reset",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# 2. Queue a node operation (we'll use the manager itself)
|
||||
api_request(
|
||||
method="post",
|
||||
path="/manager/queue/update",
|
||||
json_data=node_data,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# 3. Check queue status - should have one operation
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path="/manager/queue/status",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
assert json_data["total_count"] > 0
|
||||
assert not json_data["is_processing"] # Queue hasn't started yet
|
||||
|
||||
# 4. Start the queue
|
||||
api_request(
|
||||
method="get",
|
||||
path="/manager/queue/start",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# 5. Check queue status again - should be processing
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path="/manager/queue/status",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Queue should be processing or already done
|
||||
assert json_data["is_processing"] or json_data["done_count"] == json_data["total_count"]
|
||||
|
||||
# 6. Wait for queue to complete (with timeout)
|
||||
max_wait_time = 60 # seconds
|
||||
start_time = time.time()
|
||||
completed = False
|
||||
|
||||
while time.time() - start_time < max_wait_time:
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path="/manager/queue/status",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
if json_data["done_count"] == json_data["total_count"] and not json_data["is_processing"]:
|
||||
completed = True
|
||||
break
|
||||
|
||||
time.sleep(2) # Wait before checking again
|
||||
|
||||
assert completed, "Queue did not complete within timeout period"
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying operation")
|
||||
def test_concurrent_queue_operations(
|
||||
self,
|
||||
api_request: Callable,
|
||||
node_data: Dict
|
||||
):
|
||||
"""
|
||||
Test concurrent queue operations.
|
||||
"""
|
||||
# This would test adding multiple operations to the queue
|
||||
# and verifying they all complete correctly
|
||||
pass
|
||||
198
tests-api/test_snapshot_api.py
Normal file
198
tests-api/test_snapshot_api.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""
|
||||
Tests for snapshot management endpoints.
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from utils.validation import validate_response
|
||||
|
||||
|
||||
def test_get_snapshot_list(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict
|
||||
):
|
||||
"""
|
||||
Test the endpoint for listing snapshots.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/snapshot/getlist"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Verify the response contains the expected fields
|
||||
assert "items" in json_data
|
||||
assert isinstance(json_data["items"], list)
|
||||
|
||||
|
||||
def test_get_current_snapshot(
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict
|
||||
):
|
||||
"""
|
||||
Test the endpoint for getting the current snapshot.
|
||||
"""
|
||||
# Make the API request
|
||||
path = "/snapshot/get_current"
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path=path,
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate response structure against the schema
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path=path,
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
# Check for basic snapshot structure
|
||||
assert "snapshot_date" in json_data
|
||||
assert "custom_nodes" in json_data
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test creates a snapshot which is a state-modifying operation")
|
||||
def test_save_snapshot(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test the endpoint for saving a new snapshot.
|
||||
Skipped to avoid modifying state in tests.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test removes a snapshot which is a destructive operation")
|
||||
def test_remove_snapshot(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test the endpoint for removing a snapshot.
|
||||
Skipped to avoid modifying state in tests.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test restores a snapshot which is a state-modifying operation")
|
||||
def test_restore_snapshot(
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test the endpoint for restoring a snapshot.
|
||||
Skipped to avoid modifying state in tests.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TestSnapshotWorkflow:
|
||||
"""
|
||||
Test the complete snapshot workflow (create, list, get, remove).
|
||||
These tests are grouped to ensure proper sequencing but are still skipped
|
||||
to avoid modifying state in automated tests.
|
||||
"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def snapshot_name(self) -> str:
|
||||
"""
|
||||
Generate a unique snapshot name for testing.
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return f"test_snapshot_{timestamp}"
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying test")
|
||||
def test_create_snapshot(
|
||||
self,
|
||||
api_request: Callable,
|
||||
snapshot_name: str
|
||||
):
|
||||
"""
|
||||
Test creating a snapshot.
|
||||
"""
|
||||
# Make the API request to save a snapshot
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/snapshot/save",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify a snapshot was created (would need to check the snapshot list)
|
||||
response2, json_data = api_request(
|
||||
method="get",
|
||||
path="/snapshot/getlist",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# The most recently created snapshot should be first in the list
|
||||
assert json_data["items"]
|
||||
|
||||
# Store the snapshot name for later tests
|
||||
self.actual_snapshot_name = json_data["items"][0]
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying test")
|
||||
def test_get_snapshot_details(
|
||||
self,
|
||||
api_request: Callable,
|
||||
openapi_spec: Dict
|
||||
):
|
||||
"""
|
||||
Test getting details of the created snapshot.
|
||||
"""
|
||||
# This would check the current snapshot, not a specific one
|
||||
# since there's no direct API to get a specific snapshot
|
||||
response, json_data = api_request(
|
||||
method="get",
|
||||
path="/snapshot/get_current",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Validate the snapshot data
|
||||
assert json_data is not None
|
||||
validate_response(
|
||||
response_data=json_data,
|
||||
path="/snapshot/get_current",
|
||||
method="get",
|
||||
spec=openapi_spec,
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="State-modifying test")
|
||||
def test_remove_test_snapshot(
|
||||
self,
|
||||
api_request: Callable
|
||||
):
|
||||
"""
|
||||
Test removing the test snapshot.
|
||||
"""
|
||||
# Make the API request to remove the snapshot
|
||||
response, _ = api_request(
|
||||
method="get",
|
||||
path="/snapshot/remove",
|
||||
params={"target": self.actual_snapshot_name},
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# Verify the snapshot was removed
|
||||
response2, json_data = api_request(
|
||||
method="get",
|
||||
path="/snapshot/getlist",
|
||||
expected_status=200,
|
||||
)
|
||||
|
||||
# The snapshot should no longer be in the list
|
||||
assert self.actual_snapshot_name not in json_data["items"]
|
||||
150
tests-api/test_spec_validation.py
Normal file
150
tests-api/test_spec_validation.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""
|
||||
Tests for validating the OpenAPI specification.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from pathlib import Path
|
||||
from openapi_spec_validator import validate_spec
|
||||
from utils.validation import load_openapi_spec
|
||||
from utils.schema_utils import (
|
||||
get_all_paths,
|
||||
get_methods_for_path,
|
||||
find_paths_with_security,
|
||||
get_required_parameters
|
||||
)
|
||||
|
||||
|
||||
def test_spec_is_valid():
|
||||
"""
|
||||
Test that the OpenAPI specification is valid according to the spec validator.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
validate_spec(spec)
|
||||
|
||||
|
||||
def test_spec_has_info():
|
||||
"""
|
||||
Test that the OpenAPI specification has basic info.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
|
||||
assert "info" in spec
|
||||
assert "title" in spec["info"]
|
||||
assert "version" in spec["info"]
|
||||
assert spec["info"]["title"] == "ComfyUI-Manager API"
|
||||
|
||||
|
||||
def test_spec_has_paths():
|
||||
"""
|
||||
Test that the OpenAPI specification has paths defined.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
|
||||
assert "paths" in spec
|
||||
assert len(spec["paths"]) > 0
|
||||
|
||||
|
||||
def test_paths_have_responses():
|
||||
"""
|
||||
Test that all paths have responses defined.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
|
||||
for path, path_item in spec["paths"].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}:
|
||||
continue
|
||||
|
||||
assert "responses" in operation, f"Path {path} method {method} has no responses"
|
||||
assert len(operation["responses"]) > 0, f"Path {path} method {method} has empty responses"
|
||||
|
||||
|
||||
def test_responses_have_schemas():
|
||||
"""
|
||||
Test that responses with application/json content type have schemas.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
|
||||
for path, path_item in spec["paths"].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}:
|
||||
continue
|
||||
|
||||
for status, response in operation["responses"].items():
|
||||
if "content" not in response:
|
||||
continue
|
||||
|
||||
if "application/json" in response["content"]:
|
||||
assert "schema" in response["content"]["application/json"], (
|
||||
f"Path {path} method {method} status {status} "
|
||||
f"application/json content has no schema"
|
||||
)
|
||||
|
||||
|
||||
def test_required_parameters_have_schemas():
|
||||
"""
|
||||
Test that all required parameters have schemas.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
|
||||
for path, path_item in spec["paths"].items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}:
|
||||
continue
|
||||
|
||||
if "parameters" not in operation:
|
||||
continue
|
||||
|
||||
for param in operation["parameters"]:
|
||||
if param.get("required", False):
|
||||
assert "schema" in param, (
|
||||
f"Path {path} method {method} required parameter {param.get('name')} has no schema"
|
||||
)
|
||||
|
||||
|
||||
def test_security_schemes_defined():
|
||||
"""
|
||||
Test that security schemes are properly defined.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
|
||||
# Get paths requiring security
|
||||
secure_paths = find_paths_with_security(spec)
|
||||
|
||||
if secure_paths:
|
||||
assert "components" in spec, "Spec has secure paths but no components"
|
||||
assert "securitySchemes" in spec["components"], "Spec has secure paths but no securitySchemes"
|
||||
|
||||
# Check each security reference is defined
|
||||
for path, method in secure_paths:
|
||||
operation = spec["paths"][path][method]
|
||||
for security_req in operation["security"]:
|
||||
for scheme_name in security_req:
|
||||
assert scheme_name in spec["components"]["securitySchemes"], (
|
||||
f"Security scheme {scheme_name} used by {method.upper()} {path} "
|
||||
f"is not defined in components.securitySchemes"
|
||||
)
|
||||
|
||||
|
||||
def test_common_endpoint_groups_present():
|
||||
"""
|
||||
Test that the spec includes the main endpoint groups.
|
||||
"""
|
||||
spec = load_openapi_spec()
|
||||
paths = get_all_paths(spec)
|
||||
|
||||
# Define the expected endpoint prefixes
|
||||
expected_prefixes = [
|
||||
"/customnode/",
|
||||
"/externalmodel/",
|
||||
"/manager/",
|
||||
"/snapshot/",
|
||||
"/comfyui_manager/",
|
||||
]
|
||||
|
||||
# Check that at least one path exists for each expected prefix
|
||||
for prefix in expected_prefixes:
|
||||
matching_paths = [p for p in paths if p.startswith(prefix)]
|
||||
assert matching_paths, f"No endpoints found with prefix {prefix}"
|
||||
1
tests-api/utils/__init__.py
Normal file
1
tests-api/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Make utils directory a proper package
|
||||
174
tests-api/utils/schema_utils.py
Normal file
174
tests-api/utils/schema_utils.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
Schema utilities for extracting and manipulating OpenAPI schemas.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from .validation import load_openapi_spec
|
||||
|
||||
|
||||
def get_all_paths(spec: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Get all paths defined in the OpenAPI specification.
|
||||
|
||||
Args:
|
||||
spec: The OpenAPI specification
|
||||
|
||||
Returns:
|
||||
List of all paths
|
||||
"""
|
||||
return list(spec.get("paths", {}).keys())
|
||||
|
||||
|
||||
def get_grouped_paths(spec: Dict[str, Any]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Group paths by their top-level segment.
|
||||
|
||||
Args:
|
||||
spec: The OpenAPI specification
|
||||
|
||||
Returns:
|
||||
Dictionary mapping top-level segments to lists of paths
|
||||
"""
|
||||
result = {}
|
||||
|
||||
for path in get_all_paths(spec):
|
||||
segments = path.strip("/").split("/")
|
||||
if not segments:
|
||||
continue
|
||||
|
||||
top_segment = segments[0]
|
||||
if top_segment not in result:
|
||||
result[top_segment] = []
|
||||
|
||||
result[top_segment].append(path)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_methods_for_path(spec: Dict[str, Any], path: str) -> List[str]:
|
||||
"""
|
||||
Get all HTTP methods defined for a path.
|
||||
|
||||
Args:
|
||||
spec: The OpenAPI specification
|
||||
path: The API path
|
||||
|
||||
Returns:
|
||||
List of HTTP methods (lowercase)
|
||||
"""
|
||||
if path not in spec.get("paths", {}):
|
||||
return []
|
||||
|
||||
return [
|
||||
method.lower()
|
||||
for method in spec["paths"][path].keys()
|
||||
if method.lower() in {"get", "post", "put", "delete", "patch", "options", "head"}
|
||||
]
|
||||
|
||||
|
||||
def find_paths_with_security(
|
||||
spec: Dict[str, Any],
|
||||
security_scheme: Optional[str] = None
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Find all paths that require security.
|
||||
|
||||
Args:
|
||||
spec: The OpenAPI specification
|
||||
security_scheme: Optional specific security scheme to filter by
|
||||
|
||||
Returns:
|
||||
List of (path, method) tuples that require security
|
||||
"""
|
||||
result = []
|
||||
|
||||
for path, path_item in spec.get("paths", {}).items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}:
|
||||
continue
|
||||
|
||||
if "security" in operation:
|
||||
if security_scheme is None:
|
||||
result.append((path, method.lower()))
|
||||
else:
|
||||
# Check if this security scheme is required
|
||||
for security_req in operation["security"]:
|
||||
if security_scheme in security_req:
|
||||
result.append((path, method.lower()))
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_content_types_for_response(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str,
|
||||
status_code: str = "200"
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get content types defined for a response.
|
||||
|
||||
Args:
|
||||
spec: The OpenAPI specification
|
||||
path: The API path
|
||||
method: The HTTP method
|
||||
status_code: The HTTP status code
|
||||
|
||||
Returns:
|
||||
List of content types
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
if path not in spec["paths"]:
|
||||
return []
|
||||
|
||||
if method not in spec["paths"][path]:
|
||||
return []
|
||||
|
||||
if "responses" not in spec["paths"][path][method]:
|
||||
return []
|
||||
|
||||
if status_code not in spec["paths"][path][method]["responses"]:
|
||||
return []
|
||||
|
||||
response_def = spec["paths"][path][method]["responses"][status_code]
|
||||
|
||||
if "content" not in response_def:
|
||||
return []
|
||||
|
||||
return list(response_def["content"].keys())
|
||||
|
||||
|
||||
def get_required_parameters(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all required parameters for a path/method.
|
||||
|
||||
Args:
|
||||
spec: The OpenAPI specification
|
||||
path: The API path
|
||||
method: The HTTP method
|
||||
|
||||
Returns:
|
||||
List of parameter objects that are required
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
if path not in spec["paths"]:
|
||||
return []
|
||||
|
||||
if method not in spec["paths"][path]:
|
||||
return []
|
||||
|
||||
if "parameters" not in spec["paths"][path][method]:
|
||||
return []
|
||||
|
||||
return [
|
||||
param for param in spec["paths"][path][method]["parameters"]
|
||||
if param.get("required", False)
|
||||
]
|
||||
155
tests-api/utils/validation.py
Normal file
155
tests-api/utils/validation.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""
|
||||
Validation utilities for API tests.
|
||||
"""
|
||||
import json
|
||||
import jsonschema
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
|
||||
def load_openapi_spec(spec_path: Union[str, Path] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Load the OpenAPI specification document.
|
||||
|
||||
Args:
|
||||
spec_path: Path to the OpenAPI specification file
|
||||
|
||||
Returns:
|
||||
The OpenAPI specification as a dictionary
|
||||
"""
|
||||
if spec_path is None:
|
||||
# Default to the root openapi.yaml file
|
||||
spec_path = Path(__file__).parents[2] / "openapi.yaml"
|
||||
|
||||
with open(spec_path, "r") as f:
|
||||
if str(spec_path).endswith(".yaml") or str(spec_path).endswith(".yml"):
|
||||
return yaml.safe_load(f)
|
||||
else:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_schema_for_path(
|
||||
spec: Dict[str, Any],
|
||||
path: str,
|
||||
method: str,
|
||||
status_code: str = "200",
|
||||
content_type: str = "application/json"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract the response schema for a specific path, method, and status code.
|
||||
|
||||
Args:
|
||||
spec: The OpenAPI specification
|
||||
path: The API path (e.g., "/customnode/getlist")
|
||||
method: The HTTP method (e.g., "get", "post")
|
||||
status_code: The HTTP status code (default: "200")
|
||||
content_type: The response content type (default: "application/json")
|
||||
|
||||
Returns:
|
||||
The schema for the specified path and method, or None if not found
|
||||
"""
|
||||
method = method.lower()
|
||||
|
||||
if path not in spec["paths"]:
|
||||
return None
|
||||
|
||||
if method not in spec["paths"][path]:
|
||||
return None
|
||||
|
||||
if "responses" not in spec["paths"][path][method]:
|
||||
return None
|
||||
|
||||
if status_code not in spec["paths"][path][method]["responses"]:
|
||||
return None
|
||||
|
||||
response_def = spec["paths"][path][method]["responses"][status_code]
|
||||
|
||||
if "content" not in response_def:
|
||||
return None
|
||||
|
||||
if content_type not in response_def["content"]:
|
||||
return None
|
||||
|
||||
if "schema" not in response_def["content"][content_type]:
|
||||
return None
|
||||
|
||||
return response_def["content"][content_type]["schema"]
|
||||
|
||||
|
||||
def validate_response_schema(
|
||||
response_data: Any,
|
||||
schema: Dict[str, Any],
|
||||
spec: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validate a response against a schema from the OpenAPI specification.
|
||||
|
||||
Args:
|
||||
response_data: The response data to validate
|
||||
schema: The schema to validate against
|
||||
spec: The complete OpenAPI specification (for resolving references)
|
||||
|
||||
Returns:
|
||||
True if validation succeeds, raises an exception otherwise
|
||||
"""
|
||||
if spec is None:
|
||||
spec = load_openapi_spec()
|
||||
|
||||
# Create a resolver for references within the schema
|
||||
resolver = jsonschema.RefResolver.from_schema(spec)
|
||||
|
||||
# Validate the response against the schema
|
||||
jsonschema.validate(
|
||||
instance=response_data,
|
||||
schema=schema,
|
||||
resolver=resolver
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_response(
|
||||
response_data: Any,
|
||||
path: str,
|
||||
method: str,
|
||||
status_code: str = "200",
|
||||
content_type: str = "application/json",
|
||||
spec: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validate a response against the schema defined in the OpenAPI specification.
|
||||
|
||||
Args:
|
||||
response_data: The response data to validate
|
||||
path: The API path
|
||||
method: The HTTP method
|
||||
status_code: The HTTP status code (default: "200")
|
||||
content_type: The response content type (default: "application/json")
|
||||
spec: The OpenAPI specification (loaded from default location if None)
|
||||
|
||||
Returns:
|
||||
True if validation succeeds, raises an exception otherwise
|
||||
"""
|
||||
if spec is None:
|
||||
spec = load_openapi_spec()
|
||||
|
||||
schema = get_schema_for_path(
|
||||
spec=spec,
|
||||
path=path,
|
||||
method=method,
|
||||
status_code=status_code,
|
||||
content_type=content_type
|
||||
)
|
||||
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
f"No schema found for {method.upper()} {path} "
|
||||
f"with status {status_code} and content type {content_type}"
|
||||
)
|
||||
|
||||
return validate_response_schema(
|
||||
response_data=response_data,
|
||||
schema=schema,
|
||||
spec=spec
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user