mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
refactoring: use the same code for "scan task" and realtime DB population
This commit is contained in:
parent
d7464e9e73
commit
09dabf95bc
@ -117,21 +117,25 @@ def upgrade() -> None:
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
# Core tags
|
||||
{"name": "checkpoint", "tag_type": "system"},
|
||||
{"name": "lora", "tag_type": "system"},
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text-encoder", "tag_type": "system"},
|
||||
{"name": "clip-vision", "tag_type": "system"},
|
||||
{"name": "embedding", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "upscale", "tag_type": "system"},
|
||||
{"name": "diffusion-model", "tag_type": "system"},
|
||||
{"name": "hypernetwork", "tag_type": "system"},
|
||||
{"name": "vae-approx", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "style-model", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifier", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
# Extra basic tags (used for vae_approx, ...)
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
],
|
||||
|
||||
99
app/_assets_helpers.py
Normal file
99
app/_assets_helpers.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, Sequence
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, normalize_tags([root_category, *parent_parts])
|
||||
|
||||
|
||||
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.internal import async_to_sync
|
||||
@ -26,6 +26,7 @@ from .database.services import (
|
||||
create_asset_info_for_existing_asset,
|
||||
)
|
||||
from .api import schemas_out
|
||||
from ._assets_helpers import get_name_and_tags_from_asset_path
|
||||
|
||||
|
||||
async def asset_exists(*, asset_hash: str) -> bool:
|
||||
@ -33,16 +34,20 @@ async def asset_exists(*, asset_hash: str) -> bool:
|
||||
return await asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
|
||||
if not args.disable_model_processing:
|
||||
p = Path(file_name)
|
||||
dir_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset,
|
||||
tags=list(dict.fromkeys([*tags, *dir_parts])),
|
||||
file_name=p.name,
|
||||
file_path=file_path,
|
||||
)
|
||||
if tags is None:
|
||||
tags = []
|
||||
try:
|
||||
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset,
|
||||
tags=list(dict.fromkeys([*path_tags, *tags])),
|
||||
file_name=asset_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
except ValueError:
|
||||
logging.exception("Cant parse '%s' as an asset file path.", file_path)
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
|
||||
@ -7,10 +7,11 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
import folder_paths
|
||||
|
||||
from . import assets_manager
|
||||
from .api import schemas_out
|
||||
|
||||
import folder_paths
|
||||
from ._assets_helpers import get_comfy_models_folders
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -36,7 +37,7 @@ class ScanProgress:
|
||||
errors: int = 0
|
||||
last_error: Optional[str] = None
|
||||
|
||||
# Optional details for diagnostics
|
||||
# Optional details for diagnostics (e.g., files per bucket)
|
||||
details: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
@ -49,8 +50,6 @@ def _new_scan_id(root: RootType) -> str:
|
||||
|
||||
|
||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
# make shallow copies to avoid external mutation
|
||||
states = [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
|
||||
return schemas_out.AssetScanStatusResponse(
|
||||
scans=[
|
||||
schemas_out.AssetScanStatus(
|
||||
@ -65,7 +64,7 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
errors=s.errors,
|
||||
last_error=s.last_error,
|
||||
)
|
||||
for s in states
|
||||
for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
|
||||
]
|
||||
)
|
||||
|
||||
@ -94,15 +93,12 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
|
||||
results: list[ScanProgress] = []
|
||||
for root in normalized:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
# already running; return the live progress object
|
||||
results.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
# Create fresh progress
|
||||
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
|
||||
# Start task
|
||||
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}")
|
||||
RUNNING_TASKS[root] = task
|
||||
results.append(prog)
|
||||
@ -151,24 +147,21 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
|
||||
prog.last_error = str(exc)
|
||||
finally:
|
||||
prog.finished_at = time.time()
|
||||
# Drop the task entry if it's the current one
|
||||
t = RUNNING_TASKS.get(root)
|
||||
if t and t.done():
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _scan_models(prog: ScanProgress) -> None:
|
||||
# Iterate all folder_names whose base paths lie under the Comfy 'models' directory
|
||||
models_root = os.path.abspath(os.path.join(folder_paths.base_path, "models"))
|
||||
"""
|
||||
Scan all configured model buckets from folder_paths.folder_names_and_paths,
|
||||
restricted to entries whose base paths lie under folder_paths.models_dir
|
||||
(per get_comfy_models_folders). We trust those mappings and do not try to
|
||||
infer anything else here.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = get_comfy_models_folders()
|
||||
|
||||
# Build list of (folder_name, base_paths[]) that are configured for this category.
|
||||
# If any path for the category lies under 'models', include the category.
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
|
||||
plans: list[tuple[str, str]] = [] # (abs_path, file_name_for_tags)
|
||||
plans: list[str] = [] # absolute file paths to ingest
|
||||
per_bucket: dict[str, int] = {}
|
||||
|
||||
for folder_name, bases in targets:
|
||||
@ -198,13 +191,12 @@ async def _scan_models(prog: ScanProgress) -> None:
|
||||
|
||||
try:
|
||||
if not os.path.getsize(abs_path):
|
||||
continue
|
||||
continue # skip empty files
|
||||
except OSError as e:
|
||||
LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e)
|
||||
continue
|
||||
|
||||
file_name_for_tags = os.path.join(folder_name, rel_path)
|
||||
plans.append((abs_path, file_name_for_tags))
|
||||
plans.append(abs_path)
|
||||
count_valid += 1
|
||||
|
||||
if count_valid:
|
||||
@ -221,16 +213,12 @@ async def _scan_models(prog: ScanProgress) -> None:
|
||||
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
|
||||
tasks: list[asyncio.Task] = []
|
||||
|
||||
for abs_path, name_for_tags in plans:
|
||||
async def worker(fp_abs: str = abs_path, fn_rel: str = name_for_tags):
|
||||
for abs_path in plans:
|
||||
async def worker(fp_abs: str = abs_path):
|
||||
try:
|
||||
# Offload sync ingestion into a thread
|
||||
await asyncio.to_thread(
|
||||
assets_manager.populate_db_with_asset,
|
||||
["models"],
|
||||
fn_rel,
|
||||
fp_abs,
|
||||
)
|
||||
# Offload sync ingestion into a thread; populate_db_with_asset
|
||||
# derives name and tags from the path using _assets_helpers.
|
||||
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
|
||||
except Exception as e:
|
||||
prog.errors += 1
|
||||
prog.last_error = str(e)
|
||||
@ -260,7 +248,10 @@ def _count_files_in_tree(base_abs: str) -> int:
|
||||
|
||||
|
||||
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None:
|
||||
# Guard: base_dir must be a directory
|
||||
"""
|
||||
Generic scanner for input/output roots. We pass only the absolute path to
|
||||
populate_db_with_asset and let it derive the relative name and tags.
|
||||
"""
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs)
|
||||
@ -272,24 +263,27 @@ async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress
|
||||
tasks: list[asyncio.Task] = []
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
rel = os.path.relpath(os.path.join(dirpath, name), base_abs)
|
||||
abs_path = os.path.join(base_abs, rel)
|
||||
abs_path = os.path.abspath(os.path.join(dirpath, name))
|
||||
|
||||
# Safety: ensure within base
|
||||
try:
|
||||
if os.path.commonpath([os.path.abspath(abs_path), base_abs]) != base_abs:
|
||||
if os.path.commonpath([abs_path, base_abs]) != base_abs:
|
||||
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path)
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
async def worker(fp_abs: str = abs_path, fn_rel: str = rel):
|
||||
# Skip empty files and handle stat errors
|
||||
try:
|
||||
if not os.path.getsize(abs_path):
|
||||
continue
|
||||
except OSError as e:
|
||||
LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e)
|
||||
continue
|
||||
|
||||
async def worker(fp_abs: str = abs_path):
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
assets_manager.populate_db_with_asset,
|
||||
[root],
|
||||
fn_rel,
|
||||
fp_abs,
|
||||
)
|
||||
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
|
||||
except Exception as e:
|
||||
prog.errors += 1
|
||||
prog.last_error = str(e)
|
||||
|
||||
@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||
from .timeutil import utcnow
|
||||
|
||||
from .._assets_helpers import normalize_tags
|
||||
|
||||
|
||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
@ -471,7 +471,7 @@ async def set_asset_info_tags(
|
||||
Replace the tag set on an AssetInfo with `tags`. Idempotent.
|
||||
Creates missing tag names as 'user'.
|
||||
"""
|
||||
desired = _normalize_tags(tags)
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
# current links
|
||||
current = set(
|
||||
@ -691,7 +691,7 @@ async def add_tags_to_asset_info(
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = _normalize_tags(tags)
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
@ -753,7 +753,7 @@ async def remove_tags_from_asset_info(
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = _normalize_tags(tags)
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
@ -784,12 +784,8 @@ async def remove_tags_from_asset_info(
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def _normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||
wanted = _normalize_tags(list(names))
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
@ -808,8 +804,8 @@ def _apply_tag_filters(
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = _normalize_tags(include_tags)
|
||||
exclude_tags = _normalize_tags(exclude_tags)
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
|
||||
@ -29,6 +29,7 @@ import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
from app.assets_manager import populate_db_with_asset
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -102,6 +103,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
else:
|
||||
sd = pl_sd
|
||||
|
||||
populate_db_with_asset(ckpt)
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
26
nodes.py
26
nodes.py
@ -31,7 +31,6 @@ from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||
from comfy_api.version_list import supported_versions
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
from app.assets_manager import populate_db_with_asset
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@ -555,9 +554,7 @@ class CheckpointLoader:
|
||||
def load_checkpoint(self, config_name, ckpt_name):
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
|
||||
return out
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
class CheckpointLoaderSimple:
|
||||
@classmethod
|
||||
@ -579,7 +576,6 @@ class CheckpointLoaderSimple:
|
||||
def load_checkpoint(self, ckpt_name):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
|
||||
return out[:3]
|
||||
|
||||
class DiffusersLoader:
|
||||
@ -622,7 +618,6 @@ class unCLIPCheckpointLoader:
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
|
||||
return out
|
||||
|
||||
class CLIPSetLastLayer:
|
||||
@ -681,7 +676,6 @@ class LoraLoader:
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
populate_db_with_asset(["models", "lora"], lora_name, lora_path)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class LoraLoaderModelOnly(LoraLoader):
|
||||
@ -746,15 +740,11 @@ class VAELoader:
|
||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||
|
||||
encoder_path = folder_paths.get_full_path_or_raise("vae_approx", encoder)
|
||||
populate_db_with_asset(["models", "vae-approx", "encoder"], name, encoder_path)
|
||||
enc = comfy.utils.load_torch_file(encoder_path)
|
||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
|
||||
for k in enc:
|
||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||
|
||||
decoder_path = folder_paths.get_full_path_or_raise("vae_approx", decoder)
|
||||
populate_db_with_asset(["models", "vae-approx", "decoder"], name, decoder_path)
|
||||
dec = comfy.utils.load_torch_file(decoder_path)
|
||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
|
||||
for k in dec:
|
||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||
|
||||
@ -787,7 +777,6 @@ class VAELoader:
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
populate_db_with_asset(["models", "vae"], vae_name, vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
@ -807,7 +796,6 @@ class ControlNetLoader:
|
||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
if controlnet is None:
|
||||
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
|
||||
populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path)
|
||||
return (controlnet,)
|
||||
|
||||
class DiffControlNetLoader:
|
||||
@ -824,7 +812,6 @@ class DiffControlNetLoader:
|
||||
def load_controlnet(self, model, control_net_name):
|
||||
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
|
||||
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
|
||||
populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path)
|
||||
return (controlnet,)
|
||||
|
||||
|
||||
@ -932,7 +919,6 @@ class UNETLoader:
|
||||
|
||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
populate_db_with_asset(["models", "diffusion-model"], unet_name, unet_path)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
@ -960,7 +946,6 @@ class CLIPLoader:
|
||||
|
||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
populate_db_with_asset(["models", "text-encoder"], clip_name, clip_path)
|
||||
return (clip,)
|
||||
|
||||
class DualCLIPLoader:
|
||||
@ -991,8 +976,6 @@ class DualCLIPLoader:
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
populate_db_with_asset(["models", "text-encoder"], clip_name1, clip_path1)
|
||||
populate_db_with_asset(["models", "text-encoder"], clip_name2, clip_path2)
|
||||
return (clip,)
|
||||
|
||||
class CLIPVisionLoader:
|
||||
@ -1010,7 +993,6 @@ class CLIPVisionLoader:
|
||||
clip_vision = comfy.clip_vision.load(clip_path)
|
||||
if clip_vision is None:
|
||||
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
|
||||
populate_db_with_asset(["models", "clip-vision"], clip_name, clip_path)
|
||||
return (clip_vision,)
|
||||
|
||||
class CLIPVisionEncode:
|
||||
@ -1045,7 +1027,6 @@ class StyleModelLoader:
|
||||
def load_style_model(self, style_model_name):
|
||||
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
|
||||
style_model = comfy.sd.load_style_model(style_model_path)
|
||||
populate_db_with_asset(["models", "style-model"], style_model_name, style_model_path)
|
||||
return (style_model,)
|
||||
|
||||
|
||||
@ -1143,7 +1124,6 @@ class GLIGENLoader:
|
||||
def load_gligen(self, gligen_name):
|
||||
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
|
||||
gligen = comfy.sd.load_gligen(gligen_path)
|
||||
populate_db_with_asset(["models", "gligen"], gligen_name, gligen_path)
|
||||
return (gligen,)
|
||||
|
||||
class GLIGENTextBoxApply:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user