From 8e67b2557aae7204c697d7a5c61e00754da465be Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 13 Oct 2025 18:21:48 +0800 Subject: [PATCH] [Bugfix] Fix out of bound index issue for Jina-embedding-v3 RoPE with cuda graph (#26687) Signed-off-by: Isotr0py --- .../language/pooling_mteb_test/test_jina.py | 4 ---- vllm/model_executor/models/config.py | 16 +++++++++++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/models/language/pooling_mteb_test/test_jina.py b/tests/models/language/pooling_mteb_test/test_jina.py index dbdf82af33c7..c2065bcd6eb4 100644 --- a/tests/models/language/pooling_mteb_test/test_jina.py +++ b/tests/models/language/pooling_mteb_test/test_jina.py @@ -25,10 +25,6 @@ EMBEDDING_MODELS = [ mteb_score=0.824413164, architecture="XLMRobertaModel", is_matryoshka=True, - # The default max length of the model is 8194, which will crash - # CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid - # avoid this issue. - max_model_len=8192, dtype="float32", ) ] diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ee6a3ba773bb..662f2c9209f4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: @@ -59,16 +59,26 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): class JinaRobertaModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + config = model_config.hf_config if config.position_embedding_type == "rotary": assert config.__class__.__name__ == "XLMRobertaFlashConfig" head_dim = config.hidden_size // config.num_attention_heads + max_position = config.max_position_embeddings + # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause + # out-of-bound index issue at RoPE for long prompts with torch.compile, + # because it can't be divided by triton num_warps(default=4 or 8). + # To deal with this, we increase max_position to multiple of n_warps, + # so that triton kernel won't hit out-of-bound index in RoPE cache. + if not model_config.enforce_eager: + max_position = round_up(max_position, 8) + config.rotary_kwargs = { "head_size": head_dim, "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_position_embeddings, + "max_position": max_position, "base": getattr(config, "rope_theta", config.rotary_emb_base), "rope_scaling": getattr(config, "rope_scaling", None), }