mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:55:01 +08:00
[Bugfix / Core] Prefix Caching Guards (merged with main) (#4846)
Co-authored-by: rsnm2 <rshaw@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
parent
f17a1a8f96
commit
1102bef219
44
tests/prefix_caching/test_disable_sliding_window.py
Normal file
44
tests/prefix_caching/test_disable_sliding_window.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Compare the with and without prefix caching.
|
||||
|
||||
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
|
||||
MODEL_LEN_LEN = [
|
||||
# Example models with sliding window.
|
||||
("bigcode/starcoder2-3b", 4096, 16384),
|
||||
# ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI
|
||||
|
||||
# Confirm model with sliding window works.
|
||||
# config has "use_sliding_window": false
|
||||
("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768),
|
||||
# config has no sliding window attribute.
|
||||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN)
|
||||
def test_disable_sliding_window(model_len_len, ):
|
||||
model, sliding_len, full_len = model_len_len
|
||||
vllm_disabled_model = LLM(model, disable_sliding_window=True)
|
||||
vllm_disabled_model.generate("Hi my name is")
|
||||
model_config = vllm_disabled_model.llm_engine.model_config
|
||||
assert model_config.max_model_len == sliding_len, (
|
||||
"Max len expected to equal sliding_len of %s, but got %s", sliding_len,
|
||||
model_config.max_model_len)
|
||||
|
||||
del vllm_disabled_model
|
||||
cleanup()
|
||||
|
||||
vllm_enabled_model = LLM(model, disable_sliding_window=False)
|
||||
vllm_enabled_model.generate("Hi my name is")
|
||||
model_config = vllm_enabled_model.llm_engine.model_config
|
||||
assert model_config.max_model_len == full_len, (
|
||||
"Max len expected to equal full_len of %s, but got %s", full_len,
|
||||
model_config.max_model_len)
|
||||
|
||||
del vllm_enabled_model
|
||||
cleanup()
|
||||
@ -1,5 +1,29 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
MODEL_IDS_EXPECTED = [
|
||||
("Qwen/Qwen1.5-7B", 32768),
|
||||
("mistralai/Mistral-7B-v0.1", 4096),
|
||||
("mistralai/Mistral-7B-Instruct-v0.2", 32768),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
|
||||
def test_disable_sliding_window(model_id_expected):
|
||||
model_id, expected = model_id_expected
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
disable_sliding_window=True,
|
||||
)
|
||||
assert model_config.max_model_len == expected
|
||||
|
||||
|
||||
def test_get_sliding_window():
|
||||
TEST_SLIDING_WINDOW = 4096
|
||||
|
||||
@ -30,7 +30,6 @@ class Attention(nn.Module):
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
@ -39,9 +38,11 @@ class Attention(nn.Module):
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
sliding_window = None
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
|
||||
|
||||
@ -69,6 +69,10 @@ class ModelConfig:
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode
|
||||
disable_sliding_window: Whether to disable sliding window. If True,
|
||||
we will disable the sliding window functionality of the model.
|
||||
If the model does not support sliding window, this argument is
|
||||
ignored.
|
||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||
detokenizer.
|
||||
served_model_name: The model name used in metrics tag `model_name`,
|
||||
@ -96,6 +100,7 @@ class ModelConfig:
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 5,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
) -> None:
|
||||
@ -118,14 +123,18 @@ class ModelConfig:
|
||||
self.max_seq_len_to_capture = (max_seq_len_to_capture
|
||||
or max_context_len_to_capture)
|
||||
self.max_logprobs = max_logprobs
|
||||
self.disable_sliding_window = disable_sliding_window
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, rope_scaling)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
||||
max_model_len)
|
||||
self.max_model_len = _get_and_verify_max_len(
|
||||
hf_config=self.hf_text_config,
|
||||
max_model_len=max_model_len,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
sliding_window_len=self.get_hf_config_sliding_window())
|
||||
self.served_model_name = get_served_model_name(model,
|
||||
served_model_name)
|
||||
if not self.skip_tokenizer_init:
|
||||
@ -220,7 +229,7 @@ class ModelConfig:
|
||||
"must be divisible by pipeline parallel size "
|
||||
f"({pipeline_parallel_size}).")
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
def get_hf_config_sliding_window(self) -> Optional[int]:
|
||||
"""Get the sliding window size, or None if disabled.
|
||||
"""
|
||||
|
||||
@ -232,6 +241,15 @@ class ModelConfig:
|
||||
return None
|
||||
return getattr(self.hf_text_config, "sliding_window", None)
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
"""Get the sliding window size, or None if disabled.
|
||||
"""
|
||||
# If user disables sliding window, return None.
|
||||
if self.disable_sliding_window:
|
||||
return None
|
||||
# Otherwise get the value from the hf config.
|
||||
return self.get_hf_config_sliding_window()
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.hf_text_config.vocab_size
|
||||
|
||||
@ -336,6 +354,7 @@ class CacheConfig:
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
self._verify_prefix_caching()
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks = None
|
||||
@ -364,6 +383,19 @@ class CacheConfig:
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
|
||||
def _verify_prefix_caching(self) -> None:
|
||||
if not self.enable_prefix_caching:
|
||||
return
|
||||
|
||||
if self.sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported with sliding window. "
|
||||
"Run with --disable-sliding-window to use prefix caching.")
|
||||
if self.cache_dtype == "fp8":
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for fp8 cache_dtype. "
|
||||
"Run with --kv-cache-dtype auto to use prefix caching.")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
@ -1116,6 +1148,8 @@ def _get_and_verify_dtype(
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
max_model_len: Optional[int],
|
||||
disable_sliding_window: bool,
|
||||
sliding_window_len: Optional[int],
|
||||
) -> int:
|
||||
"""Get and verify the model's maximum length."""
|
||||
derived_max_model_len = float("inf")
|
||||
@ -1135,6 +1169,7 @@ def _get_and_verify_max_len(
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
# Choose the smallest "max_length" from the possible keys.
|
||||
max_len_key = None
|
||||
for key in possible_keys:
|
||||
max_len = getattr(hf_config, key, None)
|
||||
@ -1142,6 +1177,16 @@ def _get_and_verify_max_len(
|
||||
max_len_key = key if max_len < derived_max_model_len \
|
||||
else max_len_key
|
||||
derived_max_model_len = min(derived_max_model_len, max_len)
|
||||
|
||||
# If sliding window is manually disabled, max_length should be less
|
||||
# than the sliding window length in the model config.
|
||||
if disable_sliding_window and sliding_window_len is not None:
|
||||
max_len_key = "sliding_window" \
|
||||
if sliding_window_len < derived_max_model_len else max_len_key
|
||||
derived_max_model_len = min(derived_max_model_len, sliding_window_len)
|
||||
|
||||
# If none of the keys were found in the config, use a default and
|
||||
# log a warning.
|
||||
if derived_max_model_len == float("inf"):
|
||||
if max_model_len is not None:
|
||||
# If max_model_len is specified, we use it.
|
||||
@ -1157,6 +1202,13 @@ def _get_and_verify_max_len(
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None and rope_scaling["type"] != "su":
|
||||
if disable_sliding_window:
|
||||
# TODO(robertgshaw): Find a model that supports rope_scaling
|
||||
# with sliding window to see if this case should be allowed.
|
||||
raise NotImplementedError(
|
||||
"Disabling sliding window is not supported for models "
|
||||
"with rope_scaling. Please raise an issue so we can "
|
||||
"investigate.")
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "yarn":
|
||||
@ -1164,6 +1216,8 @@ def _get_and_verify_max_len(
|
||||
"original_max_position_embeddings"]
|
||||
derived_max_model_len *= scaling_factor
|
||||
|
||||
# If the user specified a max length, make sure it is smaller than the
|
||||
# derived length from the HF model config.
|
||||
if max_model_len is None:
|
||||
max_model_len = int(derived_max_model_len)
|
||||
elif max_model_len > derived_max_model_len:
|
||||
@ -1172,6 +1226,13 @@ def _get_and_verify_max_len(
|
||||
# with model_max_length and allow this override when it's smaller.
|
||||
model_max_length = getattr(hf_config, "model_max_length", None)
|
||||
if model_max_length is not None and max_model_len <= model_max_length:
|
||||
if disable_sliding_window:
|
||||
# TODO(robertgshaw): Find a model that has model_max_length
|
||||
# with sliding window to see if this case should be allowed.
|
||||
raise NotImplementedError(
|
||||
"Disabling sliding window is not supported for models "
|
||||
"model_max_length in the config. Please raise an issue "
|
||||
"so we can investigate.")
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@ -41,6 +41,7 @@ class EngineArgs:
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: int = 16
|
||||
enable_prefix_caching: bool = False
|
||||
disable_sliding_window: bool = False
|
||||
use_v2_block_manager: bool = False
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
@ -267,6 +268,10 @@ class EngineArgs:
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='Enables automatic prefix caching.')
|
||||
parser.add_argument('--disable-sliding-window',
|
||||
action='store_true',
|
||||
help='Disables sliding window, '
|
||||
'capping to sliding window size')
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2.')
|
||||
@ -558,8 +563,8 @@ class EngineArgs:
|
||||
self.max_model_len, self.quantization,
|
||||
self.quantization_param_path, self.enforce_eager,
|
||||
self.max_context_len_to_capture, self.max_seq_len_to_capture,
|
||||
self.max_logprobs, self.skip_tokenizer_init,
|
||||
self.served_model_name)
|
||||
self.max_logprobs, self.disable_sliding_window,
|
||||
self.skip_tokenizer_init, self.served_model_name)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
@ -645,7 +650,8 @@ class EngineArgs:
|
||||
if (model_config.get_sliding_window() is not None
|
||||
and scheduler_config.chunked_prefill_enabled):
|
||||
raise ValueError(
|
||||
"Chunked prefill is not supported with sliding window.")
|
||||
"Chunked prefill is not supported with sliding window. "
|
||||
"Set --disable-sliding-window to disable sliding window.")
|
||||
|
||||
return EngineConfig(model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
|
||||
@ -94,7 +94,6 @@ class LlamaAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
sliding_window: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -146,7 +145,6 @@ class LlamaAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -183,7 +181,6 @@ class LlamaDecoderLayer(nn.Module):
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
@ -198,7 +195,6 @@ class LlamaDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
|
||||
@ -246,15 +246,16 @@ class MixtralMoE(nn.Module):
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -276,7 +277,6 @@ class MixtralAttention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
if isinstance(
|
||||
quant_config,
|
||||
@ -312,7 +312,6 @@ class MixtralAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -349,7 +348,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
|
||||
@ -166,7 +166,6 @@ class MixtralAttention(nn.Module):
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -190,7 +189,6 @@ class MixtralAttention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@ -217,7 +215,6 @@ class MixtralAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -254,7 +251,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.block_sparse_moe = MixtralMoE(config=config,
|
||||
|
||||
@ -86,10 +86,8 @@ class Qwen2Attention(nn.Module):
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
use_sliding_window: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
rope_scaling: Optional[Tuple] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -112,7 +110,6 @@ class Qwen2Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@ -140,7 +137,6 @@ class Qwen2Attention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -164,7 +160,6 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
@ -173,18 +168,14 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
use_sliding_window = (config.use_sliding_window
|
||||
and layer_idx < config.max_window_layers)
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
use_sliding_window=use_sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
sliding_window=config.sliding_window,
|
||||
rope_scaling=rope_scaling)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -244,8 +235,8 @@ class Qwen2Model(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
Qwen2DecoderLayer(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -302,6 +293,18 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
del lora_config
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
and hasattr(config, "max_window_layers")):
|
||||
raise ValueError("Sliding window for some but all layers is not "
|
||||
"supported. This model uses sliding window "
|
||||
"but `max_window_layers` = %s is less than "
|
||||
"`num_hidden_layers` = %s. Please open an issue "
|
||||
"to discuss this feature." % (
|
||||
config.max_window_layers,
|
||||
config.num_hidden_layers,
|
||||
))
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
@ -74,7 +74,6 @@ class Starcoder2Attention(nn.Module):
|
||||
self.rope_theta = config.rope_theta
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.use_bias = config.use_bias
|
||||
self.sliding_window = config.sliding_window
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
@ -101,7 +100,6 @@ class Starcoder2Attention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
@ -88,7 +88,6 @@ class XverseAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
sliding_window: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -134,7 +133,6 @@ class XverseAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -167,7 +165,6 @@ class XverseDecoderLayer(nn.Module):
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
self.self_attn = XverseAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -178,7 +175,6 @@ class XverseDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "bias", False),
|
||||
sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.mlp = XverseMLP(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user