mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 15:49:10 +08:00
[Bugfix][CI/Build] Fix failing Mteb CI (#26638)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
76852017ea
commit
045b396d09
@ -191,7 +191,7 @@ def mteb_test_embed_models(
|
|||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_info.name,
|
model_info.name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=model_info.max_model_len,
|
||||||
**vllm_extra_kwargs,
|
**vllm_extra_kwargs,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
|
|||||||
@ -25,6 +25,11 @@ EMBEDDING_MODELS = [
|
|||||||
mteb_score=0.824413164,
|
mteb_score=0.824413164,
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
is_matryoshka=True,
|
is_matryoshka=True,
|
||||||
|
# The default max length of the model is 8194, which will crash
|
||||||
|
# CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid
|
||||||
|
# avoid this issue.
|
||||||
|
max_model_len=8192,
|
||||||
|
dtype="float32",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,7 @@ ST_PROJECTOR_MODELS = [
|
|||||||
architecture="Gemma3TextModel",
|
architecture="Gemma3TextModel",
|
||||||
mteb_score=0.7473819294684156,
|
mteb_score=0.7473819294684156,
|
||||||
enable_test=True,
|
enable_test=True,
|
||||||
|
dtype="float32",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -369,6 +369,7 @@ class ModelInfo:
|
|||||||
name: str
|
name: str
|
||||||
architecture: str = ""
|
architecture: str = ""
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
hf_dtype: str = "float32"
|
hf_dtype: str = "float32"
|
||||||
hf_overrides: Optional[dict[str, Any]] = None
|
hf_overrides: Optional[dict[str, Any]] = None
|
||||||
default_pooling_type: str = ""
|
default_pooling_type: str = ""
|
||||||
|
|||||||
@ -318,7 +318,11 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
x = x + residual.float() if orig_dtype == torch.float16 else x + residual
|
x = (
|
||||||
|
x.float() + residual.float()
|
||||||
|
if orig_dtype == torch.float16
|
||||||
|
else x + residual
|
||||||
|
)
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
x = x.float()
|
x = x.float()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user