[Bugfix] Fix out of bound index issue for Jina-embedding-v3 RoPE with cuda graph (#26687)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-13 18:21:48 +08:00 committed by GitHub
parent 4073c82c4e
commit 8e67b2557a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 7 deletions

View File

@ -25,10 +25,6 @@ EMBEDDING_MODELS = [
mteb_score=0.824413164,
architecture="XLMRobertaModel",
is_matryoshka=True,
# The default max length of the model is 8194, which will crash
# CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid
# avoid this issue.
max_model_len=8192,
dtype="float32",
)
]

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
if TYPE_CHECKING:
@ -59,16 +59,26 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
config = model_config.hf_config
if config.position_embedding_type == "rotary":
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
head_dim = config.hidden_size // config.num_attention_heads
max_position = config.max_position_embeddings
# Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
# out-of-bound index issue at RoPE for long prompts with torch.compile,
# because it can't be divided by triton num_warps(default=4 or 8).
# To deal with this, we increase max_position to multiple of n_warps,
# so that triton kernel won't hit out-of-bound index in RoPE cache.
if not model_config.enforce_eager:
max_position = round_up(max_position, 8)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"max_position": max_position,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None),
}