mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-10 22:44:54 +08:00
Remove useless code. (#10223)
This commit is contained in:
parent
187f43696d
commit
195e0b0639
@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.source_sample_rate = source_sample_rate
|
self.source_sample_rate = source_sample_rate
|
||||||
|
|
||||||
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
|
|
||||||
|
|
||||||
self.transform = transforms.Compose([
|
self.transform = transforms.Compose([
|
||||||
transforms.Normalize(0.5, 0.5),
|
transforms.Normalize(0.5, 0.5),
|
||||||
])
|
])
|
||||||
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
|
|||||||
self.scale_factor = 0.1786
|
self.scale_factor = 0.1786
|
||||||
self.shift_factor = -1.9091
|
self.shift_factor = -1.9091
|
||||||
|
|
||||||
def load_audio(self, audio_path):
|
|
||||||
audio, sr = torchaudio.load(audio_path)
|
|
||||||
return audio, sr
|
|
||||||
|
|
||||||
def forward_mel(self, audios):
|
def forward_mel(self, audios):
|
||||||
mels = []
|
mels = []
|
||||||
for i in range(len(audios)):
|
for i in range(len(audios)):
|
||||||
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
|
|||||||
latent = self.dcae.encoder(mel.unsqueeze(0))
|
latent = self.dcae.encoder(mel.unsqueeze(0))
|
||||||
latents.append(latent)
|
latents.append(latent)
|
||||||
latents = torch.cat(latents, dim=0)
|
latents = torch.cat(latents, dim=0)
|
||||||
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
|
|
||||||
latents = (latents - self.shift_factor) * self.scale_factor
|
latents = (latents - self.shift_factor) * self.scale_factor
|
||||||
return latents
|
return latents
|
||||||
# return latents, latent_lengths
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(self, latents, audio_lengths=None, sr=None):
|
def decode(self, latents, audio_lengths=None, sr=None):
|
||||||
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
|
|||||||
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
wav = self.vocoder.decode(mels[0]).squeeze(1)
|
||||||
|
|
||||||
if sr is not None:
|
if sr is not None:
|
||||||
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
|
|
||||||
wav = torchaudio.functional.resample(wav, 44100, sr)
|
wav = torchaudio.functional.resample(wav, 44100, sr)
|
||||||
# wav = resampler(wav)
|
|
||||||
else:
|
else:
|
||||||
sr = 44100
|
sr = 44100
|
||||||
pred_wavs.append(wav)
|
pred_wavs.append(wav)
|
||||||
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
|
|||||||
if audio_lengths is not None:
|
if audio_lengths is not None:
|
||||||
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
|
||||||
return torch.stack(pred_wavs)
|
return torch.stack(pred_wavs)
|
||||||
# return sr, pred_wavs
|
|
||||||
|
|
||||||
def forward(self, audios, audio_lengths=None, sr=None):
|
def forward(self, audios, audio_lengths=None, sr=None):
|
||||||
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user