mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-19 08:26:58 +08:00
fp8
This commit is contained in:
parent
9a64e1ae5e
commit
2eb9b81d27
@ -36,10 +36,14 @@ def fp8_linear_forward(cls, original_dtype, input):
|
|||||||
else:
|
else:
|
||||||
return cls.original_forward(input)
|
return cls.original_forward(input)
|
||||||
|
|
||||||
def convert_fp8_linear(module, original_dtype):
|
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
|
||||||
setattr(module, "fp8_matmul_enabled", True)
|
setattr(module, "fp8_matmul_enabled", True)
|
||||||
|
|
||||||
|
|
||||||
for name, module in module.named_modules():
|
for name, module in module.named_modules():
|
||||||
if isinstance(module, nn.Linear):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
original_forward = module.forward
|
if isinstance(module, nn.Linear):
|
||||||
setattr(module, "original_forward", original_forward)
|
print(name)
|
||||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
original_forward = module.forward
|
||||||
|
setattr(module, "original_forward", original_forward)
|
||||||
|
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||||
|
|||||||
@ -153,7 +153,6 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
base_path = os.path.join(download_path, (model.split("/")[-1]))
|
base_path = os.path.join(download_path, (model.split("/")[-1]))
|
||||||
download_path = base_path
|
download_path = base_path
|
||||||
repo_id = model
|
repo_id = model
|
||||||
|
|
||||||
|
|
||||||
if "2b" in model:
|
if "2b" in model:
|
||||||
scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json')
|
scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json')
|
||||||
@ -193,13 +192,15 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
#fp8
|
#fp8
|
||||||
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
||||||
for name, param in transformer.named_parameters():
|
for name, param in transformer.named_parameters():
|
||||||
params_to_keep = {"patch_embed", "lora", "pos_embedding"}
|
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"}
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
param.data = param.data.to(torch.float8_e4m3fn)
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
if fp8_transformer == "fastmode":
|
if fp8_transformer == "fastmode":
|
||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
convert_fp8_linear(transformer, dtype)
|
if "1.5" in model:
|
||||||
|
params_to_keep = {"norm","ff"}
|
||||||
|
convert_fp8_linear(transformer, dtype, params_to_keep=params_to_keep)
|
||||||
|
|
||||||
with open(scheduler_path) as f:
|
with open(scheduler_path) as f:
|
||||||
scheduler_config = json.load(f)
|
scheduler_config = json.load(f)
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -861,8 +861,7 @@ class CogVideoSampler:
|
|||||||
pipe.transformer.fastercache_counter = 0
|
pipe.transformer.fastercache_counter = 0
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||||
autocastcondition = False ##todo
|
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
|
||||||
with autocast_context:
|
with autocast_context:
|
||||||
latents = pipeline["pipe"](
|
latents = pipeline["pipe"](
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user