mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
This should fix the import errors but will break if the custom nodes actually try to use the class.
64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from comfy.ldm.flux.layers import (
|
|
MLPEmbedder,
|
|
RMSNorm,
|
|
ModulationOut,
|
|
)
|
|
|
|
# TODO: remove this in a few months
|
|
SingleStreamBlock = None
|
|
DoubleStreamBlock = None
|
|
|
|
|
|
class ChromaModulationOut(ModulationOut):
|
|
@classmethod
|
|
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
|
return cls(
|
|
shift=tensor[:, offset : offset + 1, :],
|
|
scale=tensor[:, offset + 1 : offset + 2, :],
|
|
gate=tensor[:, offset + 2 : offset + 3, :],
|
|
)
|
|
|
|
|
|
|
|
|
|
class Approximator(nn.Module):
|
|
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
|
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
|
|
|
@property
|
|
def device(self):
|
|
# Get the device of the module (assumes all parameters are on the same device)
|
|
return next(self.parameters()).device
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
x = self.in_proj(x)
|
|
|
|
for layer, norms in zip(self.layers, self.norms):
|
|
x = x + layer(norms(x))
|
|
|
|
x = self.out_proj(x)
|
|
|
|
return x
|
|
|
|
|
|
class LastLayer(nn.Module):
|
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
|
shift, scale = vec
|
|
shift = shift.squeeze(1)
|
|
scale = scale.squeeze(1)
|
|
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
|
|
x = self.linear(x)
|
|
return x
|