mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:55:00 +08:00
[Fix] Support cls pooling in ModernBertPooler (#20067)
Signed-off-by: shengzhe.li <shengzhe.li@sbintuitions.co.jp>
This commit is contained in:
parent
02c97d9a92
commit
23a04e0895
@ -258,6 +258,7 @@ class ModernBertPooler(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
|
||||||
config.classifier_bias)
|
config.classifier_bias)
|
||||||
|
self.pooling_type = config.classifier_pooling
|
||||||
self.act = nn.GELU()
|
self.act = nn.GELU()
|
||||||
self.norm = nn.LayerNorm(config.hidden_size,
|
self.norm = nn.LayerNorm(config.hidden_size,
|
||||||
eps=config.norm_eps,
|
eps=config.norm_eps,
|
||||||
@ -265,7 +266,13 @@ class ModernBertPooler(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
pooled_output = hidden_states
|
pooled_output = hidden_states
|
||||||
pooled_output = pooled_output.mean(dim=0, keepdim=False)
|
if self.pooling_type == "mean":
|
||||||
|
pooled_output = pooled_output.mean(dim=0, keepdim=False)
|
||||||
|
elif self.pooling_type == "cls":
|
||||||
|
pooled_output = pooled_output[0, :]
|
||||||
|
else:
|
||||||
|
raise ValueError("Pooling type should be either `cls` or `mean`, "
|
||||||
|
f"but got {self.pooling_type}")
|
||||||
pooled_output = self.norm(self.act(self.dense(pooled_output)))
|
pooled_output = self.norm(self.act(self.dense(pooled_output)))
|
||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user