[Bugfix][CI/Build] Fix failing Mteb CI (#26638)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-12 17:42:42 +08:00 committed by GitHub
parent 76852017ea
commit 045b396d09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 13 additions and 2 deletions

View File

@ -191,7 +191,7 @@ def mteb_test_embed_models(
with vllm_runner(
model_info.name,
runner="pooling",
max_model_len=None,
max_model_len=model_info.max_model_len,
**vllm_extra_kwargs,
) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config

View File

@ -25,6 +25,11 @@ EMBEDDING_MODELS = [
mteb_score=0.824413164,
architecture="XLMRobertaModel",
is_matryoshka=True,
# The default max length of the model is 8194, which will crash
# CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid
# avoid this issue.
max_model_len=8192,
dtype="float32",
)
]

View File

@ -23,6 +23,7 @@ ST_PROJECTOR_MODELS = [
architecture="Gemma3TextModel",
mteb_score=0.7473819294684156,
enable_test=True,
dtype="float32",
),
]

View File

@ -369,6 +369,7 @@ class ModelInfo:
name: str
architecture: str = ""
dtype: str = "auto"
max_model_len: Optional[int] = None
hf_dtype: str = "float32"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = ""

View File

@ -318,7 +318,11 @@ class GemmaRMSNorm(CustomOp):
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
x = x + residual.float() if orig_dtype == torch.float16 else x + residual
x = (
x.float() + residual.float()
if orig_dtype == torch.float16
else x + residual
)
residual = x
x = x.float()