refactoring: use the same code for "scan task" and realtime DB population

This commit is contained in:
bigcat88 2025-08-25 13:31:56 +03:00
parent d7464e9e73
commit 09dabf95bc
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
7 changed files with 178 additions and 98 deletions

View File

@ -117,21 +117,25 @@ def upgrade() -> None:
{"name": "output", "tag_type": "system"}, {"name": "output", "tag_type": "system"},
# Core tags # Core tags
{"name": "checkpoint", "tag_type": "system"}, {"name": "configs", "tag_type": "system"},
{"name": "lora", "tag_type": "system"}, {"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"}, {"name": "vae", "tag_type": "system"},
{"name": "text-encoder", "tag_type": "system"}, {"name": "text_encoders", "tag_type": "system"},
{"name": "clip-vision", "tag_type": "system"}, {"name": "diffusion_models", "tag_type": "system"},
{"name": "embedding", "tag_type": "system"}, {"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"}, {"name": "controlnet", "tag_type": "system"},
{"name": "upscale", "tag_type": "system"},
{"name": "diffusion-model", "tag_type": "system"},
{"name": "hypernetwork", "tag_type": "system"},
{"name": "vae-approx", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"}, {"name": "gligen", "tag_type": "system"},
{"name": "style-model", "tag_type": "system"}, {"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"}, {"name": "photomaker", "tag_type": "system"},
{"name": "classifier", "tag_type": "system"}, {"name": "classifiers", "tag_type": "system"},
# Extra basic tags (used for vae_approx, ...)
{"name": "encoder", "tag_type": "system"}, {"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"},
], ],

99
app/_assets_helpers.py Normal file
View File

@ -0,0 +1,99 @@
import os
from pathlib import Path
from typing import Optional, Literal, Sequence
import folder_paths
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, normalize_tags([root_category, *parent_parts])
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]

View File

@ -1,7 +1,7 @@
import logging
import mimetypes import mimetypes
import os import os
from typing import Optional, Sequence from typing import Optional, Sequence
from pathlib import Path
from comfy.cli_args import args from comfy.cli_args import args
from comfy_api.internal import async_to_sync from comfy_api.internal import async_to_sync
@ -26,6 +26,7 @@ from .database.services import (
create_asset_info_for_existing_asset, create_asset_info_for_existing_asset,
) )
from .api import schemas_out from .api import schemas_out
from ._assets_helpers import get_name_and_tags_from_asset_path
async def asset_exists(*, asset_hash: str) -> bool: async def asset_exists(*, asset_hash: str) -> bool:
@ -33,16 +34,20 @@ async def asset_exists(*, asset_hash: str) -> bool:
return await asset_exists_by_hash(session, asset_hash=asset_hash) return await asset_exists_by_hash(session, asset_hash=asset_hash)
def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if not args.disable_model_processing: if not args.disable_model_processing:
p = Path(file_name) if tags is None:
dir_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] tags = []
async_to_sync.AsyncToSyncConverter.run_async_in_thread( try:
add_local_asset, asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
tags=list(dict.fromkeys([*tags, *dir_parts])), async_to_sync.AsyncToSyncConverter.run_async_in_thread(
file_name=p.name, add_local_asset,
file_path=file_path, tags=list(dict.fromkeys([*path_tags, *tags])),
) file_name=asset_name,
file_path=file_path,
)
except ValueError:
logging.exception("Cant parse '%s' as an asset file path.", file_path)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:

View File

@ -7,10 +7,11 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal, Optional, Sequence from typing import Literal, Optional, Sequence
import folder_paths
from . import assets_manager from . import assets_manager
from .api import schemas_out from .api import schemas_out
from ._assets_helpers import get_comfy_models_folders
import folder_paths
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@ -36,7 +37,7 @@ class ScanProgress:
errors: int = 0 errors: int = 0
last_error: Optional[str] = None last_error: Optional[str] = None
# Optional details for diagnostics # Optional details for diagnostics (e.g., files per bucket)
details: dict[str, int] = field(default_factory=dict) details: dict[str, int] = field(default_factory=dict)
@ -49,8 +50,6 @@ def _new_scan_id(root: RootType) -> str:
def current_statuses() -> schemas_out.AssetScanStatusResponse: def current_statuses() -> schemas_out.AssetScanStatusResponse:
# make shallow copies to avoid external mutation
states = [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
return schemas_out.AssetScanStatusResponse( return schemas_out.AssetScanStatusResponse(
scans=[ scans=[
schemas_out.AssetScanStatus( schemas_out.AssetScanStatus(
@ -65,7 +64,7 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse:
errors=s.errors, errors=s.errors,
last_error=s.last_error, last_error=s.last_error,
) )
for s in states for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
] ]
) )
@ -94,15 +93,12 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
results: list[ScanProgress] = [] results: list[ScanProgress] = []
for root in normalized: for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
# already running; return the live progress object
results.append(PROGRESS_BY_ROOT[root]) results.append(PROGRESS_BY_ROOT[root])
continue continue
# Create fresh progress
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog PROGRESS_BY_ROOT[root] = prog
# Start task
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}") task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}")
RUNNING_TASKS[root] = task RUNNING_TASKS[root] = task
results.append(prog) results.append(prog)
@ -151,24 +147,21 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
prog.last_error = str(exc) prog.last_error = str(exc)
finally: finally:
prog.finished_at = time.time() prog.finished_at = time.time()
# Drop the task entry if it's the current one
t = RUNNING_TASKS.get(root) t = RUNNING_TASKS.get(root)
if t and t.done(): if t and t.done():
RUNNING_TASKS.pop(root, None) RUNNING_TASKS.pop(root, None)
async def _scan_models(prog: ScanProgress) -> None: async def _scan_models(prog: ScanProgress) -> None:
# Iterate all folder_names whose base paths lie under the Comfy 'models' directory """
models_root = os.path.abspath(os.path.join(folder_paths.base_path, "models")) Scan all configured model buckets from folder_paths.folder_names_and_paths,
restricted to entries whose base paths lie under folder_paths.models_dir
(per get_comfy_models_folders). We trust those mappings and do not try to
infer anything else here.
"""
targets: list[tuple[str, list[str]]] = get_comfy_models_folders()
# Build list of (folder_name, base_paths[]) that are configured for this category. plans: list[str] = [] # absolute file paths to ingest
# If any path for the category lies under 'models', include the category.
targets: list[tuple[str, list[str]]] = []
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
plans: list[tuple[str, str]] = [] # (abs_path, file_name_for_tags)
per_bucket: dict[str, int] = {} per_bucket: dict[str, int] = {}
for folder_name, bases in targets: for folder_name, bases in targets:
@ -198,13 +191,12 @@ async def _scan_models(prog: ScanProgress) -> None:
try: try:
if not os.path.getsize(abs_path): if not os.path.getsize(abs_path):
continue continue # skip empty files
except OSError as e: except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e) LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue continue
file_name_for_tags = os.path.join(folder_name, rel_path) plans.append(abs_path)
plans.append((abs_path, file_name_for_tags))
count_valid += 1 count_valid += 1
if count_valid: if count_valid:
@ -221,16 +213,12 @@ async def _scan_models(prog: ScanProgress) -> None:
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
for abs_path, name_for_tags in plans: for abs_path in plans:
async def worker(fp_abs: str = abs_path, fn_rel: str = name_for_tags): async def worker(fp_abs: str = abs_path):
try: try:
# Offload sync ingestion into a thread # Offload sync ingestion into a thread; populate_db_with_asset
await asyncio.to_thread( # derives name and tags from the path using _assets_helpers.
assets_manager.populate_db_with_asset, await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
["models"],
fn_rel,
fp_abs,
)
except Exception as e: except Exception as e:
prog.errors += 1 prog.errors += 1
prog.last_error = str(e) prog.last_error = str(e)
@ -260,7 +248,10 @@ def _count_files_in_tree(base_abs: str) -> int:
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None: async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None:
# Guard: base_dir must be a directory """
Generic scanner for input/output roots. We pass only the absolute path to
populate_db_with_asset and let it derive the relative name and tags.
"""
base_abs = os.path.abspath(base_dir) base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs): if not os.path.isdir(base_abs):
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs) LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs)
@ -272,24 +263,27 @@ async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames: for name in filenames:
rel = os.path.relpath(os.path.join(dirpath, name), base_abs) abs_path = os.path.abspath(os.path.join(dirpath, name))
abs_path = os.path.join(base_abs, rel)
# Safety: ensure within base # Safety: ensure within base
try: try:
if os.path.commonpath([os.path.abspath(abs_path), base_abs]) != base_abs: if os.path.commonpath([abs_path, base_abs]) != base_abs:
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path) LOGGER.warning("Skipping path outside root %s: %s", root, abs_path)
continue continue
except ValueError: except ValueError:
continue continue
async def worker(fp_abs: str = abs_path, fn_rel: str = rel): # Skip empty files and handle stat errors
try:
if not os.path.getsize(abs_path):
continue
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
async def worker(fp_abs: str = abs_path):
try: try:
await asyncio.to_thread( await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
assets_manager.populate_db_with_asset,
[root],
fn_rel,
fp_abs,
)
except Exception as e: except Exception as e:
prog.errors += 1 prog.errors += 1
prog.last_error = str(e) prog.last_error = str(e)

View File

@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
from .timeutil import utcnow from .timeutil import utcnow
from .._assets_helpers import normalize_tags
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
@ -471,7 +471,7 @@ async def set_asset_info_tags(
Replace the tag set on an AssetInfo with `tags`. Idempotent. Replace the tag set on an AssetInfo with `tags`. Idempotent.
Creates missing tag names as 'user'. Creates missing tag names as 'user'.
""" """
desired = _normalize_tags(tags) desired = normalize_tags(tags)
# current links # current links
current = set( current = set(
@ -691,7 +691,7 @@ async def add_tags_to_asset_info(
if not info: if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found") raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = _normalize_tags(tags) norm = normalize_tags(tags)
if not norm: if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id) total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total} return {"added": [], "already_present": [], "total_tags": total}
@ -753,7 +753,7 @@ async def remove_tags_from_asset_info(
if not info: if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found") raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = _normalize_tags(tags) norm = normalize_tags(tags)
if not norm: if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id) total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total} return {"removed": [], "not_present": [], "total_tags": total}
@ -784,12 +784,8 @@ async def remove_tags_from_asset_info(
return {"removed": to_remove, "not_present": not_present, "total_tags": total} return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def _normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
wanted = _normalize_tags(list(names)) wanted = normalize_tags(list(names))
if not wanted: if not wanted:
return [] return []
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
@ -808,8 +804,8 @@ def _apply_tag_filters(
exclude_tags: Optional[Sequence[str]], exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select: ) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present.""" """include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = _normalize_tags(include_tags) include_tags = normalize_tags(include_tags)
exclude_tags = _normalize_tags(exclude_tags) exclude_tags = normalize_tags(exclude_tags)
if include_tags: if include_tags:
for tag_name in include_tags: for tag_name in include_tags:

View File

@ -29,6 +29,7 @@ import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from einops import rearrange from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
from app.assets_manager import populate_db_with_asset
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap DISABLE_MMAP = args.disable_mmap
@ -102,6 +103,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
else: else:
sd = pl_sd sd = pl_sd
populate_db_with_asset(ckpt)
return (sd, metadata) if return_metadata else sd return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None): def save_torch_file(sd, ckpt, metadata=None):

View File

@ -31,7 +31,6 @@ from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions from comfy_api.version_list import supported_versions
from comfy_api.latest import io, ComfyExtension from comfy_api.latest import io, ComfyExtension
from app.assets_manager import populate_db_with_asset
import comfy.clip_vision import comfy.clip_vision
@ -555,9 +554,7 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name): def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
return out
class CheckpointLoaderSimple: class CheckpointLoaderSimple:
@classmethod @classmethod
@ -579,7 +576,6 @@ class CheckpointLoaderSimple:
def load_checkpoint(self, ckpt_name): def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
return out[:3] return out[:3]
class DiffusersLoader: class DiffusersLoader:
@ -622,7 +618,6 @@ class unCLIPCheckpointLoader:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
return out return out
class CLIPSetLastLayer: class CLIPSetLastLayer:
@ -681,7 +676,6 @@ class LoraLoader:
self.loaded_lora = (lora_path, lora) self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
populate_db_with_asset(["models", "lora"], lora_name, lora_path)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader): class LoraLoaderModelOnly(LoraLoader):
@ -746,15 +740,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
encoder_path = folder_paths.get_full_path_or_raise("vae_approx", encoder) enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
populate_db_with_asset(["models", "vae-approx", "encoder"], name, encoder_path)
enc = comfy.utils.load_torch_file(encoder_path)
for k in enc: for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k] sd["taesd_encoder.{}".format(k)] = enc[k]
decoder_path = folder_paths.get_full_path_or_raise("vae_approx", decoder) dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
populate_db_with_asset(["models", "vae-approx", "decoder"], name, decoder_path)
dec = comfy.utils.load_torch_file(decoder_path)
for k in dec: for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k] sd["taesd_decoder.{}".format(k)] = dec[k]
@ -787,7 +777,6 @@ class VAELoader:
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
populate_db_with_asset(["models", "vae"], vae_name, vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
return (vae,) return (vae,)
@ -807,7 +796,6 @@ class ControlNetLoader:
controlnet = comfy.controlnet.load_controlnet(controlnet_path) controlnet = comfy.controlnet.load_controlnet(controlnet_path)
if controlnet is None: if controlnet is None:
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.") raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path)
return (controlnet,) return (controlnet,)
class DiffControlNetLoader: class DiffControlNetLoader:
@ -824,7 +812,6 @@ class DiffControlNetLoader:
def load_controlnet(self, model, control_net_name): def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model) controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path)
return (controlnet,) return (controlnet,)
@ -932,7 +919,6 @@ class UNETLoader:
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
populate_db_with_asset(["models", "diffusion-model"], unet_name, unet_path)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader:
@ -960,7 +946,6 @@ class CLIPLoader:
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
populate_db_with_asset(["models", "text-encoder"], clip_name, clip_path)
return (clip,) return (clip,)
class DualCLIPLoader: class DualCLIPLoader:
@ -991,8 +976,6 @@ class DualCLIPLoader:
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
populate_db_with_asset(["models", "text-encoder"], clip_name1, clip_path1)
populate_db_with_asset(["models", "text-encoder"], clip_name2, clip_path2)
return (clip,) return (clip,)
class CLIPVisionLoader: class CLIPVisionLoader:
@ -1010,7 +993,6 @@ class CLIPVisionLoader:
clip_vision = comfy.clip_vision.load(clip_path) clip_vision = comfy.clip_vision.load(clip_path)
if clip_vision is None: if clip_vision is None:
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
populate_db_with_asset(["models", "clip-vision"], clip_name, clip_path)
return (clip_vision,) return (clip_vision,)
class CLIPVisionEncode: class CLIPVisionEncode:
@ -1045,7 +1027,6 @@ class StyleModelLoader:
def load_style_model(self, style_model_name): def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name) style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path) style_model = comfy.sd.load_style_model(style_model_path)
populate_db_with_asset(["models", "style-model"], style_model_name, style_model_path)
return (style_model,) return (style_model,)
@ -1143,7 +1124,6 @@ class GLIGENLoader:
def load_gligen(self, gligen_name): def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name) gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path) gligen = comfy.sd.load_gligen(gligen_path)
populate_db_with_asset(["models", "gligen"], gligen_name, gligen_path)
return (gligen,) return (gligen,)
class GLIGENTextBoxApply: class GLIGENTextBoxApply: