mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 11:17:04 +08:00
signature
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
ad85bded6c
commit
762ca9e38a
@ -534,9 +534,7 @@ class WhisperEncoder(nn.Module):
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
|
||||
def forward_conv(
|
||||
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
def forward_conv(self, input_features: torch.Tensor) -> torch.Tensor:
|
||||
embeds = nn.functional.gelu(self.conv1(input_features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
@ -557,7 +555,7 @@ class WhisperEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
def forward(self, input_features: torch.Tensor):
|
||||
hidden_states = self.forward_conv(input_features)
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user