mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:55:36 +08:00
[Bugfix] Fix out of bound index issue for Jina-embedding-v3 RoPE with cuda graph (#26687)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
4073c82c4e
commit
8e67b2557a
@ -25,10 +25,6 @@ EMBEDDING_MODELS = [
|
|||||||
mteb_score=0.824413164,
|
mteb_score=0.824413164,
|
||||||
architecture="XLMRobertaModel",
|
architecture="XLMRobertaModel",
|
||||||
is_matryoshka=True,
|
is_matryoshka=True,
|
||||||
# The default max length of the model is 8194, which will crash
|
|
||||||
# CUDAGraph due to odd length for Gemm. We set it to 8192 to avoid
|
|
||||||
# avoid this issue.
|
|
||||||
max_model_len=8192,
|
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -59,16 +59,26 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
|||||||
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
model_config = vllm_config.model_config
|
||||||
|
config = model_config.hf_config
|
||||||
|
|
||||||
if config.position_embedding_type == "rotary":
|
if config.position_embedding_type == "rotary":
|
||||||
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
|
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
|
||||||
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
max_position = config.max_position_embeddings
|
||||||
|
# Jina-embeddings-v3 has max_position_embeddings=8194, which will cause
|
||||||
|
# out-of-bound index issue at RoPE for long prompts with torch.compile,
|
||||||
|
# because it can't be divided by triton num_warps(default=4 or 8).
|
||||||
|
# To deal with this, we increase max_position to multiple of n_warps,
|
||||||
|
# so that triton kernel won't hit out-of-bound index in RoPE cache.
|
||||||
|
if not model_config.enforce_eager:
|
||||||
|
max_position = round_up(max_position, 8)
|
||||||
|
|
||||||
config.rotary_kwargs = {
|
config.rotary_kwargs = {
|
||||||
"head_size": head_dim,
|
"head_size": head_dim,
|
||||||
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
||||||
"max_position": config.max_position_embeddings,
|
"max_position": max_position,
|
||||||
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user