mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[Bugfix] Fix hidden_size for multimodal classification model (#24501)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
b9a1c4c8a2
commit
9ad0688e43
@ -255,7 +255,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .utils import maybe_prefix
|
from .utils import get_model_hidden_size, maybe_prefix
|
||||||
|
|
||||||
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
|
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
|
||||||
SupportsCrossEncoding):
|
SupportsCrossEncoding):
|
||||||
@ -263,9 +263,10 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
hidden_size = get_model_hidden_size(config)
|
||||||
|
|
||||||
self.score = ReplicatedLinear(
|
self.score = ReplicatedLinear(
|
||||||
config.hidden_size,
|
hidden_size,
|
||||||
config.num_labels,
|
config.num_labels,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=torch.float32,
|
params_dtype=torch.float32,
|
||||||
|
|||||||
@ -761,3 +761,10 @@ def fast_topk(values: torch.Tensor, topk: int,
|
|||||||
else:
|
else:
|
||||||
# Use topk for efficiency with larger k values
|
# Use topk for efficiency with larger k values
|
||||||
return torch.topk(values, topk, dim=dim)
|
return torch.topk(values, topk, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
|
||||||
|
if hasattr(hf_config, "hidden_size"):
|
||||||
|
return hf_config.hidden_size
|
||||||
|
text_config = hf_config.get_text_config()
|
||||||
|
return text_config.hidden_size
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user