[Bugfix] In LongRoPE, decide short vs long based on max_model_len (#27431)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-10-28 08:00:56 -04:00 committed by GitHub
parent 7a865f2325
commit 44b5ce956d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 11 deletions

View File

@ -29,7 +29,7 @@ def multimodal_server(): # noqa: F811
"--dtype",
"half",
"--max-model-len",
"12800",
"4096",
"--enforce-eager",
# lora config below
"--enable-lora",

View File

@ -2142,8 +2142,18 @@ def _get_and_verify_max_len(
# If the user didn't specify `max_model_len`, then use that derived from
# the model config as a default value.
if max_model_len is None:
max_model_len = int(derived_max_model_len)
# For LongRoPE, default to original_max_position_embeddings to avoid
# performance degradation for shorter sequences
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
max_model_len = int(
getattr(
hf_config, "original_max_position_embeddings", derived_max_model_len
)
)
else:
max_model_len = int(derived_max_model_len)
max_model_len = current_platform.check_max_model_len(max_model_len)
# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
elif max_model_len > derived_max_model_len:

View File

@ -5,8 +5,13 @@ import math
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from .common import rotate_neox
logger = init_logger(__name__)
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
@ -43,6 +48,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self.short_factor = short_factor
self.long_factor = long_factor
# Force long factors if max_model_len (runtime max length) exceeds
# original_max_position_embeddings to prevent KV cache invalidation when
# sequences cross this threshold during generation
max_model_len = get_current_vllm_config().model_config.max_model_len
self.use_long_rope = max_model_len > original_max_position_embeddings
if self.use_long_rope:
logger.warning_once(
"Using LongRoPE scaling factors. This enables longer "
"contexts (%d tokens vs original %d tokens) at the cost of "
"some performance degradation for shorter sequences. If "
"this is not desired, set `max_model_len` to be at most %d.",
max_position_embeddings,
original_max_position_embeddings,
original_max_position_embeddings,
)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
@ -112,15 +133,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
k = self.original_max_position_embeddings
long_prompt_offset = (
torch.any(positions > k).float() * torch.full_like(positions, k)
).long()
idx = (
torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None
else positions
)
if self.use_long_rope:
k = self.original_max_position_embeddings
long_prompt_offset = torch.full_like(positions, k).long()
idx = torch.add(positions, long_prompt_offset)
else:
idx = positions
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)