add chroma-radiance-x0 mode (#11197)

This commit is contained in:
Lodestone 2025-12-09 11:33:29 +07:00 committed by GitHub
parent cabc4d351f
commit b9fb542703
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 2 deletions

View File

@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str nerf_final_head_type: str
# None means use the same dtype as the model. # None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype] nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
""" """
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
self.skip_dit = [] self.skip_dit = []
self.lite = False self.lite = False
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
@property @property
def _nerf_final_layer(self) -> nn.Module: def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear": if self.params.nerf_final_head_type == "linear":
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
params_dict |= overrides params_dict |= overrides
return params.__class__(**params_dict) return params.__class__(**params_dict)
def _apply_x0_residual(self, predicted, noisy, timesteps):
# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
def _forward( def _forward(
self, self,
x: Tensor, x: Tensor,
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
transformer_options, transformer_options,
attn_mask=kwargs.get("attention_mask", None), attn_mask=kwargs.get("attention_mask", None),
) )
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out

View File

@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512 dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32 dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys