Remove head_mask from Ultravox and Swin (#30764)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-16 16:02:41 +00:00 committed by GitHub
parent af506fd76a
commit 0b0acc758e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 17 deletions

View File

@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape
@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values)
@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
)

View File

@ -5,6 +5,7 @@
"""PyTorch Ultravox model."""
import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias
@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
)
hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]
@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]