mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 05:35:02 +08:00
[Bugfix] In LongRoPE, decide short vs long based on max_model_len (#27431)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
7a865f2325
commit
44b5ce956d
@ -29,7 +29,7 @@ def multimodal_server(): # noqa: F811
|
||||
"--dtype",
|
||||
"half",
|
||||
"--max-model-len",
|
||||
"12800",
|
||||
"4096",
|
||||
"--enforce-eager",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
|
||||
@ -2142,8 +2142,18 @@ def _get_and_verify_max_len(
|
||||
# If the user didn't specify `max_model_len`, then use that derived from
|
||||
# the model config as a default value.
|
||||
if max_model_len is None:
|
||||
max_model_len = int(derived_max_model_len)
|
||||
# For LongRoPE, default to original_max_position_embeddings to avoid
|
||||
# performance degradation for shorter sequences
|
||||
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
|
||||
max_model_len = int(
|
||||
getattr(
|
||||
hf_config, "original_max_position_embeddings", derived_max_model_len
|
||||
)
|
||||
)
|
||||
else:
|
||||
max_model_len = int(derived_max_model_len)
|
||||
max_model_len = current_platform.check_max_model_len(max_model_len)
|
||||
|
||||
# If the user specified a max length, make sure it is smaller than the
|
||||
# derived length from the HF model config.
|
||||
elif max_model_len > derived_max_model_len:
|
||||
|
||||
@ -5,8 +5,13 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .common import rotate_neox
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
"""Phi3 family of models scaled rotary embedding.
|
||||
@ -43,6 +48,22 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
self.short_factor = short_factor
|
||||
self.long_factor = long_factor
|
||||
|
||||
# Force long factors if max_model_len (runtime max length) exceeds
|
||||
# original_max_position_embeddings to prevent KV cache invalidation when
|
||||
# sequences cross this threshold during generation
|
||||
max_model_len = get_current_vllm_config().model_config.max_model_len
|
||||
self.use_long_rope = max_model_len > original_max_position_embeddings
|
||||
if self.use_long_rope:
|
||||
logger.warning_once(
|
||||
"Using LongRoPE scaling factors. This enables longer "
|
||||
"contexts (%d tokens vs original %d tokens) at the cost of "
|
||||
"some performance degradation for shorter sequences. If "
|
||||
"this is not desired, set `max_model_len` to be at most %d.",
|
||||
max_position_embeddings,
|
||||
original_max_position_embeddings,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
if scale <= 1.0:
|
||||
scaling_factor = 1.0
|
||||
@ -112,15 +133,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
|
||||
k = self.original_max_position_embeddings
|
||||
long_prompt_offset = (
|
||||
torch.any(positions > k).float() * torch.full_like(positions, k)
|
||||
).long()
|
||||
idx = (
|
||||
torch.add(positions, long_prompt_offset)
|
||||
if long_prompt_offset is not None
|
||||
else positions
|
||||
)
|
||||
if self.use_long_rope:
|
||||
k = self.original_max_position_embeddings
|
||||
long_prompt_offset = torch.full_like(positions, k).long()
|
||||
idx = torch.add(positions, long_prompt_offset)
|
||||
else:
|
||||
idx = positions
|
||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user