[Bugfix / Core] Prefix Caching Guards (merged with main) (#4846)

Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
Zhuohan Li 2024-05-27 15:18:17 -07:00 committed by GitHub
parent f17a1a8f96
commit 1102bef219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 167 additions and 44 deletions

View File

@ -0,0 +1,44 @@
"""Compare the with and without prefix caching.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest
from tests.conftest import cleanup
from vllm import LLM
MODEL_LEN_LEN = [
# Example models with sliding window.
("bigcode/starcoder2-3b", 4096, 16384),
# ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI
# Confirm model with sliding window works.
# config has "use_sliding_window": false
("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768),
# config has no sliding window attribute.
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048),
]
@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN)
def test_disable_sliding_window(model_len_len, ):
model, sliding_len, full_len = model_len_len
vllm_disabled_model = LLM(model, disable_sliding_window=True)
vllm_disabled_model.generate("Hi my name is")
model_config = vllm_disabled_model.llm_engine.model_config
assert model_config.max_model_len == sliding_len, (
"Max len expected to equal sliding_len of %s, but got %s", sliding_len,
model_config.max_model_len)
del vllm_disabled_model
cleanup()
vllm_enabled_model = LLM(model, disable_sliding_window=False)
vllm_enabled_model.generate("Hi my name is")
model_config = vllm_enabled_model.llm_engine.model_config
assert model_config.max_model_len == full_len, (
"Max len expected to equal full_len of %s, but got %s", full_len,
model_config.max_model_len)
del vllm_enabled_model
cleanup()

View File

@ -1,5 +1,29 @@
import pytest
from vllm.config import ModelConfig
MODEL_IDS_EXPECTED = [
("Qwen/Qwen1.5-7B", 32768),
("mistralai/Mistral-7B-v0.1", 4096),
("mistralai/Mistral-7B-Instruct-v0.2", 32768),
]
@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
def test_disable_sliding_window(model_id_expected):
model_id, expected = model_id_expected
model_config = ModelConfig(
model_id,
model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
disable_sliding_window=True,
)
assert model_config.max_model_len == expected
def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096

View File

@ -30,7 +30,6 @@ class Attention(nn.Module):
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
@ -39,9 +38,11 @@ class Attention(nn.Module):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
if num_kv_heads is None:
num_kv_heads = num_heads

View File

@ -69,6 +69,10 @@ class ModelConfig:
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is
ignored.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
@ -96,6 +100,7 @@ class ModelConfig:
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
) -> None:
@ -118,14 +123,18 @@ class ModelConfig:
self.max_seq_len_to_capture = (max_seq_len_to_capture
or max_context_len_to_capture)
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)
if not self.skip_tokenizer_init:
@ -220,7 +229,7 @@ class ModelConfig:
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
def get_sliding_window(self) -> Optional[int]:
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
@ -232,6 +241,15 @@ class ModelConfig:
return None
return getattr(self.hf_text_config, "sliding_window", None)
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if self.disable_sliding_window:
return None
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size
@ -336,6 +354,7 @@ class CacheConfig:
self.enable_prefix_caching = enable_prefix_caching
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
# Will be set after profiling.
self.num_gpu_blocks = None
@ -364,6 +383,19 @@ class CacheConfig:
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
def _verify_prefix_caching(self) -> None:
if not self.enable_prefix_caching:
return
if self.sliding_window is not None:
raise NotImplementedError(
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")
if self.cache_dtype == "fp8":
raise NotImplementedError(
"Prefix caching is not supported for fp8 cache_dtype. "
"Run with --kv-cache-dtype auto to use prefix caching.")
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
@ -1116,6 +1148,8 @@ def _get_and_verify_dtype(
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
@ -1135,6 +1169,7 @@ def _get_and_verify_max_len(
"max_seq_length",
"seq_len",
]
# Choose the smallest "max_length" from the possible keys.
max_len_key = None
for key in possible_keys:
max_len = getattr(hf_config, key, None)
@ -1142,6 +1177,16 @@ def _get_and_verify_max_len(
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
derived_max_model_len = min(derived_max_model_len, max_len)
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:
max_len_key = "sliding_window" \
if sliding_window_len < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len, sliding_window_len)
# If none of the keys were found in the config, use a default and
# log a warning.
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
@ -1157,6 +1202,13 @@ def _get_and_verify_max_len(
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
@ -1164,6 +1216,8 @@ def _get_and_verify_max_len(
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if max_model_len is None:
max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
@ -1172,6 +1226,13 @@ def _get_and_verify_max_len(
# with model_max_length and allow this override when it's smaller.
model_max_length = getattr(hf_config, "model_max_length", None)
if model_max_length is not None and max_model_len <= model_max_length:
if disable_sliding_window:
# TODO(robertgshaw): Find a model that has model_max_length
# with sliding window to see if this case should be allowed.
raise NotImplementedError(
"Disabling sliding window is not supported for models "
"model_max_length in the config. Please raise an issue "
"so we can investigate.")
pass
else:
raise ValueError(

View File

@ -41,6 +41,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
@ -267,6 +268,10 @@ class EngineArgs:
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching.')
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2.')
@ -558,8 +563,8 @@ class EngineArgs:
self.max_model_len, self.quantization,
self.quantization_param_path, self.enforce_eager,
self.max_context_len_to_capture, self.max_seq_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init,
self.served_model_name)
self.max_logprobs, self.disable_sliding_window,
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
@ -645,7 +650,8 @@ class EngineArgs:
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled):
raise ValueError(
"Chunked prefill is not supported with sliding window.")
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window.")
return EngineConfig(model_config=model_config,
cache_config=cache_config,

View File

@ -94,7 +94,6 @@ class LlamaAttention(nn.Module):
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
@ -146,7 +145,6 @@ class LlamaAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config)
@ -183,7 +181,6 @@ class LlamaDecoderLayer(nn.Module):
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
@ -198,7 +195,6 @@ class LlamaDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
cache_config=cache_config,
)
self.mlp = LlamaMLP(

View File

@ -246,15 +246,16 @@ class MixtralMoE(nn.Module):
class MixtralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@ -276,7 +277,6 @@ class MixtralAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
if isinstance(
quant_config,
@ -312,7 +312,6 @@ class MixtralAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
@ -349,7 +348,6 @@ class MixtralDecoderLayer(nn.Module):
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(

View File

@ -166,7 +166,6 @@ class MixtralAttention(nn.Module):
max_position: int = 4096 * 32,
rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
@ -190,7 +189,6 @@ class MixtralAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
@ -217,7 +215,6 @@ class MixtralAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
@ -254,7 +251,6 @@ class MixtralDecoderLayer(nn.Module):
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config,

View File

@ -86,10 +86,8 @@ class Qwen2Attention(nn.Module):
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None,
rope_scaling: Optional[Tuple] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -112,7 +110,6 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window if use_sliding_window else None
self.qkv_proj = QKVParallelLinear(
hidden_size,
@ -140,7 +137,6 @@ class Qwen2Attention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
@ -164,7 +160,6 @@ class Qwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
@ -173,18 +168,14 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
use_sliding_window = (config.use_sliding_window
and layer_idx < config.max_window_layers)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
cache_config=cache_config,
quant_config=quant_config,
sliding_window=config.sliding_window,
rope_scaling=rope_scaling)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
@ -244,8 +235,8 @@ class Qwen2Model(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
Qwen2DecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -302,6 +293,18 @@ class Qwen2ForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = %s is less than "
"`num_hidden_layers` = %s. Please open an issue "
"to discuss this feature." % (
config.max_window_layers,
config.num_hidden_layers,
))
super().__init__()
self.config = config
self.quant_config = quant_config

View File

@ -74,7 +74,6 @@ class Starcoder2Attention(nn.Module):
self.rope_theta = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.use_bias = config.use_bias
self.sliding_window = config.sliding_window
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
@ -101,7 +100,6 @@ class Starcoder2Attention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)

View File

@ -88,7 +88,6 @@ class XverseAttention(nn.Module):
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
@ -134,7 +133,6 @@ class XverseAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config)
@ -167,7 +165,6 @@ class XverseDecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
self.self_attn = XverseAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@ -178,7 +175,6 @@ class XverseDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
cache_config=cache_config,
)
self.mlp = XverseMLP(