mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-11 06:54:26 +08:00
add chroma-radiance-x0 mode (#11197)
This commit is contained in:
parent
cabc4d351f
commit
b9fb542703
@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
|
|||||||
nerf_final_head_type: str
|
nerf_final_head_type: str
|
||||||
# None means use the same dtype as the model.
|
# None means use the same dtype as the model.
|
||||||
nerf_embedder_dtype: Optional[torch.dtype]
|
nerf_embedder_dtype: Optional[torch.dtype]
|
||||||
|
use_x0: bool
|
||||||
|
|
||||||
class ChromaRadiance(Chroma):
|
class ChromaRadiance(Chroma):
|
||||||
"""
|
"""
|
||||||
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
|
|||||||
self.skip_dit = []
|
self.skip_dit = []
|
||||||
self.lite = False
|
self.lite = False
|
||||||
|
|
||||||
|
if params.use_x0:
|
||||||
|
self.register_buffer("__x0__", torch.tensor([]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _nerf_final_layer(self) -> nn.Module:
|
def _nerf_final_layer(self) -> nn.Module:
|
||||||
if self.params.nerf_final_head_type == "linear":
|
if self.params.nerf_final_head_type == "linear":
|
||||||
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
|
|||||||
params_dict |= overrides
|
params_dict |= overrides
|
||||||
return params.__class__(**params_dict)
|
return params.__class__(**params_dict)
|
||||||
|
|
||||||
|
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||||
|
|
||||||
|
# non zero during training to prevent 0 div
|
||||||
|
eps = 0.0
|
||||||
|
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
|
|||||||
transformer_options,
|
transformer_options,
|
||||||
attn_mask=kwargs.get("attention_mask", None),
|
attn_mask=kwargs.get("attention_mask", None),
|
||||||
)
|
)
|
||||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
|
||||||
|
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||||
|
|
||||||
|
# If x0 variant → v-pred, just return this instead
|
||||||
|
if hasattr(self, "__x0__"):
|
||||||
|
out = self._apply_x0_residual(out, img, timestep)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
|
if "__x0__" in state_dict_keys: # x0 pred
|
||||||
|
dit_config["use_x0"] = True
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user