From da87651e539495f186480f27952ed88d260c0676 Mon Sep 17 00:00:00 2001 From: bymyself Date: Tue, 20 May 2025 16:35:40 -0700 Subject: [PATCH] [tests] Add API test suite --- tests-api/.gitignore | 19 ++ tests-api/README.md | 91 ++++++ tests-api/__init__.py | 1 + tests-api/conftest.py | 237 +++++++++++++++ tests-api/mocks/__init__.py | 1 + tests-api/mocks/custom_node_manager.py | 26 ++ tests-api/mocks/patch.py | 116 ++++++++ tests-api/mocks/prompt_server.py | 71 +++++ tests-api/mocks/utils.py | 20 ++ tests-api/openapi.yaml | 382 +++++++++++++++++++++++++ tests-api/requirements-test.txt | 6 + tests-api/test_config_api.py | 270 +++++++++++++++++ tests-api/test_customnode_api.py | 200 +++++++++++++ tests-api/test_import.py | 23 ++ tests-api/test_model_api.py | 62 ++++ tests-api/test_queue_api.py | 213 ++++++++++++++ tests-api/test_snapshot_api.py | 198 +++++++++++++ tests-api/test_spec_validation.py | 150 ++++++++++ tests-api/utils/__init__.py | 1 + tests-api/utils/schema_utils.py | 174 +++++++++++ tests-api/utils/validation.py | 155 ++++++++++ 21 files changed, 2416 insertions(+) create mode 100644 tests-api/.gitignore create mode 100644 tests-api/README.md create mode 100644 tests-api/__init__.py create mode 100644 tests-api/conftest.py create mode 100644 tests-api/mocks/__init__.py create mode 100644 tests-api/mocks/custom_node_manager.py create mode 100644 tests-api/mocks/patch.py create mode 100644 tests-api/mocks/prompt_server.py create mode 100644 tests-api/mocks/utils.py create mode 100644 tests-api/openapi.yaml create mode 100644 tests-api/requirements-test.txt create mode 100644 tests-api/test_config_api.py create mode 100644 tests-api/test_customnode_api.py create mode 100644 tests-api/test_import.py create mode 100644 tests-api/test_model_api.py create mode 100644 tests-api/test_queue_api.py create mode 100644 tests-api/test_snapshot_api.py create mode 100644 tests-api/test_spec_validation.py create mode 100644 tests-api/utils/__init__.py create mode 100644 tests-api/utils/schema_utils.py create mode 100644 tests-api/utils/validation.py diff --git a/tests-api/.gitignore b/tests-api/.gitignore new file mode 100644 index 00000000..8ae3aad9 --- /dev/null +++ b/tests-api/.gitignore @@ -0,0 +1,19 @@ +# Python cache files +__pycache__/ +*.py[cod] +*$py.class + +# Pytest cache +.pytest_cache/ + +# Coverage reports +.coverage +htmlcov/ + +# Virtual environments +venv/ +env/ +ENV/ + +# Test-specific resources +resources/tmp/ \ No newline at end of file diff --git a/tests-api/README.md b/tests-api/README.md new file mode 100644 index 00000000..b55e1800 --- /dev/null +++ b/tests-api/README.md @@ -0,0 +1,91 @@ +# ComfyUI-Manager API Tests + +This directory contains tests for the ComfyUI-Manager API endpoints, validating the OpenAPI specification and ensuring API functionality. + +## Setup + +1. Install test dependencies: + +```bash +pip install -r requirements-test.txt +``` + +2. Ensure ComfyUI is running with ComfyUI-Manager installed: + +```bash +# Start ComfyUI with the default server +python main.py +``` + +## Running Tests + +### Run all tests + +```bash +pytest -xvs +``` + +### Run specific test files + +```bash +# Run only the spec validation tests +pytest -xvs test_spec_validation.py + +# Run only the custom node API tests +pytest -xvs test_customnode_api.py +``` + +### Run specific test functions + +```bash +# Run a specific test +pytest -xvs test_customnode_api.py::test_get_custom_node_list +``` + +## Test Configuration + +The tests use the following default configuration: + +- Server URL: `http://localhost:8188` +- Server timeout: 2 seconds +- Wait between requests: 0.5 seconds +- Maximum retries: 3 + +You can override these settings with environment variables: + +```bash +# Use a different server URL +COMFYUI_SERVER_URL=http://localhost:8189 pytest -xvs +``` + +## Test Categories + +The tests are organized into the following categories: + +1. **Spec Validation** (`test_spec_validation.py`): Validates that the OpenAPI specification is correct and complete. +2. **Custom Node API** (`test_customnode_api.py`): Tests for custom node management endpoints. +3. **Snapshot API** (`test_snapshot_api.py`): Tests for snapshot management endpoints. +4. **Queue API** (`test_queue_api.py`): Tests for queue management endpoints. +5. **Config API** (`test_config_api.py`): Tests for configuration endpoints. +6. **Model API** (`test_model_api.py`): Tests for model management endpoints (minimal as these are being deprecated). + +## Test Implementation Details + +### Fixtures + +- `test_config`: Provides the test configuration +- `server_url`: Returns the server URL from the configuration +- `openapi_spec`: Loads the OpenAPI specification +- `api_client`: Creates a requests Session for API calls +- `api_request`: Helper function for making consistent API requests + +### Utilities + +- `validation.py`: Functions for validating responses against the OpenAPI schema +- `schema_utils.py`: Utilities for extracting and manipulating schemas + +## Notes + +- Some tests are skipped with `@pytest.mark.skip` to avoid modifying state in automated testing +- Security-level restricted endpoints have minimal tests to avoid security issues +- Tests focus on read operations rather than write operations where possible \ No newline at end of file diff --git a/tests-api/__init__.py b/tests-api/__init__.py new file mode 100644 index 00000000..8277b84f --- /dev/null +++ b/tests-api/__init__.py @@ -0,0 +1 @@ +# Make tests-api directory a proper package \ No newline at end of file diff --git a/tests-api/conftest.py b/tests-api/conftest.py new file mode 100644 index 00000000..170d79e8 --- /dev/null +++ b/tests-api/conftest.py @@ -0,0 +1,237 @@ +""" +PyTest configuration and fixtures for API tests. +""" +import os +import sys +import json +import pytest +import requests +import tempfile +import time +import yaml +from pathlib import Path +from typing import Dict, Generator, Optional, Tuple + +# Import test utilities +import sys +import os +from pathlib import Path + +# Get the absolute path to the current file (conftest.py) +current_file = Path(os.path.abspath(__file__)) + +# Get the directory containing the current file (the tests-api directory) +tests_api_dir = current_file.parent + +# Add the tests-api directory to the Python path +if str(tests_api_dir) not in sys.path: + sys.path.insert(0, str(tests_api_dir)) + +# Apply mocks for ComfyUI imports +from mocks.patch import apply_mocks +apply_mocks() + +# Now we can import from utils.validation +from utils.validation import load_openapi_spec + + +# Default test configuration +DEFAULT_TEST_CONFIG = { + "server_url": "http://localhost:8188", + "server_timeout": 2, # seconds + "wait_between_requests": 0.5, # seconds + "max_retries": 3, +} + + +@pytest.fixture(scope="session") +def test_config() -> Dict: + """ + Load test configuration from environment variables or use defaults. + """ + config = DEFAULT_TEST_CONFIG.copy() + + # Override from environment variables if present + if "COMFYUI_SERVER_URL" in os.environ: + config["server_url"] = os.environ["COMFYUI_SERVER_URL"] + + return config + + +@pytest.fixture(scope="session") +def server_url(test_config: Dict) -> str: + """ + Get the server URL from the test configuration. + """ + return test_config["server_url"] + + +@pytest.fixture(scope="session") +def openapi_spec() -> Dict: + """ + Load the OpenAPI specification. + """ + return load_openapi_spec() + + +@pytest.fixture(scope="session") +def api_client(server_url: str, test_config: Dict) -> requests.Session: + """ + Create a requests Session for API calls. + """ + session = requests.Session() + + # Check if the server is running + try: + response = session.get(f"{server_url}/", timeout=test_config["server_timeout"]) + response.raise_for_status() + except (requests.ConnectionError, requests.Timeout, requests.HTTPError): + pytest.skip("ComfyUI server is not running or not accessible") + + return session + + +@pytest.fixture(scope="function") +def temp_dir() -> Generator[Path, None, None]: + """ + Create a temporary directory for test files. + """ + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +class SecurityLevelContext: + """ + Context manager for setting and restoring security levels. + """ + def __init__(self, api_client: requests.Session, server_url: str, security_level: str): + self.api_client = api_client + self.server_url = server_url + self.security_level = security_level + self.original_level = None + + async def __aenter__(self): + # Get the current security level (not directly exposed in API, would require more setup) + # For now, we'll just set the new level + + # Set the new security level + # Note: In a real implementation, we would need a way to set this + # This is a placeholder - the actual implementation would depend on how + # security levels are managed in ComfyUI-Manager + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Restore the original security level if needed + pass + + +@pytest.fixture +def security_level_context(api_client: requests.Session, server_url: str): + """ + Create a context manager for setting security levels. + """ + return lambda level: SecurityLevelContext(api_client, server_url, level) + + +def make_api_url(server_url: str, path: str) -> str: + """ + Construct a full API URL from the server URL and path. + """ + # Ensure the path starts with a slash + if not path.startswith("/"): + path = f"/{path}" + + # Remove trailing slash from server_url if present + if server_url.endswith("/"): + server_url = server_url[:-1] + + return f"{server_url}{path}" + + +@pytest.fixture +def api_request(api_client: requests.Session, server_url: str, test_config: Dict): + """ + Helper function for making API requests with consistent behavior. + """ + def _request( + method: str, + path: str, + params: Optional[Dict] = None, + json_data: Optional[Dict] = None, + headers: Optional[Dict] = None, + expected_status: int = 200, + retry_on_error: bool = True, + ) -> Tuple[requests.Response, Optional[Dict]]: + """ + Make an API request with automatic validation. + + Args: + method: HTTP method + path: API path + params: Query parameters + json_data: JSON request body + headers: HTTP headers + expected_status: Expected HTTP status code + retry_on_error: Whether to retry on connection errors + + Returns: + Tuple of (Response object, JSON response data or None) + """ + method = method.lower() + url = make_api_url(server_url, path) + + if headers is None: + headers = {} + + # Add common headers + headers.setdefault("Accept", "application/json") + + # Sleep between requests to avoid overwhelming the server + time.sleep(test_config["wait_between_requests"]) + + retries = test_config["max_retries"] if retry_on_error else 0 + last_exception = None + + for attempt in range(retries + 1): + try: + if method == "get": + response = api_client.get(url, params=params, headers=headers) + elif method == "post": + response = api_client.post(url, params=params, json=json_data, headers=headers) + elif method == "put": + response = api_client.put(url, params=params, json=json_data, headers=headers) + elif method == "delete": + response = api_client.delete(url, params=params, headers=headers) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + # Check status code + assert response.status_code == expected_status, ( + f"Expected status code {expected_status}, got {response.status_code}" + ) + + # Parse JSON response if possible + json_response = None + if response.headers.get("Content-Type", "").startswith("application/json"): + try: + json_response = response.json() + except json.JSONDecodeError: + if expected_status == 200: + raise ValueError("Response was not valid JSON") + + return response, json_response + + except (requests.ConnectionError, requests.Timeout) as e: + last_exception = e + if attempt < retries: + # Wait before retrying + time.sleep(1) + continue + break + + if last_exception: + raise last_exception + + raise RuntimeError("Failed to make API request") + + return _request \ No newline at end of file diff --git a/tests-api/mocks/__init__.py b/tests-api/mocks/__init__.py new file mode 100644 index 00000000..8cd8805c --- /dev/null +++ b/tests-api/mocks/__init__.py @@ -0,0 +1 @@ +# Make tests-api/mocks directory a proper package \ No newline at end of file diff --git a/tests-api/mocks/custom_node_manager.py b/tests-api/mocks/custom_node_manager.py new file mode 100644 index 00000000..c0d8aa77 --- /dev/null +++ b/tests-api/mocks/custom_node_manager.py @@ -0,0 +1,26 @@ +""" +Mock CustomNodeManager for testing purposes +""" + +class CustomNodeManager: + """ + Mock implementation of the CustomNodeManager class + """ + instance = None + + def __init__(self): + self.custom_nodes = {} + self.node_paths = [] + self.refresh_timeout = None + + def get_node_path(self, node_class): + """ + Mock implementation to get the path for a node class + """ + return self.custom_nodes.get(node_class, None) + + def update_node_paths(self): + """ + Mock implementation to update node paths + """ + pass \ No newline at end of file diff --git a/tests-api/mocks/patch.py b/tests-api/mocks/patch.py new file mode 100644 index 00000000..91d87812 --- /dev/null +++ b/tests-api/mocks/patch.py @@ -0,0 +1,116 @@ +""" +Patch module to mock imports for testing +""" +import sys +import importlib.util +import os +from pathlib import Path + +# Import mock modules +from mocks.prompt_server import PromptServer +from mocks.custom_node_manager import CustomNodeManager + +# Current directory +current_dir = Path(__file__).parent.parent # tests-api directory + +# Define mocks +class MockModule: + """Base class for mock modules""" + pass + +# Create server mock module with PromptServer +server_mock = MockModule() +server_mock.PromptServer = PromptServer +prompt_server_instance = PromptServer() +server_mock.PromptServer.instance = prompt_server_instance +server_mock.PromptServer.inst = prompt_server_instance + +# Create app mock module with custom_node_manager submodule +app_mock = MockModule() +app_custom_node_manager = MockModule() +app_custom_node_manager.CustomNodeManager = CustomNodeManager +app_custom_node_manager.CustomNodeManager.instance = CustomNodeManager() + +# Create utils mock module with json_util submodule +utils_mock = MockModule() +utils_json_util = MockModule() + +# Create utils.validation and utils.schema_utils submodules +utils_validation = MockModule() +utils_schema_utils = MockModule() + +# Import actual modules (make sure path is set up correctly) +sys.path.insert(0, str(current_dir)) + +try: + # Import the validation module + from utils.validation import load_openapi_spec + utils_validation.load_openapi_spec = load_openapi_spec + + # Import all schema_utils functions + from utils.schema_utils import ( + get_all_paths, + get_grouped_paths, + get_methods_for_path, + find_paths_with_security, + get_content_types_for_response, + get_required_parameters + ) + + utils_schema_utils.get_all_paths = get_all_paths + utils_schema_utils.get_grouped_paths = get_grouped_paths + utils_schema_utils.get_methods_for_path = get_methods_for_path + utils_schema_utils.find_paths_with_security = find_paths_with_security + utils_schema_utils.get_content_types_for_response = get_content_types_for_response + utils_schema_utils.get_required_parameters = get_required_parameters + +except ImportError as e: + print(f"Error importing test utilities: {e}") + # Define dummy functions if imports fail + def dummy_load_openapi_spec(): + """Dummy function for testing""" + return {"paths": {}} + utils_validation.load_openapi_spec = dummy_load_openapi_spec + + def dummy_get_all_paths(spec): + return list(spec.get("paths", {}).keys()) + utils_schema_utils.get_all_paths = dummy_get_all_paths + + def dummy_get_grouped_paths(spec): + return {} + utils_schema_utils.get_grouped_paths = dummy_get_grouped_paths + + def dummy_get_methods_for_path(spec, path): + return [] + utils_schema_utils.get_methods_for_path = dummy_get_methods_for_path + + def dummy_find_paths_with_security(spec, security_scheme=None): + return [] + utils_schema_utils.find_paths_with_security = dummy_find_paths_with_security + + def dummy_get_content_types_for_response(spec, path, method, status_code="200"): + return [] + utils_schema_utils.get_content_types_for_response = dummy_get_content_types_for_response + + def dummy_get_required_parameters(spec, path, method): + return [] + utils_schema_utils.get_required_parameters = dummy_get_required_parameters + +# Add merge_json_recursive from our mock utils +from mocks.utils import merge_json_recursive +utils_json_util.merge_json_recursive = merge_json_recursive + +# Apply the mocks to sys.modules +def apply_mocks(): + """Apply all mocks to sys.modules""" + sys.modules['server'] = server_mock + sys.modules['app'] = app_mock + sys.modules['app.custom_node_manager'] = app_custom_node_manager + sys.modules['utils'] = utils_mock + sys.modules['utils.json_util'] = utils_json_util + sys.modules['utils.validation'] = utils_validation + sys.modules['utils.schema_utils'] = utils_schema_utils + + # Make sure our actual utils module is importable + if current_dir not in sys.path: + sys.path.insert(0, str(current_dir)) \ No newline at end of file diff --git a/tests-api/mocks/prompt_server.py b/tests-api/mocks/prompt_server.py new file mode 100644 index 00000000..3276b36f --- /dev/null +++ b/tests-api/mocks/prompt_server.py @@ -0,0 +1,71 @@ +""" +Mock PromptServer for testing purposes +""" + +class MockRoutes: + """ + Mock routing class with method decorators + """ + def __init__(self): + self.routes = {} + + def get(self, path): + """Decorator for GET routes""" + def decorator(f): + self.routes[('GET', path)] = f + return f + return decorator + + def post(self, path): + """Decorator for POST routes""" + def decorator(f): + self.routes[('POST', path)] = f + return f + return decorator + + def put(self, path): + """Decorator for PUT routes""" + def decorator(f): + self.routes[('PUT', path)] = f + return f + return decorator + + def delete(self, path): + """Decorator for DELETE routes""" + def decorator(f): + self.routes[('DELETE', path)] = f + return f + return decorator + + +class PromptServer: + """ + Mock implementation of the PromptServer class + """ + instance = None + inst = None + + def __init__(self): + self.routes = MockRoutes() + self.registered_paths = set() + self.base_url = "http://127.0.0.1:8188" # Assuming server is running on default port + self.queue_lock = None + + def add_route(self, method, path, handler, *args, **kwargs): + """ + Add a mock route to the server + """ + self.routes.routes[(method.upper(), path)] = handler + self.registered_paths.add(path) + + async def send_msg(self, message, data=None): + """ + Mock send_msg method (does nothing in the mock) + """ + pass + + def send_sync(self, message, data=None): + """ + Mock send_sync method (does nothing in the mock) + """ + pass \ No newline at end of file diff --git a/tests-api/mocks/utils.py b/tests-api/mocks/utils.py new file mode 100644 index 00000000..52614f8e --- /dev/null +++ b/tests-api/mocks/utils.py @@ -0,0 +1,20 @@ +""" +Mock utils module for testing purposes +""" + +def merge_json_recursive(a, b): + """ + Mock implementation of merge_json_recursive + """ + if isinstance(a, dict) and isinstance(b, dict): + result = a.copy() + for key, value in b.items(): + if key in result and isinstance(result[key], (dict, list)) and isinstance(value, (dict, list)): + result[key] = merge_json_recursive(result[key], value) + else: + result[key] = value + return result + elif isinstance(a, list) and isinstance(b, list): + return a + b + else: + return b \ No newline at end of file diff --git a/tests-api/openapi.yaml b/tests-api/openapi.yaml new file mode 100644 index 00000000..e16f9397 --- /dev/null +++ b/tests-api/openapi.yaml @@ -0,0 +1,382 @@ +openapi: 3.0.3 +info: + title: ComfyUI-Manager API + description: API for managing ComfyUI extensions, custom nodes, and models + version: 1.0.0 + contact: + name: ComfyUI Community + url: https://github.com/comfyanonymous/ComfyUI + +servers: + - url: http://localhost:8188 + description: Local ComfyUI server + +paths: + /customnode/getlist: + get: + summary: Get the list of custom nodes + description: Returns the list of custom nodes from all configured channels + parameters: + - name: mode + in: query + description: "The mode to retrieve (local=installed nodes, remote=available nodes)" + schema: + type: string + enum: [local, remote] + default: remote + responses: + '200': + description: List of custom nodes + content: + application/json: + schema: + type: object + properties: + nodes: + type: array + items: + $ref: '#/components/schemas/CustomNode' + '500': + description: Server error + + /customnode/get_node_mappings: + get: + summary: Get mappings between node class names and their custom nodes + description: Returns mappings that help identify which custom node package provides specific node classes + parameters: + - name: mode + in: query + description: "The mode for mappings (local=installed nodes, nickname=node nicknames)" + schema: + type: string + enum: [local, nickname] + default: local + required: true + responses: + '200': + description: Node mappings + content: + application/json: + schema: + type: object + additionalProperties: + type: string + '500': + description: Server error + + /customnode/get_node_alternatives: + get: + summary: Get alternative nodes for specific node classes + description: Returns alternative implementations of node classes from different custom node packages + parameters: + - name: mode + in: query + description: "The mode to retrieve alternatives (local=installed nodes, remote=all available nodes)" + schema: + type: string + enum: [local, remote] + default: remote + responses: + '200': + description: Node alternatives + content: + application/json: + schema: + type: object + additionalProperties: + type: array + items: + type: string + '500': + description: Server error + + /externalmodel/getlist: + get: + summary: Get the list of external models + description: Returns the list of models from all configured channels + parameters: + - name: mode + in: query + description: "The mode to retrieve (local=installed models, remote=available models)" + schema: + type: string + enum: [local, remote] + default: remote + responses: + '200': + description: List of external models + content: + application/json: + schema: + type: object + properties: + models: + type: array + items: + $ref: '#/components/schemas/ExternalModel' + '500': + description: Server error + + /manager/get_config: + get: + summary: Get manager configuration + description: Returns the current configuration of ComfyUI-Manager + parameters: + - name: key + in: query + description: "The configuration key to retrieve" + schema: + type: string + required: true + responses: + '200': + description: Configuration value + content: + application/json: + schema: + type: object + properties: + value: + type: string + '400': + description: Invalid key or missing parameter + '500': + description: Server error + + /manager/set_config: + post: + summary: Set manager configuration + description: Updates the configuration of ComfyUI-Manager + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - key + - value + properties: + key: + type: string + description: "The configuration key to update" + value: + type: string + description: "The new value for the configuration key" + responses: + '200': + description: Configuration updated successfully + content: + application/json: + schema: + type: object + properties: + success: + type: boolean + '400': + description: Invalid key or value + '500': + description: Server error + + /snapshot/getlist: + get: + summary: Get the list of snapshots + description: Returns the list of saved snapshots + responses: + '200': + description: List of snapshots + content: + application/json: + schema: + type: object + properties: + snapshots: + type: array + items: + $ref: '#/components/schemas/Snapshot' + '500': + description: Server error + + /comfyui_manager/queue/status: + get: + summary: Get queue status + description: Returns the current status of the operation queue + responses: + '200': + description: Queue status + content: + application/json: + schema: + $ref: '#/components/schemas/QueueStatus' + '500': + description: Server error + +components: + schemas: + CustomNode: + type: object + required: + - name + - title + - reference + properties: + name: + type: string + description: "Internal name/ID of the custom node" + title: + type: string + description: "Display title of the custom node" + reference: + type: string + description: "Reference URL (usually GitHub repository URL)" + description: + type: string + description: "Description of what the custom node does" + install_type: + type: string + enum: [git, pip, copy] + description: "Installation method for the custom node" + files: + type: array + items: + type: string + description: "List of files provided by this custom node" + node_class_names: + type: array + items: + type: string + description: "List of node class names provided by this custom node" + installed: + type: boolean + description: "Whether the custom node is installed" + version: + type: string + description: "Version of the custom node" + tags: + type: array + items: + type: string + description: "Tags associated with the custom node" + + ExternalModel: + type: object + required: + - name + - type + - url + properties: + name: + type: string + description: "Name of the model" + type: + type: string + description: "Type of the model (checkpoint, lora, embedding, etc.)" + url: + type: string + description: "Download URL for the model" + description: + type: string + description: "Description of the model" + size: + type: integer + description: "Size of the model in bytes" + installed: + type: boolean + description: "Whether the model is installed" + version: + type: string + description: "Version of the model" + tags: + type: array + items: + type: string + description: "Tags associated with the model" + + Snapshot: + type: object + required: + - name + - date + properties: + name: + type: string + description: "Name of the snapshot" + date: + type: string + format: date-time + description: "Date when the snapshot was created" + description: + type: string + description: "Description of the snapshot" + nodes: + type: array + items: + type: string + description: "List of custom nodes in the snapshot" + models: + type: array + items: + type: string + description: "List of models in the snapshot" + + QueueStatus: + type: object + properties: + pending: + type: array + items: + $ref: '#/components/schemas/QueueItem' + description: "List of pending operations in the queue" + completed: + type: array + items: + $ref: '#/components/schemas/QueueItem' + description: "List of completed operations in the queue" + failed: + type: array + items: + $ref: '#/components/schemas/QueueItem' + description: "List of failed operations in the queue" + running: + type: boolean + description: "Whether the queue is currently running" + + QueueItem: + type: object + required: + - id + - type + - target + properties: + id: + type: string + description: "Unique ID of the queue item" + type: + type: string + enum: [install, update, uninstall] + description: "Type of operation" + target: + type: string + description: "Target of the operation (e.g., custom node name, model name)" + status: + type: string + enum: [pending, processing, completed, failed] + description: "Current status of the operation" + error: + type: string + description: "Error message if the operation failed" + created_at: + type: string + format: date-time + description: "Time when the operation was added to the queue" + completed_at: + type: string + format: date-time + description: "Time when the operation was completed" + + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key + description: "API key for authentication" \ No newline at end of file diff --git a/tests-api/requirements-test.txt b/tests-api/requirements-test.txt new file mode 100644 index 00000000..f596c26c --- /dev/null +++ b/tests-api/requirements-test.txt @@ -0,0 +1,6 @@ +pytest>=7.3.1 +requests>=2.31.0 +openapi-spec-validator>=0.6.0 +jsonschema>=4.17.3 +pytest-asyncio>=0.21.0 +pyyaml>=6.0 \ No newline at end of file diff --git a/tests-api/test_config_api.py b/tests-api/test_config_api.py new file mode 100644 index 00000000..3ee2d4f9 --- /dev/null +++ b/tests-api/test_config_api.py @@ -0,0 +1,270 @@ +""" +Tests for configuration endpoints. +""" +import pytest +from typing import Callable, Dict, List, Tuple + +from utils.validation import validate_response + + +def test_get_preview_method( + api_request: Callable +): + """ + Test getting the current preview method. + """ + # Make the API request + path = "/manager/preview_method" + response, _ = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Verify the response is one of the valid preview methods + assert response.text in ["auto", "latent2rgb", "taesd", "none"] + + +def test_get_db_mode( + api_request: Callable +): + """ + Test getting the current database mode. + """ + # Make the API request + path = "/manager/db_mode" + response, _ = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Verify the response is one of the valid database modes + assert response.text in ["channel", "local", "remote"] + + +def test_get_component_policy( + api_request: Callable +): + """ + Test getting the current component policy. + """ + # Make the API request + path = "/manager/policy/component" + response, _ = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Component policy could be any string + assert response.text is not None + + +def test_get_update_policy( + api_request: Callable +): + """ + Test getting the current update policy. + """ + # Make the API request + path = "/manager/policy/update" + response, _ = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Verify the response is one of the valid update policies + assert response.text in ["stable", "nightly", "nightly-comfyui"] + + +def test_get_channel_url_list( + api_request: Callable, + openapi_spec: Dict +): + """ + Test getting the channel URL list. + """ + # Make the API request + path = "/manager/channel_url_list" + response, json_data = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response contains the expected fields + assert "selected" in json_data + assert "list" in json_data + assert isinstance(json_data["list"], list) + + # Each channel should have a name and URL + if json_data["list"]: + first_channel = json_data["list"][0] + assert "name" in first_channel + assert "url" in first_channel + + +def test_get_manager_version( + api_request: Callable +): + """ + Test getting the manager version. + """ + # Make the API request + path = "/manager/version" + response, _ = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Verify the response is a version string + assert response.text.startswith("V") # Version strings start with V + + +def test_get_manager_notice( + api_request: Callable +): + """ + Test getting the manager notice. + """ + # Make the API request + path = "/manager/notice" + response, _ = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Verify the response is HTML content + assert response.headers.get("Content-Type", "").startswith("text/html") or "ComfyUI" in response.text + + +@pytest.mark.skip(reason="State-modifying operations") +class TestConfigChanges: + """ + Tests for changing configuration settings. + These are skipped to avoid modifying state in automated tests. + """ + + @pytest.fixture(scope="class", autouse=True) + def save_original_config(self, api_request: Callable): + """ + Save the original configuration to restore after tests. + """ + # Save original values + response, _ = api_request( + method="get", + path="/manager/preview_method", + expected_status=200, + ) + self.original_preview_method = response.text + + response, _ = api_request( + method="get", + path="/manager/db_mode", + expected_status=200, + ) + self.original_db_mode = response.text + + response, _ = api_request( + method="get", + path="/manager/policy/update", + expected_status=200, + ) + self.original_update_policy = response.text + + yield + + # Restore original values + api_request( + method="get", + path="/manager/preview_method", + params={"value": self.original_preview_method}, + expected_status=200, + ) + + api_request( + method="get", + path="/manager/db_mode", + params={"value": self.original_db_mode}, + expected_status=200, + ) + + api_request( + method="get", + path="/manager/policy/update", + params={"value": self.original_update_policy}, + expected_status=200, + ) + + def test_set_preview_method(self, api_request: Callable): + """ + Test setting the preview method. + """ + # Set to a different value (taesd) + api_request( + method="get", + path="/manager/preview_method", + params={"value": "taesd"}, + expected_status=200, + ) + + # Verify it was changed + response, _ = api_request( + method="get", + path="/manager/preview_method", + expected_status=200, + ) + assert response.text == "taesd" + + def test_set_db_mode(self, api_request: Callable): + """ + Test setting the database mode. + """ + # Set to local mode + api_request( + method="get", + path="/manager/db_mode", + params={"value": "local"}, + expected_status=200, + ) + + # Verify it was changed + response, _ = api_request( + method="get", + path="/manager/db_mode", + expected_status=200, + ) + assert response.text == "local" + + def test_set_update_policy(self, api_request: Callable): + """ + Test setting the update policy. + """ + # Set to stable + api_request( + method="get", + path="/manager/policy/update", + params={"value": "stable"}, + expected_status=200, + ) + + # Verify it was changed + response, _ = api_request( + method="get", + path="/manager/policy/update", + expected_status=200, + ) + assert response.text == "stable" \ No newline at end of file diff --git a/tests-api/test_customnode_api.py b/tests-api/test_customnode_api.py new file mode 100644 index 00000000..756a4b5b --- /dev/null +++ b/tests-api/test_customnode_api.py @@ -0,0 +1,200 @@ +""" +Tests for custom node management endpoints. +""" +import pytest +from pathlib import Path +from typing import Callable, Dict, Tuple + +from utils.validation import validate_response + + +@pytest.mark.parametrize( + "mode", + ["local", "remote"] +) +def test_get_custom_node_list( + api_request: Callable, + openapi_spec: Dict, + mode: str +): + """ + Test the endpoint for listing custom nodes. + """ + # Make the API request + path = "/customnode/getlist" + response, json_data = api_request( + method="get", + path=path, + params={"mode": mode, "skip_update": "true"}, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response contains the expected fields + assert "channel" in json_data + assert "node_packs" in json_data + assert isinstance(json_data["node_packs"], dict) + + # If there are any node packs, verify they have the expected structure + if json_data["node_packs"]: + # Take the first node pack to validate + first_node_pack = next(iter(json_data["node_packs"].values())) + assert "title" in first_node_pack + assert "name" in first_node_pack + + +def test_get_installed_nodes( + api_request: Callable, + openapi_spec: Dict +): + """ + Test the endpoint for listing installed nodes. + """ + # Make the API request + path = "/customnode/installed" + response, json_data = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response is a dictionary of node packs + assert isinstance(json_data, dict) + + +@pytest.mark.parametrize( + "mode", + ["local", "nickname"] +) +def test_get_node_mappings( + api_request: Callable, + openapi_spec: Dict, + mode: str +): + """ + Test the endpoint for getting node-to-package mappings. + """ + # Make the API request + path = "/customnode/getmappings" + response, json_data = api_request( + method="get", + path=path, + params={"mode": mode}, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response is a dictionary mapping extension IDs to node info + assert isinstance(json_data, dict) + + # If there are any mappings, verify they have the expected structure + if json_data: + # Take the first mapping to validate + first_mapping = next(iter(json_data.values())) + assert isinstance(first_mapping, list) + assert len(first_mapping) == 2 + assert isinstance(first_mapping[0], list) # List of node classes + assert isinstance(first_mapping[1], dict) # Metadata + + +@pytest.mark.parametrize( + "mode", + ["local", "remote"] +) +def test_get_node_alternatives( + api_request: Callable, + openapi_spec: Dict, + mode: str +): + """ + Test the endpoint for getting alternative node options. + """ + # Make the API request + path = "/customnode/alternatives" + response, json_data = api_request( + method="get", + path=path, + params={"mode": mode}, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response is a dictionary + assert isinstance(json_data, dict) + + +def test_fetch_updates( + api_request: Callable +): + """ + Test the endpoint for fetching updates. + This might modify state, so we just check for a valid response. + """ + # Make the API request with skip_update=true to avoid actual updates + path = "/customnode/fetch_updates" + response, _ = api_request( + method="get", + path=path, + params={"mode": "local"}, + # Don't validate JSON since this endpoint doesn't return JSON + expected_status=200, + retry_on_error=False, # Don't retry as this might have side effects + ) + + # Just check the status code is as expected (covered by api_request) + assert response.status_code in [200, 201] + + +@pytest.mark.skip(reason="Queue endpoints are better tested with queue operations") +def test_queue_update_all( + api_request: Callable +): + """ + Test the endpoint for queuing updates for all nodes. + Skipping as this would actually modify the installation. + """ + pass + + +@pytest.mark.skip(reason="Security-restricted endpoint") +def test_install_node_via_git_url( + api_request: Callable +): + """ + Test the endpoint for installing a node via Git URL. + Skipping as this requires high security level and would modify the installation. + """ + pass \ No newline at end of file diff --git a/tests-api/test_import.py b/tests-api/test_import.py new file mode 100644 index 00000000..a5976bf2 --- /dev/null +++ b/tests-api/test_import.py @@ -0,0 +1,23 @@ +import os +import sys + +# Print current working directory +print(f"Current directory: {os.getcwd()}") + +# Print module search path +print(f"System path: {sys.path}") + +# Try to import +try: + from utils.validation import load_openapi_spec + print("Import successful!") +except ImportError as e: + print(f"Import error: {e}") + + # Try direct import + try: + sys.path.insert(0, os.path.join(os.getcwd(), "custom_nodes/ComfyUI-Manager/tests-api")) + from utils.validation import load_openapi_spec + print("Direct import successful!") + except ImportError as e: + print(f"Direct import error: {e}") \ No newline at end of file diff --git a/tests-api/test_model_api.py b/tests-api/test_model_api.py new file mode 100644 index 00000000..c1a88033 --- /dev/null +++ b/tests-api/test_model_api.py @@ -0,0 +1,62 @@ +""" +Tests for model management endpoints. +These features are scheduled for deprecation, so tests are minimal. +""" +import pytest +from typing import Callable, Dict + +from utils.validation import validate_response + + +@pytest.mark.parametrize( + "mode", + ["local", "remote"] +) +def test_get_external_model_list( + api_request: Callable, + openapi_spec: Dict, + mode: str +): + """ + Test the endpoint for listing external models. + """ + # Make the API request + path = "/externalmodel/getlist" + response, json_data = api_request( + method="get", + path=path, + params={"mode": mode}, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response contains the expected fields + assert "models" in json_data + assert isinstance(json_data["models"], list) + + # If there are any models, verify they have the expected structure + if json_data["models"]: + first_model = json_data["models"][0] + assert "name" in first_model + assert "type" in first_model + assert "url" in first_model + assert "filename" in first_model + assert "installed" in first_model + + +@pytest.mark.skip(reason="State-modifying operation that requires auth") +def test_install_model(): + """ + Test queuing a model installation. + Skipped to avoid modifying state and requires authentication. + This feature is also scheduled for deprecation. + """ + pass \ No newline at end of file diff --git a/tests-api/test_queue_api.py b/tests-api/test_queue_api.py new file mode 100644 index 00000000..e26f5231 --- /dev/null +++ b/tests-api/test_queue_api.py @@ -0,0 +1,213 @@ +""" +Tests for queue management endpoints. +""" +import pytest +import time +from pathlib import Path +from typing import Callable, Dict, Tuple + +from utils.validation import validate_response + + +def test_get_queue_status( + api_request: Callable, + openapi_spec: Dict +): + """ + Test the endpoint for getting queue status. + """ + # Make the API request + path = "/manager/queue/status" + response, json_data = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response contains the expected fields + assert "total_count" in json_data + assert "done_count" in json_data + assert "in_progress_count" in json_data + assert "is_processing" in json_data + + # Type checks + assert isinstance(json_data["total_count"], int) + assert isinstance(json_data["done_count"], int) + assert isinstance(json_data["in_progress_count"], int) + assert isinstance(json_data["is_processing"], bool) + + +def test_reset_queue( + api_request: Callable +): + """ + Test the endpoint for resetting the queue. + """ + # Make the API request + path = "/manager/queue/reset" + response, _ = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Now check the queue status to verify it was reset + response2, json_data = api_request( + method="get", + path="/manager/queue/status", + expected_status=200, + ) + + # Queue should be empty after reset + assert json_data["total_count"] == json_data["done_count"] + json_data["in_progress_count"] + + +@pytest.mark.skip(reason="State-modifying operation that requires auth") +def test_queue_install_node(): + """ + Test queuing a node installation. + Skipped to avoid modifying state and requires authentication. + """ + pass + + +@pytest.mark.skip(reason="State-modifying operation that requires auth") +def test_queue_update_node(): + """ + Test queuing a node update. + Skipped to avoid modifying state and requires authentication. + """ + pass + + +@pytest.mark.skip(reason="State-modifying operation that requires auth") +def test_queue_uninstall_node(): + """ + Test queuing a node uninstallation. + Skipped to avoid modifying state and requires authentication. + """ + pass + + +@pytest.mark.skip(reason="State-modifying operation") +def test_queue_start(): + """ + Test starting the queue. + Skipped to avoid modifying state. + """ + pass + + +class TestQueueOperations: + """ + Test a complete queue workflow. + These tests are grouped to ensure proper sequencing but are still skipped + to avoid modifying state in automated tests. + """ + + @pytest.fixture(scope="class") + def node_data(self) -> Dict: + """ + Create test data for a node operation. + """ + # This would be replaced with actual data for a known safe node + return { + "ui_id": "test_node_1", + "id": "comfyui-manager", # Manager itself + "version": "latest", + "channel": "default", + "mode": "local", + } + + @pytest.mark.skip(reason="State-modifying operation") + def test_queue_operation_sequence( + self, + api_request: Callable, + node_data: Dict + ): + """ + Test the queue operation sequence. + """ + # 1. Reset the queue + api_request( + method="get", + path="/manager/queue/reset", + expected_status=200, + ) + + # 2. Queue a node operation (we'll use the manager itself) + api_request( + method="post", + path="/manager/queue/update", + json_data=node_data, + expected_status=200, + ) + + # 3. Check queue status - should have one operation + response, json_data = api_request( + method="get", + path="/manager/queue/status", + expected_status=200, + ) + + assert json_data["total_count"] > 0 + assert not json_data["is_processing"] # Queue hasn't started yet + + # 4. Start the queue + api_request( + method="get", + path="/manager/queue/start", + expected_status=200, + ) + + # 5. Check queue status again - should be processing + response, json_data = api_request( + method="get", + path="/manager/queue/status", + expected_status=200, + ) + + # Queue should be processing or already done + assert json_data["is_processing"] or json_data["done_count"] == json_data["total_count"] + + # 6. Wait for queue to complete (with timeout) + max_wait_time = 60 # seconds + start_time = time.time() + completed = False + + while time.time() - start_time < max_wait_time: + response, json_data = api_request( + method="get", + path="/manager/queue/status", + expected_status=200, + ) + + if json_data["done_count"] == json_data["total_count"] and not json_data["is_processing"]: + completed = True + break + + time.sleep(2) # Wait before checking again + + assert completed, "Queue did not complete within timeout period" + + @pytest.mark.skip(reason="State-modifying operation") + def test_concurrent_queue_operations( + self, + api_request: Callable, + node_data: Dict + ): + """ + Test concurrent queue operations. + """ + # This would test adding multiple operations to the queue + # and verifying they all complete correctly + pass \ No newline at end of file diff --git a/tests-api/test_snapshot_api.py b/tests-api/test_snapshot_api.py new file mode 100644 index 00000000..60a1159c --- /dev/null +++ b/tests-api/test_snapshot_api.py @@ -0,0 +1,198 @@ +""" +Tests for snapshot management endpoints. +""" +import pytest +import time +from datetime import datetime +from pathlib import Path +from typing import Callable, Dict, List, Optional + +from utils.validation import validate_response + + +def test_get_snapshot_list( + api_request: Callable, + openapi_spec: Dict +): + """ + Test the endpoint for listing snapshots. + """ + # Make the API request + path = "/snapshot/getlist" + response, json_data = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Verify the response contains the expected fields + assert "items" in json_data + assert isinstance(json_data["items"], list) + + +def test_get_current_snapshot( + api_request: Callable, + openapi_spec: Dict +): + """ + Test the endpoint for getting the current snapshot. + """ + # Make the API request + path = "/snapshot/get_current" + response, json_data = api_request( + method="get", + path=path, + expected_status=200, + ) + + # Validate response structure against the schema + assert json_data is not None + validate_response( + response_data=json_data, + path=path, + method="get", + spec=openapi_spec, + ) + + # Check for basic snapshot structure + assert "snapshot_date" in json_data + assert "custom_nodes" in json_data + + +@pytest.mark.skip(reason="This test creates a snapshot which is a state-modifying operation") +def test_save_snapshot( + api_request: Callable +): + """ + Test the endpoint for saving a new snapshot. + Skipped to avoid modifying state in tests. + """ + pass + + +@pytest.mark.skip(reason="This test removes a snapshot which is a destructive operation") +def test_remove_snapshot( + api_request: Callable +): + """ + Test the endpoint for removing a snapshot. + Skipped to avoid modifying state in tests. + """ + pass + + +@pytest.mark.skip(reason="This test restores a snapshot which is a state-modifying operation") +def test_restore_snapshot( + api_request: Callable +): + """ + Test the endpoint for restoring a snapshot. + Skipped to avoid modifying state in tests. + """ + pass + + +class TestSnapshotWorkflow: + """ + Test the complete snapshot workflow (create, list, get, remove). + These tests are grouped to ensure proper sequencing but are still skipped + to avoid modifying state in automated tests. + """ + + @pytest.fixture(scope="class") + def snapshot_name(self) -> str: + """ + Generate a unique snapshot name for testing. + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return f"test_snapshot_{timestamp}" + + @pytest.mark.skip(reason="State-modifying test") + def test_create_snapshot( + self, + api_request: Callable, + snapshot_name: str + ): + """ + Test creating a snapshot. + """ + # Make the API request to save a snapshot + response, _ = api_request( + method="get", + path="/snapshot/save", + expected_status=200, + ) + + # Verify a snapshot was created (would need to check the snapshot list) + response2, json_data = api_request( + method="get", + path="/snapshot/getlist", + expected_status=200, + ) + + # The most recently created snapshot should be first in the list + assert json_data["items"] + + # Store the snapshot name for later tests + self.actual_snapshot_name = json_data["items"][0] + + @pytest.mark.skip(reason="State-modifying test") + def test_get_snapshot_details( + self, + api_request: Callable, + openapi_spec: Dict + ): + """ + Test getting details of the created snapshot. + """ + # This would check the current snapshot, not a specific one + # since there's no direct API to get a specific snapshot + response, json_data = api_request( + method="get", + path="/snapshot/get_current", + expected_status=200, + ) + + # Validate the snapshot data + assert json_data is not None + validate_response( + response_data=json_data, + path="/snapshot/get_current", + method="get", + spec=openapi_spec, + ) + + @pytest.mark.skip(reason="State-modifying test") + def test_remove_test_snapshot( + self, + api_request: Callable + ): + """ + Test removing the test snapshot. + """ + # Make the API request to remove the snapshot + response, _ = api_request( + method="get", + path="/snapshot/remove", + params={"target": self.actual_snapshot_name}, + expected_status=200, + ) + + # Verify the snapshot was removed + response2, json_data = api_request( + method="get", + path="/snapshot/getlist", + expected_status=200, + ) + + # The snapshot should no longer be in the list + assert self.actual_snapshot_name not in json_data["items"] \ No newline at end of file diff --git a/tests-api/test_spec_validation.py b/tests-api/test_spec_validation.py new file mode 100644 index 00000000..55f7a9b8 --- /dev/null +++ b/tests-api/test_spec_validation.py @@ -0,0 +1,150 @@ +""" +Tests for validating the OpenAPI specification. +""" +import json +import pytest +import yaml +from typing import Dict, Any, List, Tuple +from pathlib import Path +from openapi_spec_validator import validate_spec +from utils.validation import load_openapi_spec +from utils.schema_utils import ( + get_all_paths, + get_methods_for_path, + find_paths_with_security, + get_required_parameters +) + + +def test_spec_is_valid(): + """ + Test that the OpenAPI specification is valid according to the spec validator. + """ + spec = load_openapi_spec() + validate_spec(spec) + + +def test_spec_has_info(): + """ + Test that the OpenAPI specification has basic info. + """ + spec = load_openapi_spec() + + assert "info" in spec + assert "title" in spec["info"] + assert "version" in spec["info"] + assert spec["info"]["title"] == "ComfyUI-Manager API" + + +def test_spec_has_paths(): + """ + Test that the OpenAPI specification has paths defined. + """ + spec = load_openapi_spec() + + assert "paths" in spec + assert len(spec["paths"]) > 0 + + +def test_paths_have_responses(): + """ + Test that all paths have responses defined. + """ + spec = load_openapi_spec() + + for path, path_item in spec["paths"].items(): + for method, operation in path_item.items(): + if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}: + continue + + assert "responses" in operation, f"Path {path} method {method} has no responses" + assert len(operation["responses"]) > 0, f"Path {path} method {method} has empty responses" + + +def test_responses_have_schemas(): + """ + Test that responses with application/json content type have schemas. + """ + spec = load_openapi_spec() + + for path, path_item in spec["paths"].items(): + for method, operation in path_item.items(): + if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}: + continue + + for status, response in operation["responses"].items(): + if "content" not in response: + continue + + if "application/json" in response["content"]: + assert "schema" in response["content"]["application/json"], ( + f"Path {path} method {method} status {status} " + f"application/json content has no schema" + ) + + +def test_required_parameters_have_schemas(): + """ + Test that all required parameters have schemas. + """ + spec = load_openapi_spec() + + for path, path_item in spec["paths"].items(): + for method, operation in path_item.items(): + if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}: + continue + + if "parameters" not in operation: + continue + + for param in operation["parameters"]: + if param.get("required", False): + assert "schema" in param, ( + f"Path {path} method {method} required parameter {param.get('name')} has no schema" + ) + + +def test_security_schemes_defined(): + """ + Test that security schemes are properly defined. + """ + spec = load_openapi_spec() + + # Get paths requiring security + secure_paths = find_paths_with_security(spec) + + if secure_paths: + assert "components" in spec, "Spec has secure paths but no components" + assert "securitySchemes" in spec["components"], "Spec has secure paths but no securitySchemes" + + # Check each security reference is defined + for path, method in secure_paths: + operation = spec["paths"][path][method] + for security_req in operation["security"]: + for scheme_name in security_req: + assert scheme_name in spec["components"]["securitySchemes"], ( + f"Security scheme {scheme_name} used by {method.upper()} {path} " + f"is not defined in components.securitySchemes" + ) + + +def test_common_endpoint_groups_present(): + """ + Test that the spec includes the main endpoint groups. + """ + spec = load_openapi_spec() + paths = get_all_paths(spec) + + # Define the expected endpoint prefixes + expected_prefixes = [ + "/customnode/", + "/externalmodel/", + "/manager/", + "/snapshot/", + "/comfyui_manager/", + ] + + # Check that at least one path exists for each expected prefix + for prefix in expected_prefixes: + matching_paths = [p for p in paths if p.startswith(prefix)] + assert matching_paths, f"No endpoints found with prefix {prefix}" \ No newline at end of file diff --git a/tests-api/utils/__init__.py b/tests-api/utils/__init__.py new file mode 100644 index 00000000..e2b11a0e --- /dev/null +++ b/tests-api/utils/__init__.py @@ -0,0 +1 @@ +# Make utils directory a proper package \ No newline at end of file diff --git a/tests-api/utils/schema_utils.py b/tests-api/utils/schema_utils.py new file mode 100644 index 00000000..a9d0e091 --- /dev/null +++ b/tests-api/utils/schema_utils.py @@ -0,0 +1,174 @@ +""" +Schema utilities for extracting and manipulating OpenAPI schemas. +""" +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple +from .validation import load_openapi_spec + + +def get_all_paths(spec: Dict[str, Any]) -> List[str]: + """ + Get all paths defined in the OpenAPI specification. + + Args: + spec: The OpenAPI specification + + Returns: + List of all paths + """ + return list(spec.get("paths", {}).keys()) + + +def get_grouped_paths(spec: Dict[str, Any]) -> Dict[str, List[str]]: + """ + Group paths by their top-level segment. + + Args: + spec: The OpenAPI specification + + Returns: + Dictionary mapping top-level segments to lists of paths + """ + result = {} + + for path in get_all_paths(spec): + segments = path.strip("/").split("/") + if not segments: + continue + + top_segment = segments[0] + if top_segment not in result: + result[top_segment] = [] + + result[top_segment].append(path) + + return result + + +def get_methods_for_path(spec: Dict[str, Any], path: str) -> List[str]: + """ + Get all HTTP methods defined for a path. + + Args: + spec: The OpenAPI specification + path: The API path + + Returns: + List of HTTP methods (lowercase) + """ + if path not in spec.get("paths", {}): + return [] + + return [ + method.lower() + for method in spec["paths"][path].keys() + if method.lower() in {"get", "post", "put", "delete", "patch", "options", "head"} + ] + + +def find_paths_with_security( + spec: Dict[str, Any], + security_scheme: Optional[str] = None +) -> List[Tuple[str, str]]: + """ + Find all paths that require security. + + Args: + spec: The OpenAPI specification + security_scheme: Optional specific security scheme to filter by + + Returns: + List of (path, method) tuples that require security + """ + result = [] + + for path, path_item in spec.get("paths", {}).items(): + for method, operation in path_item.items(): + if method.lower() not in {"get", "post", "put", "delete", "patch", "options", "head"}: + continue + + if "security" in operation: + if security_scheme is None: + result.append((path, method.lower())) + else: + # Check if this security scheme is required + for security_req in operation["security"]: + if security_scheme in security_req: + result.append((path, method.lower())) + break + + return result + + +def get_content_types_for_response( + spec: Dict[str, Any], + path: str, + method: str, + status_code: str = "200" +) -> List[str]: + """ + Get content types defined for a response. + + Args: + spec: The OpenAPI specification + path: The API path + method: The HTTP method + status_code: The HTTP status code + + Returns: + List of content types + """ + method = method.lower() + + if path not in spec["paths"]: + return [] + + if method not in spec["paths"][path]: + return [] + + if "responses" not in spec["paths"][path][method]: + return [] + + if status_code not in spec["paths"][path][method]["responses"]: + return [] + + response_def = spec["paths"][path][method]["responses"][status_code] + + if "content" not in response_def: + return [] + + return list(response_def["content"].keys()) + + +def get_required_parameters( + spec: Dict[str, Any], + path: str, + method: str +) -> List[Dict[str, Any]]: + """ + Get all required parameters for a path/method. + + Args: + spec: The OpenAPI specification + path: The API path + method: The HTTP method + + Returns: + List of parameter objects that are required + """ + method = method.lower() + + if path not in spec["paths"]: + return [] + + if method not in spec["paths"][path]: + return [] + + if "parameters" not in spec["paths"][path][method]: + return [] + + return [ + param for param in spec["paths"][path][method]["parameters"] + if param.get("required", False) + ] \ No newline at end of file diff --git a/tests-api/utils/validation.py b/tests-api/utils/validation.py new file mode 100644 index 00000000..4d0a05c1 --- /dev/null +++ b/tests-api/utils/validation.py @@ -0,0 +1,155 @@ +""" +Validation utilities for API tests. +""" +import json +import jsonschema +import yaml +from pathlib import Path +from typing import Any, Dict, Optional, Union + + +def load_openapi_spec(spec_path: Union[str, Path] = None) -> Dict[str, Any]: + """ + Load the OpenAPI specification document. + + Args: + spec_path: Path to the OpenAPI specification file + + Returns: + The OpenAPI specification as a dictionary + """ + if spec_path is None: + # Default to the root openapi.yaml file + spec_path = Path(__file__).parents[2] / "openapi.yaml" + + with open(spec_path, "r") as f: + if str(spec_path).endswith(".yaml") or str(spec_path).endswith(".yml"): + return yaml.safe_load(f) + else: + return json.load(f) + + +def get_schema_for_path( + spec: Dict[str, Any], + path: str, + method: str, + status_code: str = "200", + content_type: str = "application/json" +) -> Optional[Dict[str, Any]]: + """ + Extract the response schema for a specific path, method, and status code. + + Args: + spec: The OpenAPI specification + path: The API path (e.g., "/customnode/getlist") + method: The HTTP method (e.g., "get", "post") + status_code: The HTTP status code (default: "200") + content_type: The response content type (default: "application/json") + + Returns: + The schema for the specified path and method, or None if not found + """ + method = method.lower() + + if path not in spec["paths"]: + return None + + if method not in spec["paths"][path]: + return None + + if "responses" not in spec["paths"][path][method]: + return None + + if status_code not in spec["paths"][path][method]["responses"]: + return None + + response_def = spec["paths"][path][method]["responses"][status_code] + + if "content" not in response_def: + return None + + if content_type not in response_def["content"]: + return None + + if "schema" not in response_def["content"][content_type]: + return None + + return response_def["content"][content_type]["schema"] + + +def validate_response_schema( + response_data: Any, + schema: Dict[str, Any], + spec: Dict[str, Any] = None +) -> bool: + """ + Validate a response against a schema from the OpenAPI specification. + + Args: + response_data: The response data to validate + schema: The schema to validate against + spec: The complete OpenAPI specification (for resolving references) + + Returns: + True if validation succeeds, raises an exception otherwise + """ + if spec is None: + spec = load_openapi_spec() + + # Create a resolver for references within the schema + resolver = jsonschema.RefResolver.from_schema(spec) + + # Validate the response against the schema + jsonschema.validate( + instance=response_data, + schema=schema, + resolver=resolver + ) + + return True + + +def validate_response( + response_data: Any, + path: str, + method: str, + status_code: str = "200", + content_type: str = "application/json", + spec: Dict[str, Any] = None +) -> bool: + """ + Validate a response against the schema defined in the OpenAPI specification. + + Args: + response_data: The response data to validate + path: The API path + method: The HTTP method + status_code: The HTTP status code (default: "200") + content_type: The response content type (default: "application/json") + spec: The OpenAPI specification (loaded from default location if None) + + Returns: + True if validation succeeds, raises an exception otherwise + """ + if spec is None: + spec = load_openapi_spec() + + schema = get_schema_for_path( + spec=spec, + path=path, + method=method, + status_code=status_code, + content_type=content_type + ) + + if schema is None: + raise ValueError( + f"No schema found for {method.upper()} {path} " + f"with status {status_code} and content type {content_type}" + ) + + return validate_response_schema( + response_data=response_data, + schema=schema, + spec=spec + ) \ No newline at end of file