mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[Bugfix] Fix out of bound index issue for Jina-embedding-v3 RoPE with cuda graph (#26687)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
4073c82c4e
commit
8e67b2557a
@ -25,10 +25,6 @@ EMBEDDING_MODELS = [
|
||||
mteb_score=0.824413164,
|
||||
architecture="XLMRobertaModel",
|
||||
is_matryoshka=True,
|
||||
# The default max length of the model is 8194, which will crash
|
||||
# CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid
|
||||
# avoid this issue.
|
||||
max_model_len=8192,
|
||||
dtype="float32",
|
||||
)
|
||||
]
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -59,16 +59,26 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
config = model_config.hf_config
|
||||
|
||||
if config.position_embedding_type == "rotary":
|
||||
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
max_position = config.max_position_embeddings
|
||||
# Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
|
||||
# out-of-bound index issue at RoPE for long prompts with torch.compile,
|
||||
# because it can't be divided by triton num_warps(default=4 or 8).
|
||||
# To deal with this, we increase max_position to multiple of n_warps,
|
||||
# so that triton kernel won't hit out-of-bound index in RoPE cache.
|
||||
if not model_config.enforce_eager:
|
||||
max_position = round_up(max_position, 8)
|
||||
|
||||
config.rotary_kwargs = {
|
||||
"head_size": head_dim,
|
||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||
"max_position": config.max_position_embeddings,
|
||||
"max_position": max_position,
|
||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user