[Fix] Support cls pooling in ModernBertPooler (#20067)

Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
This commit is contained in:
lsz05 2025-06-26 04:07:45 +09:00 committed by GitHub
parent 02c97d9a92
commit 23a04e0895
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -258,6 +258,7 @@ class ModernBertPooler(nn.Module):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.pooling_type = config.classifier_pooling
self.act = nn.GELU()
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
@ -265,7 +266,13 @@ class ModernBertPooler(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
pooled_output = hidden_states
pooled_output = pooled_output.mean(dim=0, keepdim=False)
if self.pooling_type == "mean":
pooled_output = pooled_output.mean(dim=0, keepdim=False)
elif self.pooling_type == "cls":
pooled_output = pooled_output[0, :]
else:
raise ValueError("Pooling type should be either `cls` or `mean`, "
f"but got {self.pooling_type}")
pooled_output = self.norm(self.act(self.dense(pooled_output)))
return pooled_output