signature

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-12-24 14:29:36 +00:00
parent ad85bded6c
commit 762ca9e38a

View File

@ -534,9 +534,7 @@ class WhisperEncoder(nn.Module):
sinusoids(*self.embed_positions.weight.shape)
)
def forward_conv(
self, input_features: torch.Tensor | list[torch.Tensor]
) -> torch.Tensor:
def forward_conv(self, input_features: torch.Tensor) -> torch.Tensor:
embeds = nn.functional.gelu(self.conv1(input_features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.transpose(-1, -2)
@ -557,7 +555,7 @@ class WhisperEncoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
def forward(self, input_features: torch.Tensor):
hidden_states = self.forward_conv(input_features)
return self.forward_layers(hidden_states)