mirror of
https://git.datalinker.icu/ltdrdata/ComfyUI-Manager
synced 2025-12-09 06:04:31 +08:00
[feat] Implement comprehensive batch tracking and OpenAPI-driven data models
Enhances ComfyUI Manager with robust batch execution tracking and unified data model architecture: - Implemented automatic batch history serialization with before/after system state snapshots - Added comprehensive state management capturing installed nodes, models, and ComfyUI version info - Enhanced task queue with proper client ID handling and WebSocket notifications - Migrated all data models to OpenAPI-generated Pydantic models for consistency - Added documentation for new TaskQueue methods (done_count, total_count, finalize) - Fixed 64 linting errors with proper imports and code cleanup Technical improvements: - All models now auto-generated from openapi.yaml ensuring API/implementation consistency - Batch tracking captures complete system state at operation start and completion - Enhanced REST endpoints with comprehensive documentation - Removed manual model files in favor of single source of truth - Added helper methods for system state capture and batch lifecycle management
This commit is contained in:
parent
601f1bf452
commit
c8882dcb7c
67
comfyui_manager/data_models/README.md
Normal file
67
comfyui_manager/data_models/README.md
Normal file
@ -0,0 +1,67 @@
|
||||
# Data Models
|
||||
|
||||
This directory contains Pydantic models for ComfyUI Manager, providing type safety, validation, and serialization for the API and internal data structures.
|
||||
|
||||
## Overview
|
||||
|
||||
- `generated_models.py` - All models auto-generated from OpenAPI spec
|
||||
- `__init__.py` - Package exports for all models
|
||||
|
||||
**Note**: All models are now auto-generated from the OpenAPI specification. Manual model files (`task_queue.py`, `state_management.py`) have been deprecated in favor of a single source of truth.
|
||||
|
||||
## Generating Types from OpenAPI
|
||||
|
||||
The state management models are automatically generated from the OpenAPI specification using `datamodel-codegen`. This ensures type safety and consistency between the API specification and the Python code.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Install the code generator:
|
||||
```bash
|
||||
pipx install datamodel-code-generator
|
||||
```
|
||||
|
||||
### Generation Command
|
||||
|
||||
To regenerate all models after updating the OpenAPI spec:
|
||||
|
||||
```bash
|
||||
datamodel-codegen \
|
||||
--use-subclass-enum \
|
||||
--field-constraints \
|
||||
--strict-types bytes \
|
||||
--input openapi.yaml \
|
||||
--output comfyui_manager/data_models/generated_models.py \
|
||||
--output-model-type pydantic_v2.BaseModel
|
||||
```
|
||||
|
||||
### When to Regenerate
|
||||
|
||||
You should regenerate the models when:
|
||||
|
||||
1. **Adding new API endpoints** that return new data structures
|
||||
2. **Modifying existing schemas** in the OpenAPI specification
|
||||
3. **Adding new state management features** that require new models
|
||||
|
||||
### Important Notes
|
||||
|
||||
- **Single source of truth**: All models are now generated from `openapi.yaml`
|
||||
- **No manual models**: All previously manual models have been migrated to the OpenAPI spec
|
||||
- **OpenAPI requirements**: New schemas must be referenced in API paths to be generated by datamodel-codegen
|
||||
- **Validation**: Always validate the OpenAPI spec before generation:
|
||||
```bash
|
||||
python3 -c "import yaml; yaml.safe_load(open('openapi.yaml'))"
|
||||
```
|
||||
|
||||
### Example: Adding New State Models
|
||||
|
||||
1. Add your schema to `openapi.yaml` under `components/schemas/`
|
||||
2. Reference the schema in an API endpoint response
|
||||
3. Run the generation command above
|
||||
4. Update `__init__.py` to export the new models
|
||||
5. Import and use the models in your code
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
- **Models not generated**: Ensure schemas are under `components/schemas/` (not `parameters/`)
|
||||
- **Missing models**: Verify schemas are referenced in at least one API path
|
||||
- **Import errors**: Check that new models are added to `__init__.py` exports
|
||||
@ -3,24 +3,105 @@ Data models for ComfyUI Manager.
|
||||
|
||||
This package contains Pydantic models used throughout the ComfyUI Manager
|
||||
for data validation, serialization, and type safety.
|
||||
|
||||
All models are auto-generated from the OpenAPI specification to ensure
|
||||
consistency between the API and implementation.
|
||||
"""
|
||||
|
||||
from .task_queue import (
|
||||
from .generated_models import (
|
||||
# Core Task Queue Models
|
||||
QueueTaskItem,
|
||||
TaskHistoryItem,
|
||||
TaskStateMessage,
|
||||
TaskExecutionStatus,
|
||||
|
||||
# WebSocket Message Models
|
||||
MessageTaskDone,
|
||||
MessageTaskStarted,
|
||||
MessageTaskFailed,
|
||||
MessageUpdate,
|
||||
ManagerMessageName,
|
||||
|
||||
# State Management Models
|
||||
BatchExecutionRecord,
|
||||
ComfyUISystemState,
|
||||
BatchOperation,
|
||||
InstalledNodeInfo,
|
||||
InstalledModelInfo,
|
||||
ComfyUIVersionInfo,
|
||||
|
||||
# Other models
|
||||
Kind,
|
||||
StatusStr,
|
||||
ManagerPackInfo,
|
||||
ManagerPackInstalled,
|
||||
SelectedVersion,
|
||||
ManagerChannel,
|
||||
ManagerDatabaseSource,
|
||||
ManagerPackState,
|
||||
ManagerPackInstallType,
|
||||
ManagerPack,
|
||||
InstallPackParams,
|
||||
UpdateAllPacksParams,
|
||||
QueueStatus,
|
||||
ManagerMappings,
|
||||
ModelMetadata,
|
||||
NodePackageMetadata,
|
||||
SnapshotItem,
|
||||
Error,
|
||||
InstalledPacksResponse,
|
||||
HistoryResponse,
|
||||
HistoryListResponse,
|
||||
InstallType,
|
||||
OperationType,
|
||||
Result,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core Task Queue Models
|
||||
"QueueTaskItem",
|
||||
"TaskHistoryItem",
|
||||
"TaskStateMessage",
|
||||
"TaskExecutionStatus",
|
||||
|
||||
# WebSocket Message Models
|
||||
"MessageTaskDone",
|
||||
"MessageTaskStarted",
|
||||
"MessageTaskFailed",
|
||||
"MessageUpdate",
|
||||
"ManagerMessageName",
|
||||
]
|
||||
|
||||
# State Management Models
|
||||
"BatchExecutionRecord",
|
||||
"ComfyUISystemState",
|
||||
"BatchOperation",
|
||||
"InstalledNodeInfo",
|
||||
"InstalledModelInfo",
|
||||
"ComfyUIVersionInfo",
|
||||
|
||||
# Other models
|
||||
"Kind",
|
||||
"StatusStr",
|
||||
"ManagerPackInfo",
|
||||
"ManagerPackInstalled",
|
||||
"SelectedVersion",
|
||||
"ManagerChannel",
|
||||
"ManagerDatabaseSource",
|
||||
"ManagerPackState",
|
||||
"ManagerPackInstallType",
|
||||
"ManagerPack",
|
||||
"InstallPackParams",
|
||||
"UpdateAllPacksParams",
|
||||
"QueueStatus",
|
||||
"ManagerMappings",
|
||||
"ModelMetadata",
|
||||
"NodePackageMetadata",
|
||||
"SnapshotItem",
|
||||
"Error",
|
||||
"InstalledPacksResponse",
|
||||
"HistoryResponse",
|
||||
"HistoryListResponse",
|
||||
"InstallType",
|
||||
"OperationType",
|
||||
"Result",
|
||||
]
|
||||
417
comfyui_manager/data_models/generated_models.py
Normal file
417
comfyui_manager/data_models/generated_models.py
Normal file
@ -0,0 +1,417 @@
|
||||
# generated by datamodel-codegen:
|
||||
# filename: openapi.yaml
|
||||
# timestamp: 2025-06-08T08:07:38+00:00
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
|
||||
class Kind(str, Enum):
|
||||
install = 'install'
|
||||
uninstall = 'uninstall'
|
||||
update = 'update'
|
||||
update_all = 'update-all'
|
||||
update_comfyui = 'update-comfyui'
|
||||
fix = 'fix'
|
||||
disable = 'disable'
|
||||
enable = 'enable'
|
||||
install_model = 'install-model'
|
||||
|
||||
|
||||
class QueueTaskItem(BaseModel):
|
||||
ui_id: str = Field(..., description='Unique identifier for the task')
|
||||
client_id: str = Field(..., description='Client identifier that initiated the task')
|
||||
kind: Kind = Field(..., description='Type of task being performed')
|
||||
|
||||
|
||||
class StatusStr(str, Enum):
|
||||
success = 'success'
|
||||
error = 'error'
|
||||
skip = 'skip'
|
||||
|
||||
|
||||
class TaskExecutionStatus(BaseModel):
|
||||
status_str: StatusStr = Field(..., description='Overall task execution status')
|
||||
completed: bool = Field(..., description='Whether the task completed')
|
||||
messages: List[str] = Field(..., description='Additional status messages')
|
||||
|
||||
|
||||
class ManagerMessageName(str, Enum):
|
||||
cm_task_completed = 'cm-task-completed'
|
||||
cm_task_started = 'cm-task-started'
|
||||
cm_queue_status = 'cm-queue-status'
|
||||
|
||||
|
||||
class ManagerPackInfo(BaseModel):
|
||||
id: str = Field(
|
||||
...,
|
||||
description='Either github-author/github-repo or name of pack from the registry',
|
||||
)
|
||||
version: str = Field(..., description='Semantic version or Git commit hash')
|
||||
ui_id: Optional[str] = Field(None, description='Task ID - generated internally')
|
||||
|
||||
|
||||
class ManagerPackInstalled(BaseModel):
|
||||
ver: str = Field(
|
||||
...,
|
||||
description='The version of the pack that is installed (Git commit hash or semantic version)',
|
||||
)
|
||||
cnr_id: Optional[str] = Field(
|
||||
None, description='The name of the pack if installed from the registry'
|
||||
)
|
||||
aux_id: Optional[str] = Field(
|
||||
None,
|
||||
description='The name of the pack if installed from github (author/repo-name format)',
|
||||
)
|
||||
enabled: bool = Field(..., description='Whether the pack is enabled')
|
||||
|
||||
|
||||
class SelectedVersion(str, Enum):
|
||||
latest = 'latest'
|
||||
nightly = 'nightly'
|
||||
|
||||
|
||||
class ManagerChannel(str, Enum):
|
||||
default = 'default'
|
||||
recent = 'recent'
|
||||
legacy = 'legacy'
|
||||
forked = 'forked'
|
||||
dev = 'dev'
|
||||
tutorial = 'tutorial'
|
||||
|
||||
|
||||
class ManagerDatabaseSource(str, Enum):
|
||||
remote = 'remote'
|
||||
local = 'local'
|
||||
cache = 'cache'
|
||||
|
||||
|
||||
class ManagerPackState(str, Enum):
|
||||
installed = 'installed'
|
||||
disabled = 'disabled'
|
||||
not_installed = 'not_installed'
|
||||
import_failed = 'import_failed'
|
||||
needs_update = 'needs_update'
|
||||
|
||||
|
||||
class ManagerPackInstallType(str, Enum):
|
||||
git_clone = 'git-clone'
|
||||
copy = 'copy'
|
||||
cnr = 'cnr'
|
||||
|
||||
|
||||
class UpdateState(str, Enum):
|
||||
false = 'false'
|
||||
true = 'true'
|
||||
|
||||
|
||||
class ManagerPack(ManagerPackInfo):
|
||||
author: Optional[str] = Field(
|
||||
None, description="Pack author name or 'Unclaimed' if added via GitHub crawl"
|
||||
)
|
||||
files: Optional[List[str]] = Field(None, description='Files included in the pack')
|
||||
reference: Optional[str] = Field(
|
||||
None, description='The type of installation reference'
|
||||
)
|
||||
title: Optional[str] = Field(None, description='The display name of the pack')
|
||||
cnr_latest: Optional[SelectedVersion] = None
|
||||
repository: Optional[str] = Field(None, description='GitHub repository URL')
|
||||
state: Optional[ManagerPackState] = None
|
||||
update_state: Optional[UpdateState] = Field(
|
||||
None, alias='update-state', description='Update availability status'
|
||||
)
|
||||
stars: Optional[int] = Field(None, description='GitHub stars count')
|
||||
last_update: Optional[datetime] = Field(None, description='Last update timestamp')
|
||||
health: Optional[str] = Field(None, description='Health status of the pack')
|
||||
description: Optional[str] = Field(None, description='Pack description')
|
||||
trust: Optional[bool] = Field(None, description='Whether the pack is trusted')
|
||||
install_type: Optional[ManagerPackInstallType] = None
|
||||
|
||||
|
||||
class InstallPackParams(ManagerPackInfo):
|
||||
selected_version: Union[str, SelectedVersion] = Field(
|
||||
..., description='Semantic version, Git commit hash, latest, or nightly'
|
||||
)
|
||||
repository: Optional[str] = Field(
|
||||
None,
|
||||
description='GitHub repository URL (required if selected_version is nightly)',
|
||||
)
|
||||
pip: Optional[List[str]] = Field(None, description='PyPi dependency names')
|
||||
mode: ManagerDatabaseSource
|
||||
channel: ManagerChannel
|
||||
skip_post_install: Optional[bool] = Field(
|
||||
None, description='Whether to skip post-installation steps'
|
||||
)
|
||||
|
||||
|
||||
class UpdateAllPacksParams(BaseModel):
|
||||
mode: Optional[ManagerDatabaseSource] = None
|
||||
ui_id: Optional[str] = Field(None, description='Task ID - generated internally')
|
||||
|
||||
|
||||
class QueueStatus(BaseModel):
|
||||
total_count: int = Field(
|
||||
..., description='Total number of tasks (pending + running)'
|
||||
)
|
||||
done_count: int = Field(..., description='Number of completed tasks')
|
||||
in_progress_count: int = Field(..., description='Number of tasks currently running')
|
||||
pending_count: Optional[int] = Field(
|
||||
None, description='Number of tasks waiting to be executed'
|
||||
)
|
||||
is_processing: bool = Field(..., description='Whether the task worker is active')
|
||||
client_id: Optional[str] = Field(
|
||||
None, description='Client ID (when filtered by client)'
|
||||
)
|
||||
|
||||
|
||||
class ManagerMapping(BaseModel):
|
||||
title_aux: Optional[str] = Field(None, description='The display name of the pack')
|
||||
|
||||
|
||||
class ManagerMappings(
|
||||
RootModel[Optional[Dict[str, List[Union[List[str], ManagerMapping]]]]]
|
||||
):
|
||||
root: Optional[Dict[str, List[Union[List[str], ManagerMapping]]]] = None
|
||||
|
||||
|
||||
class ModelMetadata(BaseModel):
|
||||
name: str = Field(..., description='Name of the model')
|
||||
type: str = Field(..., description='Type of model')
|
||||
base: Optional[str] = Field(None, description='Base model type')
|
||||
save_path: Optional[str] = Field(None, description='Path for saving the model')
|
||||
url: str = Field(..., description='Download URL')
|
||||
filename: str = Field(..., description='Target filename')
|
||||
ui_id: Optional[str] = Field(None, description='ID for UI reference')
|
||||
|
||||
|
||||
class InstallType(str, Enum):
|
||||
git = 'git'
|
||||
copy = 'copy'
|
||||
pip = 'pip'
|
||||
|
||||
|
||||
class NodePackageMetadata(BaseModel):
|
||||
title: Optional[str] = Field(None, description='Display name of the node package')
|
||||
name: Optional[str] = Field(None, description='Repository/package name')
|
||||
files: Optional[List[str]] = Field(None, description='Source URLs for the package')
|
||||
description: Optional[str] = Field(
|
||||
None, description='Description of the node package functionality'
|
||||
)
|
||||
install_type: Optional[InstallType] = Field(None, description='Installation method')
|
||||
version: Optional[str] = Field(None, description='Version identifier')
|
||||
id: Optional[str] = Field(
|
||||
None, description='Unique identifier for the node package'
|
||||
)
|
||||
ui_id: Optional[str] = Field(None, description='ID for UI reference')
|
||||
channel: Optional[str] = Field(None, description='Source channel')
|
||||
mode: Optional[str] = Field(None, description='Source mode')
|
||||
|
||||
|
||||
class SnapshotItem(RootModel[str]):
|
||||
root: str = Field(..., description='Name of the snapshot')
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
error: str = Field(..., description='Error message')
|
||||
|
||||
|
||||
class InstalledPacksResponse(RootModel[Optional[Dict[str, ManagerPackInstalled]]]):
|
||||
root: Optional[Dict[str, ManagerPackInstalled]] = None
|
||||
|
||||
|
||||
class HistoryListResponse(BaseModel):
|
||||
ids: Optional[List[str]] = Field(
|
||||
None, description='List of available batch history IDs'
|
||||
)
|
||||
|
||||
|
||||
class InstalledNodeInfo(BaseModel):
|
||||
name: str = Field(..., description='Node package name')
|
||||
version: str = Field(..., description='Installed version')
|
||||
repository_url: Optional[str] = Field(None, description='Git repository URL')
|
||||
install_method: str = Field(
|
||||
..., description='Installation method (cnr, git, pip, etc.)'
|
||||
)
|
||||
enabled: Optional[bool] = Field(
|
||||
True, description='Whether the node is currently enabled'
|
||||
)
|
||||
install_date: Optional[datetime] = Field(
|
||||
None, description='ISO timestamp of installation'
|
||||
)
|
||||
|
||||
|
||||
class InstalledModelInfo(BaseModel):
|
||||
name: str = Field(..., description='Model filename')
|
||||
path: str = Field(..., description='Full path to model file')
|
||||
type: str = Field(..., description='Model type (checkpoint, lora, vae, etc.)')
|
||||
size_bytes: Optional[int] = Field(None, description='File size in bytes', ge=0)
|
||||
hash: Optional[str] = Field(None, description='Model file hash for verification')
|
||||
install_date: Optional[datetime] = Field(
|
||||
None, description='ISO timestamp when added'
|
||||
)
|
||||
|
||||
|
||||
class ComfyUIVersionInfo(BaseModel):
|
||||
version: str = Field(..., description='ComfyUI version string')
|
||||
commit_hash: Optional[str] = Field(None, description='Git commit hash')
|
||||
branch: Optional[str] = Field(None, description='Git branch name')
|
||||
is_stable: Optional[bool] = Field(
|
||||
False, description='Whether this is a stable release'
|
||||
)
|
||||
last_updated: Optional[datetime] = Field(
|
||||
None, description='ISO timestamp of last update'
|
||||
)
|
||||
|
||||
|
||||
class OperationType(str, Enum):
|
||||
install = 'install'
|
||||
update = 'update'
|
||||
uninstall = 'uninstall'
|
||||
fix = 'fix'
|
||||
disable = 'disable'
|
||||
enable = 'enable'
|
||||
install_model = 'install-model'
|
||||
|
||||
|
||||
class Result(str, Enum):
|
||||
success = 'success'
|
||||
failed = 'failed'
|
||||
skipped = 'skipped'
|
||||
|
||||
|
||||
class BatchOperation(BaseModel):
|
||||
operation_id: str = Field(..., description='Unique operation identifier')
|
||||
operation_type: OperationType = Field(..., description='Type of operation')
|
||||
target: str = Field(
|
||||
..., description='Target of the operation (node name, model name, etc.)'
|
||||
)
|
||||
target_version: Optional[str] = Field(
|
||||
None, description='Target version for the operation'
|
||||
)
|
||||
result: Result = Field(..., description='Operation result')
|
||||
error_message: Optional[str] = Field(
|
||||
None, description='Error message if operation failed'
|
||||
)
|
||||
start_time: datetime = Field(
|
||||
..., description='ISO timestamp when operation started'
|
||||
)
|
||||
end_time: Optional[datetime] = Field(
|
||||
None, description='ISO timestamp when operation completed'
|
||||
)
|
||||
client_id: Optional[str] = Field(
|
||||
None, description='Client that initiated the operation'
|
||||
)
|
||||
|
||||
|
||||
class ComfyUISystemState(BaseModel):
|
||||
snapshot_time: datetime = Field(
|
||||
..., description='ISO timestamp when snapshot was taken'
|
||||
)
|
||||
comfyui_version: ComfyUIVersionInfo
|
||||
frontend_version: Optional[str] = Field(
|
||||
None, description='ComfyUI frontend version if available'
|
||||
)
|
||||
python_version: str = Field(..., description='Python interpreter version')
|
||||
platform_info: str = Field(
|
||||
..., description='Operating system and platform information'
|
||||
)
|
||||
installed_nodes: Optional[Dict[str, InstalledNodeInfo]] = Field(
|
||||
None, description='Map of installed node packages by name'
|
||||
)
|
||||
installed_models: Optional[Dict[str, InstalledModelInfo]] = Field(
|
||||
None, description='Map of installed models by name'
|
||||
)
|
||||
manager_config: Optional[Dict[str, Any]] = Field(
|
||||
None, description='ComfyUI Manager configuration settings'
|
||||
)
|
||||
|
||||
|
||||
class BatchExecutionRecord(BaseModel):
|
||||
batch_id: str = Field(..., description='Unique batch identifier')
|
||||
start_time: datetime = Field(..., description='ISO timestamp when batch started')
|
||||
end_time: Optional[datetime] = Field(
|
||||
None, description='ISO timestamp when batch completed'
|
||||
)
|
||||
state_before: ComfyUISystemState
|
||||
state_after: Optional[ComfyUISystemState] = Field(
|
||||
None, description='System state after batch execution'
|
||||
)
|
||||
operations: Optional[List[BatchOperation]] = Field(
|
||||
None, description='List of operations performed in this batch'
|
||||
)
|
||||
total_operations: Optional[int] = Field(
|
||||
0, description='Total number of operations in batch', ge=0
|
||||
)
|
||||
successful_operations: Optional[int] = Field(
|
||||
0, description='Number of successful operations', ge=0
|
||||
)
|
||||
failed_operations: Optional[int] = Field(
|
||||
0, description='Number of failed operations', ge=0
|
||||
)
|
||||
skipped_operations: Optional[int] = Field(
|
||||
0, description='Number of skipped operations', ge=0
|
||||
)
|
||||
|
||||
|
||||
class TaskHistoryItem(BaseModel):
|
||||
ui_id: str = Field(..., description='Unique identifier for the task')
|
||||
client_id: str = Field(..., description='Client identifier that initiated the task')
|
||||
kind: str = Field(..., description='Type of task that was performed')
|
||||
timestamp: datetime = Field(..., description='ISO timestamp when task completed')
|
||||
result: str = Field(..., description='Task result message or details')
|
||||
status: Optional[TaskExecutionStatus] = None
|
||||
|
||||
|
||||
class TaskStateMessage(BaseModel):
|
||||
history: Dict[str, TaskHistoryItem] = Field(
|
||||
..., description='Map of task IDs to their history items'
|
||||
)
|
||||
running_queue: List[QueueTaskItem] = Field(
|
||||
..., description='Currently executing tasks'
|
||||
)
|
||||
pending_queue: List[QueueTaskItem] = Field(
|
||||
..., description='Tasks waiting to be executed'
|
||||
)
|
||||
|
||||
|
||||
class MessageTaskDone(BaseModel):
|
||||
ui_id: str = Field(..., description='Task identifier')
|
||||
result: str = Field(..., description='Task result message')
|
||||
kind: str = Field(..., description='Type of task')
|
||||
status: Optional[TaskExecutionStatus] = None
|
||||
timestamp: datetime = Field(..., description='ISO timestamp when task completed')
|
||||
state: TaskStateMessage
|
||||
|
||||
|
||||
class MessageTaskStarted(BaseModel):
|
||||
ui_id: str = Field(..., description='Task identifier')
|
||||
kind: str = Field(..., description='Type of task')
|
||||
timestamp: datetime = Field(..., description='ISO timestamp when task started')
|
||||
state: TaskStateMessage
|
||||
|
||||
|
||||
class MessageTaskFailed(BaseModel):
|
||||
ui_id: str = Field(..., description='Task identifier')
|
||||
error: str = Field(..., description='Error message')
|
||||
kind: str = Field(..., description='Type of task')
|
||||
timestamp: datetime = Field(..., description='ISO timestamp when task failed')
|
||||
state: TaskStateMessage
|
||||
|
||||
|
||||
class MessageUpdate(
|
||||
RootModel[Union[MessageTaskDone, MessageTaskStarted, MessageTaskFailed]]
|
||||
):
|
||||
root: Union[MessageTaskDone, MessageTaskStarted, MessageTaskFailed] = Field(
|
||||
..., description='Union type for all possible WebSocket message updates'
|
||||
)
|
||||
|
||||
|
||||
class HistoryResponse(BaseModel):
|
||||
history: Optional[Dict[str, TaskHistoryItem]] = Field(
|
||||
None, description='Map of task IDs to their history items'
|
||||
)
|
||||
@ -1,69 +0,0 @@
|
||||
"""
|
||||
Task queue data models for ComfyUI Manager.
|
||||
|
||||
Contains Pydantic models for task queue management, WebSocket messaging,
|
||||
and task state tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union, Dict
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class QueueTaskItem(BaseModel):
|
||||
"""Represents a task item in the queue."""
|
||||
|
||||
ui_id: str
|
||||
client_id: str
|
||||
kind: str
|
||||
|
||||
|
||||
class TaskHistoryItem(BaseModel):
|
||||
"""Represents a completed task in the history."""
|
||||
|
||||
ui_id: str
|
||||
client_id: str
|
||||
kind: str
|
||||
timestamp: str
|
||||
result: str
|
||||
status: Optional[dict] = None
|
||||
|
||||
|
||||
class TaskStateMessage(BaseModel):
|
||||
"""Current state of the task queue system."""
|
||||
|
||||
history: Dict[str, TaskHistoryItem]
|
||||
running_queue: list[QueueTaskItem]
|
||||
pending_queue: list[QueueTaskItem]
|
||||
|
||||
|
||||
class MessageTaskDone(BaseModel):
|
||||
"""WebSocket message sent when a task completes."""
|
||||
|
||||
ui_id: str
|
||||
result: str
|
||||
kind: str
|
||||
status: Optional[dict]
|
||||
timestamp: str
|
||||
state: TaskStateMessage
|
||||
|
||||
|
||||
class MessageTaskStarted(BaseModel):
|
||||
"""WebSocket message sent when a task starts."""
|
||||
|
||||
ui_id: str
|
||||
kind: str
|
||||
timestamp: str
|
||||
state: TaskStateMessage
|
||||
|
||||
|
||||
# Union type for all possible WebSocket message updates
|
||||
MessageUpdate = Union[MessageTaskDone, MessageTaskStarted]
|
||||
|
||||
|
||||
class ManagerMessageName(Enum):
|
||||
"""WebSocket message type constants."""
|
||||
|
||||
TASK_DONE = "cm-task-completed"
|
||||
TASK_STARTED = "cm-task-started"
|
||||
STATUS = "cm-queue-status"
|
||||
@ -1,42 +1,39 @@
|
||||
import traceback
|
||||
|
||||
import folder_paths
|
||||
import locale
|
||||
import subprocess # don't remove this
|
||||
import concurrent
|
||||
import nodes
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import git
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import heapq
|
||||
import copy
|
||||
from typing import NamedTuple, List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, List, Literal, Optional
|
||||
from comfy.cli_args import args
|
||||
import latent_preview
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
import json
|
||||
import zipfile
|
||||
import urllib.request
|
||||
|
||||
from comfyui_manager.glob.utils import (
|
||||
environment_utils,
|
||||
formatting_utils,
|
||||
model_utils,
|
||||
security_utils,
|
||||
formatting_utils,
|
||||
node_pack_utils,
|
||||
environment_utils,
|
||||
)
|
||||
|
||||
|
||||
from server import PromptServer
|
||||
import logging
|
||||
import asyncio
|
||||
from collections import deque
|
||||
|
||||
from . import manager_core as core
|
||||
from ..common import manager_util
|
||||
@ -44,8 +41,6 @@ from ..common import cm_global
|
||||
from ..common import manager_downloader
|
||||
from ..common import context
|
||||
|
||||
from pydantic import BaseModel
|
||||
import heapq
|
||||
|
||||
from ..data_models import (
|
||||
QueueTaskItem,
|
||||
@ -55,8 +50,30 @@ from ..data_models import (
|
||||
MessageTaskStarted,
|
||||
MessageUpdate,
|
||||
ManagerMessageName,
|
||||
BatchExecutionRecord,
|
||||
ComfyUISystemState,
|
||||
BatchOperation,
|
||||
InstalledNodeInfo,
|
||||
InstalledModelInfo,
|
||||
ComfyUIVersionInfo,
|
||||
)
|
||||
|
||||
from .constants import (
|
||||
model_dir_name_map,
|
||||
SECURITY_MESSAGE_MIDDLE_OR_BELOW,
|
||||
SECURITY_MESSAGE_NORMAL_MINUS_MODEL,
|
||||
SECURITY_MESSAGE_GENERAL,
|
||||
SECURITY_MESSAGE_NORMAL_MINUS,
|
||||
)
|
||||
|
||||
# For legacy compatibility - these may need to be implemented in the new structure
|
||||
temp_queue_batch = []
|
||||
task_worker_lock = threading.RLock()
|
||||
|
||||
def finalize_temp_queue_batch():
|
||||
"""Temporary compatibility function - to be implemented with new queue system"""
|
||||
pass
|
||||
|
||||
|
||||
if not manager_util.is_manager_pip_package():
|
||||
network_mode_description = "offline"
|
||||
@ -220,7 +237,9 @@ class TaskQueue:
|
||||
self.running_tasks = {}
|
||||
self.history_tasks = {}
|
||||
self.task_counter = 0
|
||||
self.batch_id = 0
|
||||
self.batch_id = None
|
||||
self.batch_start_time = None
|
||||
self.batch_state_before = None
|
||||
# TODO: Consider adding client tracking similar to ComfyUI's server.client_id
|
||||
# to track which client is currently executing for better session management
|
||||
|
||||
@ -239,9 +258,11 @@ class TaskQueue:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def send_queue_state_update(msg: str, update: MessageUpdate, client_id: Optional[str] = None) -> None:
|
||||
def send_queue_state_update(
|
||||
msg: str, update: MessageUpdate, client_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Send queue state update to clients.
|
||||
|
||||
|
||||
Args:
|
||||
msg: Message type/event name
|
||||
update: Update data to send
|
||||
@ -252,8 +273,19 @@ class TaskQueue:
|
||||
|
||||
def put(self, item: QueueTaskItem) -> None:
|
||||
with self.mutex:
|
||||
# Start a new batch if this is the first task after queue was empty
|
||||
if self.batch_id is None and len(self.pending_tasks) == 0 and len(self.running_tasks) == 0:
|
||||
self._start_new_batch()
|
||||
|
||||
heapq.heappush(self.pending_tasks, item)
|
||||
self.not_empty.notify()
|
||||
|
||||
def _start_new_batch(self) -> None:
|
||||
"""Start a new batch session for tracking operations."""
|
||||
self.batch_id = f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
self.batch_start_time = datetime.now().isoformat()
|
||||
self.batch_state_before = self._capture_system_state()
|
||||
logging.info(f"[ComfyUI-Manager] Started new batch: {self.batch_id}")
|
||||
|
||||
def get(
|
||||
self, timeout: Optional[float] = None
|
||||
@ -275,7 +307,9 @@ class TaskQueue:
|
||||
timestamp=datetime.now().isoformat(),
|
||||
state=self.get_current_state(),
|
||||
),
|
||||
client_id=item["client_id"] # Send task started only to the client that requested it
|
||||
client_id=item[
|
||||
"client_id"
|
||||
], # Send task started only to the client that requested it
|
||||
)
|
||||
return item, task_index
|
||||
|
||||
@ -319,7 +353,9 @@ class TaskQueue:
|
||||
timestamp=timestamp,
|
||||
state=self.get_current_state(),
|
||||
),
|
||||
client_id=item["client_id"] # Send completion only to the client that requested it
|
||||
client_id=item[
|
||||
"client_id"
|
||||
], # Send completion only to the client that requested it
|
||||
)
|
||||
|
||||
def get_current_queue(self) -> tuple[list[QueueTaskItem], list[QueueTaskItem]]:
|
||||
@ -377,7 +413,7 @@ class TaskQueue:
|
||||
|
||||
def done_count(self) -> int:
|
||||
"""Get the number of completed tasks in history.
|
||||
|
||||
|
||||
Returns:
|
||||
int: Number of tasks that have been completed and are stored in history.
|
||||
Returns 0 if history_tasks is None (defensive programming).
|
||||
@ -386,7 +422,7 @@ class TaskQueue:
|
||||
|
||||
def total_count(self) -> int:
|
||||
"""Get the total number of tasks currently in the system (pending + running).
|
||||
|
||||
|
||||
Returns:
|
||||
int: Combined count of pending and running tasks.
|
||||
Returns 0 if either collection is None (defensive programming).
|
||||
@ -399,21 +435,142 @@ class TaskQueue:
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Finalize a completed task batch by saving execution history to disk.
|
||||
|
||||
|
||||
This method is intended to be called when the queue transitions from having
|
||||
tasks to being completely empty (no pending or running tasks). It will create
|
||||
a comprehensive snapshot of the ComfyUI state and all operations performed.
|
||||
|
||||
Note: Currently incomplete - requires implementation of state management models.
|
||||
"""
|
||||
if self.batch_id is not None:
|
||||
batch_path = os.path.join(
|
||||
context.manager_batch_history_path, self.batch_id + ".json"
|
||||
)
|
||||
# TODO: create a pydantic model for state of ComfyUI (installed nodes, models, ComfyUI version, ComfyUI frontend version) + the operations that occurred in the batch. Then add a serialization method that can work nicely for saving to json file. Finally, add post creation validation methods on the pydantic model. Then, anytime the queue goes from full to completely empty (also none running) -> run this finalize to save the snapshot.
|
||||
# Add logic here to instanitation model then save below using the serialization methodd of the object
|
||||
# with open(batch_path, "w") as json_file:
|
||||
# json.dump(json_obj, json_file, indent=4)
|
||||
|
||||
try:
|
||||
end_time = datetime.now().isoformat()
|
||||
state_after = self._capture_system_state()
|
||||
operations = self._extract_batch_operations()
|
||||
|
||||
batch_record = BatchExecutionRecord(
|
||||
batch_id=self.batch_id,
|
||||
start_time=self.batch_start_time,
|
||||
end_time=end_time,
|
||||
state_before=self.batch_state_before,
|
||||
state_after=state_after,
|
||||
operations=operations,
|
||||
total_operations=len(operations),
|
||||
successful_operations=len([op for op in operations if op.result == "success"]),
|
||||
failed_operations=len([op for op in operations if op.result == "failed"]),
|
||||
skipped_operations=len([op for op in operations if op.result == "skipped"])
|
||||
)
|
||||
|
||||
# Save to disk
|
||||
with open(batch_path, "w", encoding="utf-8") as json_file:
|
||||
json.dump(batch_record.model_dump(), json_file, indent=4, default=str)
|
||||
|
||||
logging.info(f"[ComfyUI-Manager] Batch history saved: {batch_path}")
|
||||
|
||||
# Reset batch tracking
|
||||
self.batch_id = None
|
||||
self.batch_start_time = None
|
||||
self.batch_state_before = None
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"[ComfyUI-Manager] Failed to save batch history: {e}")
|
||||
|
||||
def _capture_system_state(self) -> ComfyUISystemState:
|
||||
"""Capture current ComfyUI system state for batch record."""
|
||||
return ComfyUISystemState(
|
||||
snapshot_time=datetime.now().isoformat(),
|
||||
comfyui_version=self._get_comfyui_version_info(),
|
||||
python_version=platform.python_version(),
|
||||
platform_info=f"{platform.system()} {platform.release()} ({platform.machine()})",
|
||||
installed_nodes=self._get_installed_nodes(),
|
||||
installed_models=self._get_installed_models()
|
||||
)
|
||||
|
||||
def _get_comfyui_version_info(self) -> ComfyUIVersionInfo:
|
||||
"""Get ComfyUI version information."""
|
||||
try:
|
||||
version_info = core.get_comfyui_versions()
|
||||
current_version = version_info[1] if len(version_info) > 1 else "unknown"
|
||||
return ComfyUIVersionInfo(version=current_version)
|
||||
except Exception:
|
||||
return ComfyUIVersionInfo(version="unknown")
|
||||
|
||||
def _get_installed_nodes(self) -> dict[str, InstalledNodeInfo]:
|
||||
"""Get information about installed node packages."""
|
||||
installed_nodes = {}
|
||||
|
||||
try:
|
||||
node_packs = core.get_installed_node_packs()
|
||||
for pack_name, pack_info in node_packs.items():
|
||||
installed_nodes[pack_name] = InstalledNodeInfo(
|
||||
name=pack_name,
|
||||
version=pack_info.get("ver", "unknown"),
|
||||
install_method="unknown",
|
||||
enabled=pack_info.get("enabled", True)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[ComfyUI-Manager] Failed to get installed nodes: {e}")
|
||||
|
||||
return installed_nodes
|
||||
|
||||
def _get_installed_models(self) -> dict[str, InstalledModelInfo]:
|
||||
"""Get information about installed models."""
|
||||
installed_models = {}
|
||||
|
||||
try:
|
||||
model_dirs = ["checkpoints", "loras", "vae", "embeddings", "controlnet", "upscale_models"]
|
||||
|
||||
for model_type in model_dirs:
|
||||
try:
|
||||
files = folder_paths.get_filename_list(model_type)
|
||||
for filename in files:
|
||||
model_paths = folder_paths.get_folder_paths(model_type)
|
||||
if model_paths:
|
||||
full_path = os.path.join(model_paths[0], filename)
|
||||
if os.path.exists(full_path):
|
||||
installed_models[filename] = InstalledModelInfo(
|
||||
name=filename,
|
||||
path=full_path,
|
||||
type=model_type,
|
||||
size_bytes=os.path.getsize(full_path)
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"[ComfyUI-Manager] Failed to get installed models: {e}")
|
||||
|
||||
return installed_models
|
||||
|
||||
def _extract_batch_operations(self) -> list[BatchOperation]:
|
||||
"""Extract operations from completed task history for this batch."""
|
||||
operations = []
|
||||
|
||||
try:
|
||||
for ui_id, task in self.history_tasks.items():
|
||||
result_status = "success"
|
||||
if task.status:
|
||||
status_str = task.status.get("status_str", "success")
|
||||
if status_str == "error":
|
||||
result_status = "failed"
|
||||
elif status_str == "skip":
|
||||
result_status = "skipped"
|
||||
|
||||
operation = BatchOperation(
|
||||
operation_id=ui_id,
|
||||
operation_type=task.kind,
|
||||
target=f"task_{ui_id}",
|
||||
result=result_status,
|
||||
start_time=task.timestamp,
|
||||
client_id=task.client_id
|
||||
)
|
||||
operations.append(operation)
|
||||
except Exception as e:
|
||||
logging.warning(f"[ComfyUI-Manager] Failed to extract batch operations: {e}")
|
||||
|
||||
return operations
|
||||
|
||||
|
||||
task_queue = TaskQueue()
|
||||
@ -535,7 +692,7 @@ async def task_worker():
|
||||
return "success"
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return f"Installation failed:\n{node_spec_str}"
|
||||
return "Installation failed"
|
||||
|
||||
async def do_enable(item) -> str:
|
||||
cnr_id = item.get("cnr_id")
|
||||
@ -668,7 +825,7 @@ async def task_worker():
|
||||
async def do_install_model(item) -> str:
|
||||
json_data = item.get("json_data")
|
||||
|
||||
model_path = get_model_path(json_data)
|
||||
model_path = model_utils.get_model_path(json_data)
|
||||
model_url = json_data.get("url")
|
||||
|
||||
res = False
|
||||
@ -702,7 +859,7 @@ async def task_worker():
|
||||
or model_url.startswith("https://huggingface.co")
|
||||
or model_url.startswith("https://heibox.uni-heidelberg.de")
|
||||
):
|
||||
model_dir = get_model_dir(json_data, True)
|
||||
model_dir = model_utils.get_model_dir(json_data, True)
|
||||
download_url(model_url, model_dir, filename=json_data["filename"])
|
||||
if model_path.endswith(".zip"):
|
||||
res = core.unzip(model_path)
|
||||
@ -736,18 +893,26 @@ async def task_worker():
|
||||
timeout = 4096
|
||||
task = task_queue.get(timeout)
|
||||
if task is None:
|
||||
logging.info("\n[ComfyUI-Manager] All tasks are completed.")
|
||||
logging.info("\nAfter restarting ComfyUI, please refresh the browser.")
|
||||
# Check if queue is truly empty (no pending or running tasks)
|
||||
if task_queue.total_count() == 0 and len(task_queue.running_tasks) == 0:
|
||||
logging.info("\n[ComfyUI-Manager] All tasks are completed.")
|
||||
|
||||
# Trigger batch history serialization if there are completed tasks
|
||||
if task_queue.done_count() > 0:
|
||||
logging.info("[ComfyUI-Manager] Finalizing batch history...")
|
||||
task_queue.finalize()
|
||||
logging.info("[ComfyUI-Manager] Batch history saved.")
|
||||
|
||||
logging.info("\nAfter restarting ComfyUI, please refresh the browser.")
|
||||
|
||||
res = {"status": "all-done"}
|
||||
res = {"status": "all-done"}
|
||||
|
||||
# Broadcast general status updates to all clients
|
||||
PromptServer.instance.send_sync("cm-queue-status", res)
|
||||
# Broadcast general status updates to all clients
|
||||
PromptServer.instance.send_sync("cm-queue-status", res)
|
||||
|
||||
return
|
||||
|
||||
item, task_index = task
|
||||
ui_id = item["ui_id"]
|
||||
kind = item["kind"]
|
||||
|
||||
print(f"Processing task: {kind} with item: {item} at index: {task_index}")
|
||||
@ -777,7 +942,9 @@ async def task_worker():
|
||||
msg = "Unexpected kind: " + kind
|
||||
except Exception:
|
||||
msg = f"Exception: {(kind, item)}"
|
||||
task_queue.task_done(item, msg, TaskQueue.ExecutionStatus("error", True, [msg]))
|
||||
task_queue.task_done(
|
||||
item, msg, TaskQueue.ExecutionStatus("error", True, [msg])
|
||||
)
|
||||
|
||||
# Determine status and message for task completion
|
||||
if isinstance(msg, dict) and "msg" in msg:
|
||||
@ -799,13 +966,13 @@ async def task_worker():
|
||||
@routes.post("/v2/manager/queue/task")
|
||||
async def queue_task(request) -> web.Response:
|
||||
"""Add a new task to the processing queue.
|
||||
|
||||
|
||||
Accepts task data via JSON POST and adds it to the TaskQueue for processing.
|
||||
The task worker will automatically pick up and process queued tasks.
|
||||
|
||||
|
||||
Args:
|
||||
request: aiohttp request containing JSON task data
|
||||
|
||||
|
||||
Returns:
|
||||
web.Response: HTTP 200 on successful queueing
|
||||
"""
|
||||
@ -818,10 +985,10 @@ async def queue_task(request) -> web.Response:
|
||||
@routes.get("/v2/manager/queue/history_list")
|
||||
async def get_history_list(request) -> web.Response:
|
||||
"""Get list of available batch history files.
|
||||
|
||||
|
||||
Returns a list of batch history IDs sorted by modification time (newest first).
|
||||
These IDs can be used with the history endpoint to retrieve detailed batch information.
|
||||
|
||||
|
||||
Returns:
|
||||
web.Response: JSON response with 'ids' array of history file IDs
|
||||
"""
|
||||
@ -847,14 +1014,14 @@ async def get_history_list(request) -> web.Response:
|
||||
@routes.get("/v2/manager/queue/history")
|
||||
async def get_history(request):
|
||||
"""Get task history with optional client filtering.
|
||||
|
||||
|
||||
Query parameters:
|
||||
id: Batch history ID (for file-based history)
|
||||
client_id: Optional client ID to filter current session history
|
||||
ui_id: Optional specific task ID to get single task history
|
||||
max_items: Maximum number of items to return
|
||||
offset: Offset for pagination
|
||||
|
||||
|
||||
Returns:
|
||||
JSON with filtered history data
|
||||
"""
|
||||
@ -868,32 +1035,33 @@ async def get_history(request):
|
||||
json_str = file.read()
|
||||
json_obj = json.loads(json_str)
|
||||
return web.json_response(json_obj, content_type="application/json")
|
||||
|
||||
|
||||
# Handle current session history with optional filtering
|
||||
client_id = request.rel_url.query.get("client_id")
|
||||
ui_id = request.rel_url.query.get("ui_id")
|
||||
max_items = request.rel_url.query.get("max_items")
|
||||
offset = request.rel_url.query.get("offset", -1)
|
||||
|
||||
|
||||
if max_items:
|
||||
max_items = int(max_items)
|
||||
if offset:
|
||||
offset = int(offset)
|
||||
|
||||
|
||||
# Get history from TaskQueue
|
||||
if ui_id:
|
||||
history = task_queue.get_history(ui_id=ui_id)
|
||||
else:
|
||||
history = task_queue.get_history(max_items=max_items, offset=offset)
|
||||
|
||||
|
||||
# Filter by client_id if provided
|
||||
if client_id and isinstance(history, dict):
|
||||
filtered_history = {
|
||||
task_id: task_data for task_id, task_data in history.items()
|
||||
if hasattr(task_data, 'client_id') and task_data.client_id == client_id
|
||||
task_id: task_data
|
||||
for task_id, task_data in history.items()
|
||||
if hasattr(task_data, "client_id") and task_data.client_id == client_id
|
||||
}
|
||||
history = filtered_history
|
||||
|
||||
|
||||
return web.json_response({"history": history}, content_type="application/json")
|
||||
|
||||
except Exception as e:
|
||||
@ -918,7 +1086,7 @@ async def fetch_customnode_mappings(request):
|
||||
json_obj = core.map_to_unified_keys(json_obj)
|
||||
|
||||
if nickname_mode:
|
||||
json_obj = nickname_filter(json_obj)
|
||||
json_obj = node_pack_utils.nickname_filter(json_obj)
|
||||
|
||||
all_nodes = set()
|
||||
patterns = []
|
||||
@ -974,7 +1142,7 @@ async def update_all(request):
|
||||
|
||||
|
||||
async def _update_all(json_data):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(status=403)
|
||||
|
||||
@ -1166,7 +1334,7 @@ async def get_snapshot_list(request):
|
||||
|
||||
@routes.get("/v2/snapshot/remove")
|
||||
async def remove_snapshot(request):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(status=403)
|
||||
|
||||
@ -1184,7 +1352,7 @@ async def remove_snapshot(request):
|
||||
|
||||
@routes.get("/v2/snapshot/restore")
|
||||
async def restore_snapshot(request):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(status=403)
|
||||
|
||||
@ -1277,8 +1445,8 @@ async def import_fail_info(request):
|
||||
|
||||
@routes.post("/v2/manager/queue/reinstall")
|
||||
async def reinstall_custom_node(request):
|
||||
await uninstall_custom_node(request)
|
||||
await install_custom_node(request)
|
||||
await _uninstall_custom_node(await request.json())
|
||||
await _install_custom_node(await request.json())
|
||||
|
||||
|
||||
@routes.get("/v2/manager/queue/reset")
|
||||
@ -1289,58 +1457,68 @@ async def reset_queue(request):
|
||||
|
||||
@routes.get("/v2/manager/queue/abort_current")
|
||||
async def abort_queue(request):
|
||||
task_queue.abort()
|
||||
# task_queue.abort() # Method not implemented yet
|
||||
task_queue.wipe_queue()
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@routes.get("/v2/manager/queue/status")
|
||||
async def queue_count(request):
|
||||
"""Get current queue status with optional client filtering.
|
||||
|
||||
|
||||
Query parameters:
|
||||
client_id: Optional client ID to filter tasks
|
||||
|
||||
|
||||
Returns:
|
||||
JSON with queue counts and processing status
|
||||
"""
|
||||
client_id = request.query.get("client_id")
|
||||
|
||||
|
||||
if client_id:
|
||||
# Filter tasks by client_id
|
||||
running_client_tasks = [
|
||||
task for task in task_queue.running_tasks.values()
|
||||
task
|
||||
for task in task_queue.running_tasks.values()
|
||||
if task.get("client_id") == client_id
|
||||
]
|
||||
pending_client_tasks = [
|
||||
task for task in task_queue.pending_tasks
|
||||
task
|
||||
for task in task_queue.pending_tasks
|
||||
if task.get("client_id") == client_id
|
||||
]
|
||||
history_client_tasks = {
|
||||
ui_id: task for ui_id, task in task_queue.history_tasks.items()
|
||||
if hasattr(task, 'client_id') and task.client_id == client_id
|
||||
ui_id: task
|
||||
for ui_id, task in task_queue.history_tasks.items()
|
||||
if hasattr(task, "client_id") and task.client_id == client_id
|
||||
}
|
||||
|
||||
return web.json_response({
|
||||
"client_id": client_id,
|
||||
"total_count": len(pending_client_tasks) + len(running_client_tasks),
|
||||
"done_count": len(history_client_tasks),
|
||||
"in_progress_count": len(running_client_tasks),
|
||||
"pending_count": len(pending_client_tasks),
|
||||
"is_processing": task_worker_thread is not None and task_worker_thread.is_alive(),
|
||||
})
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"client_id": client_id,
|
||||
"total_count": len(pending_client_tasks) + len(running_client_tasks),
|
||||
"done_count": len(history_client_tasks),
|
||||
"in_progress_count": len(running_client_tasks),
|
||||
"pending_count": len(pending_client_tasks),
|
||||
"is_processing": task_worker_thread is not None
|
||||
and task_worker_thread.is_alive(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Return overall status
|
||||
return web.json_response({
|
||||
"total_count": task_queue.total_count(),
|
||||
"done_count": task_queue.done_count(),
|
||||
"in_progress_count": len(task_queue.running_tasks),
|
||||
"pending_count": len(task_queue.pending_tasks),
|
||||
"is_processing": task_worker_thread is not None and task_worker_thread.is_alive(),
|
||||
})
|
||||
return web.json_response(
|
||||
{
|
||||
"total_count": task_queue.total_count(),
|
||||
"done_count": task_queue.done_count(),
|
||||
"in_progress_count": len(task_queue.running_tasks),
|
||||
"pending_count": len(task_queue.pending_tasks),
|
||||
"is_processing": task_worker_thread is not None
|
||||
and task_worker_thread.is_alive(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def _install_custom_node(json_data):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(
|
||||
status=403,
|
||||
@ -1396,14 +1574,14 @@ async def _install_custom_node(json_data):
|
||||
# apply security policy if not cnr node (nightly isn't regarded as cnr node)
|
||||
if risky_level is None:
|
||||
if git_url is not None:
|
||||
risky_level = await get_risky_level(git_url, json_data.get("pip", []))
|
||||
risky_level = await security_utils.get_risky_level(git_url, json_data.get("pip", []))
|
||||
else:
|
||||
return web.Response(
|
||||
status=404,
|
||||
text=f"Following node pack doesn't provide `nightly` version: ${git_url}",
|
||||
)
|
||||
|
||||
if not is_allowed_security_level(risky_level):
|
||||
if not security_utils.is_allowed_security_level(risky_level):
|
||||
logging.error(SECURITY_MESSAGE_GENERAL)
|
||||
return web.Response(
|
||||
status=404,
|
||||
@ -1424,12 +1602,14 @@ async def _install_custom_node(json_data):
|
||||
|
||||
task_worker_thread: threading.Thread = None
|
||||
|
||||
|
||||
@routes.get("/v2/manager/queue/start")
|
||||
async def queue_start(request):
|
||||
with task_worker_lock:
|
||||
finalize_temp_queue_batch()
|
||||
return _queue_start()
|
||||
|
||||
|
||||
def _queue_start():
|
||||
global task_worker_thread
|
||||
|
||||
@ -1442,16 +1622,11 @@ def _queue_start():
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@routes.get("/v2/manager/queue/start")
|
||||
async def queue_start(request):
|
||||
_queue_start()
|
||||
# with task_worker_lock:
|
||||
# finalize_temp_queue_batch()
|
||||
# return _queue_start()
|
||||
# Duplicate queue_start function removed - using the earlier one with proper implementation
|
||||
|
||||
|
||||
async def _fix_custom_node(json_data):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_GENERAL)
|
||||
return web.Response(
|
||||
status=403,
|
||||
@ -1474,7 +1649,7 @@ async def _fix_custom_node(json_data):
|
||||
|
||||
@routes.post("/v2/customnode/install/git_url")
|
||||
async def install_custom_node_git_url(request):
|
||||
if not is_allowed_security_level("high"):
|
||||
if not security_utils.is_allowed_security_level("high"):
|
||||
logging.error(SECURITY_MESSAGE_NORMAL_MINUS)
|
||||
return web.Response(status=403)
|
||||
|
||||
@ -1494,7 +1669,7 @@ async def install_custom_node_git_url(request):
|
||||
|
||||
@routes.post("/v2/customnode/install/pip")
|
||||
async def install_custom_node_pip(request):
|
||||
if not is_allowed_security_level("high"):
|
||||
if not security_utils.is_allowed_security_level("high"):
|
||||
logging.error(SECURITY_MESSAGE_NORMAL_MINUS)
|
||||
return web.Response(status=403)
|
||||
|
||||
@ -1505,7 +1680,7 @@ async def install_custom_node_pip(request):
|
||||
|
||||
|
||||
async def _uninstall_custom_node(json_data):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(
|
||||
status=403,
|
||||
@ -1528,7 +1703,7 @@ async def _uninstall_custom_node(json_data):
|
||||
|
||||
|
||||
async def _update_custom_node(json_data):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(
|
||||
status=403,
|
||||
@ -1628,7 +1803,7 @@ async def install_model(request):
|
||||
|
||||
|
||||
async def _install_model(json_data):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(
|
||||
status=403,
|
||||
@ -1646,7 +1821,7 @@ async def _install_model(json_data):
|
||||
|
||||
if not json_data["filename"].endswith(
|
||||
".safetensors"
|
||||
) and not is_allowed_security_level("high"):
|
||||
) and not security_utils.is_allowed_security_level("high"):
|
||||
models_json = await core.get_data_by_mode("cache", "model-list.json", "default")
|
||||
|
||||
is_belongs_to_whitelist = False
|
||||
@ -1671,7 +1846,7 @@ async def _install_model(json_data):
|
||||
@routes.get("/v2/manager/preview_method")
|
||||
async def preview_method(request):
|
||||
if "value" in request.rel_url.query:
|
||||
set_preview_method(request.rel_url.query["value"])
|
||||
environment_utils.set_preview_method(request.rel_url.query["value"])
|
||||
core.write_config()
|
||||
else:
|
||||
return web.Response(
|
||||
@ -1684,7 +1859,7 @@ async def preview_method(request):
|
||||
@routes.get("/v2/manager/db_mode")
|
||||
async def db_mode(request):
|
||||
if "value" in request.rel_url.query:
|
||||
set_db_mode(request.rel_url.query["value"])
|
||||
environment_utils.set_db_mode(request.rel_url.query["value"])
|
||||
core.write_config()
|
||||
else:
|
||||
return web.Response(text=core.get_config()["db_mode"], status=200)
|
||||
@ -1695,7 +1870,7 @@ async def db_mode(request):
|
||||
@routes.get("/v2/manager/policy/update")
|
||||
async def update_policy(request):
|
||||
if "value" in request.rel_url.query:
|
||||
set_update_policy(request.rel_url.query["value"])
|
||||
environment_utils.set_update_policy(request.rel_url.query["value"])
|
||||
core.write_config()
|
||||
else:
|
||||
return web.Response(text=core.get_config()["update_policy"], status=200)
|
||||
@ -1728,7 +1903,7 @@ async def channel_url_list(request):
|
||||
|
||||
@routes.get("/v2/manager/reboot")
|
||||
def restart(self):
|
||||
if not is_allowed_security_level("middle"):
|
||||
if not security_utils.is_allowed_security_level("middle"):
|
||||
logging.error(SECURITY_MESSAGE_MIDDLE_OR_BELOW)
|
||||
return web.Response(status=403)
|
||||
|
||||
|
||||
1280
openapi.yaml
1280
openapi.yaml
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user