[Model] Use sigmoid for single-label classification (#18313)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-05-18 07:00:09 -07:00 committed by GitHub
parent 1a8f68bb90
commit 908733aca7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -242,9 +242,16 @@ class PoolerHead(nn.Module):
if self.softmax:
if isinstance(pooled_data, list):
pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
pooled_data = [
F.softmax(data, dim=-1)
if data.shape[-1] >= 2 else F.sigmoid(data)
for data in pooled_data
]
else:
pooled_data = F.softmax(pooled_data, dim=-1)
if pooled_data.shape[-1] >= 2:
pooled_data = F.softmax(pooled_data, dim=-1)
else:
pooled_data = F.sigmoid(pooled_data)
return pooled_data