mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 11:24:31 +08:00
[Quantization] [Eagle] Add complete quantization support to the draft model in Eagle (#28435)
Signed-off-by: Shreyas Kulkarni <shreyas.gp269@gmail.com>
This commit is contained in:
parent
7765e5ba75
commit
95ae50b7d1
169
tests/model_executor/test_eagle_quantization.py
Normal file
169
tests/model_executor/test_eagle_quantization.py
Normal file
@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.model_executor.models.utils import get_draft_quant_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEVICES = (
|
||||
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
if current_platform.is_cuda_alike()
|
||||
else ["cpu"]
|
||||
)
|
||||
|
||||
|
||||
def test_get_draft_quant_config_with_draft_model():
|
||||
mock_draft_model_config = Mock(spec=ModelConfig)
|
||||
mock_load_config = Mock(spec=LoadConfig)
|
||||
mock_speculative_config = Mock(spec=SpeculativeConfig)
|
||||
mock_speculative_config.draft_model_config = mock_draft_model_config
|
||||
|
||||
mock_vllm_config = Mock(spec=VllmConfig)
|
||||
mock_vllm_config.speculative_config = mock_speculative_config
|
||||
mock_vllm_config.load_config = mock_load_config
|
||||
|
||||
mock_quant_config = Mock()
|
||||
with patch.object(
|
||||
VllmConfig, "get_quantization_config", return_value=mock_quant_config
|
||||
):
|
||||
result = get_draft_quant_config(mock_vllm_config)
|
||||
|
||||
# Verify the function calls get_quantization_config with draft model config
|
||||
VllmConfig.get_quantization_config.assert_called_once_with(
|
||||
mock_draft_model_config, mock_load_config
|
||||
)
|
||||
assert result == mock_quant_config
|
||||
|
||||
|
||||
def test_get_draft_quant_config_without_draft_model():
|
||||
mock_speculative_config = Mock(spec=SpeculativeConfig)
|
||||
mock_speculative_config.draft_model_config = None
|
||||
|
||||
mock_vllm_config = Mock(spec=VllmConfig)
|
||||
mock_vllm_config.speculative_config = mock_speculative_config
|
||||
mock_vllm_config.load_config = Mock(spec=LoadConfig)
|
||||
|
||||
result = get_draft_quant_config(mock_vllm_config)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_fc_layer_quant_config_usage(dist_init, device) -> None:
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
|
||||
input_size = 256
|
||||
output_size = 128
|
||||
|
||||
fc_no_quant = ReplicatedLinear(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=False,
|
||||
params_dtype=torch.float16,
|
||||
quant_config=None,
|
||||
prefix="fc",
|
||||
)
|
||||
|
||||
assert fc_no_quant.quant_config is None
|
||||
assert fc_no_quant.input_size == input_size
|
||||
assert fc_no_quant.output_size == output_size
|
||||
|
||||
mock_quant_config = Mock()
|
||||
fc_with_quant = ReplicatedLinear(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=False,
|
||||
params_dtype=torch.float16,
|
||||
quant_config=mock_quant_config,
|
||||
prefix="fc",
|
||||
)
|
||||
|
||||
assert fc_with_quant.quant_config == mock_quant_config
|
||||
|
||||
# Check forward pass
|
||||
x = torch.randn(2, input_size, dtype=torch.float16)
|
||||
output, _ = fc_no_quant(x)
|
||||
assert output.shape == (2, output_size)
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_handling():
|
||||
# Mock a quant config that supports cache scales
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Condition check in load_weights
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Check if get_cache_scale is called and returns expected value
|
||||
mock_quant_config.get_cache_scale.assert_called_once_with(name)
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_no_scale():
|
||||
# Mock a quant config that returns None for get_cache_scale
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value=None)
|
||||
|
||||
name = "layers.0.mlp.gate_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Should return None for weights that don't have cache scales
|
||||
assert scale_name is None
|
||||
|
||||
|
||||
def test_maybe_remap_kv_scale_name():
|
||||
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
||||
|
||||
params_dict = {
|
||||
"layers.0.self_attn.kv_scale": Mock(),
|
||||
"layers.1.self_attn.kv_scale": Mock(),
|
||||
}
|
||||
|
||||
name = "layers.0.self_attn.some_scale"
|
||||
remapped = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
assert remapped in params_dict or remapped == name or remapped is None
|
||||
|
||||
|
||||
def test_load_weights_kv_scale_handling():
|
||||
kv_scale_param = Mock()
|
||||
kv_scale_param.weight_loader = Mock()
|
||||
|
||||
params_dict = {
|
||||
"layers.0.self_attn.kv_scale": kv_scale_param,
|
||||
}
|
||||
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Load_weights logic for KV cache scales
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
loaded_weight_tensor = torch.tensor([1.0, 2.0])
|
||||
|
||||
if mock_quant_config is not None:
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
if scale_name:
|
||||
param = params_dict[scale_name]
|
||||
assert param is kv_scale_param
|
||||
weight_to_load = (
|
||||
loaded_weight_tensor
|
||||
if loaded_weight_tensor.dim() == 0
|
||||
else loaded_weight_tensor[0]
|
||||
)
|
||||
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
assert weight_to_load == loaded_weight_tensor[0]
|
||||
@ -11,13 +11,22 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
get_draft_quant_config,
|
||||
maybe_prefix,
|
||||
process_eagle_weight,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -40,14 +49,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
|
||||
"""Use drafter's quantization config instead of verifier's."""
|
||||
draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||
draft_load_config = vllm_config.load_config
|
||||
|
||||
return (
|
||||
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
|
||||
if draft_model_config
|
||||
else None
|
||||
)
|
||||
return get_draft_quant_config(vllm_config)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@ -63,6 +65,9 @@ class LlamaModel(nn.Module):
|
||||
self.config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
# Get drafter's quantization config
|
||||
self.quant_config = get_draft_quant_config(vllm_config)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
@ -80,8 +85,14 @@ class LlamaModel(nn.Module):
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.fc = torch.nn.Linear(
|
||||
self.config.hidden_size * 2, self.config.hidden_size, bias=False
|
||||
self.fc = ReplicatedLinear(
|
||||
input_size=self.config.hidden_size * 2,
|
||||
output_size=self.config.hidden_size,
|
||||
bias=False,
|
||||
params_dtype=vllm_config.model_config.dtype,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "fc"),
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@ -117,6 +128,24 @@ class LlamaModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
# Handle kv cache quantization scales
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
# Remapping the name FP8 kv-scale
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@ -11,19 +11,27 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
get_draft_quant_config,
|
||||
maybe_prefix,
|
||||
process_eagle_weight,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -66,14 +74,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
|
||||
"""Use drafter's quantization config instead of verifier's."""
|
||||
draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||
draft_load_config = vllm_config.load_config
|
||||
|
||||
return (
|
||||
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
|
||||
if draft_model_config
|
||||
else None
|
||||
)
|
||||
return get_draft_quant_config(vllm_config)
|
||||
|
||||
def _norm_before_residual(
|
||||
self, hidden_states: torch.Tensor
|
||||
@ -140,6 +141,9 @@ class LlamaModel(nn.Module):
|
||||
self.config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
# Get drafter's quantization config
|
||||
self.quant_config = get_draft_quant_config(vllm_config)
|
||||
|
||||
current_vllm_config = get_current_vllm_config()
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@ -160,13 +164,19 @@ class LlamaModel(nn.Module):
|
||||
]
|
||||
)
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
self.fc = torch.nn.Linear(
|
||||
self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
|
||||
)
|
||||
fc_input_size = self.config.target_hidden_size * 3
|
||||
else:
|
||||
self.fc = torch.nn.Linear(
|
||||
self.config.hidden_size * 3, self.config.hidden_size, bias=False
|
||||
)
|
||||
fc_input_size = self.config.hidden_size * 3
|
||||
self.fc = ReplicatedLinear(
|
||||
input_size=fc_input_size,
|
||||
output_size=self.config.hidden_size,
|
||||
bias=False,
|
||||
params_dtype=vllm_config.model_config.dtype,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "fc"),
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(
|
||||
self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps,
|
||||
@ -211,6 +221,24 @@ class LlamaModel(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "midlayer." in name:
|
||||
name = name.replace("midlayer.", "layers.0.")
|
||||
# Handle kv cache quantization scales
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
# Remapping the name FP8 kv-scale
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@ -18,6 +18,9 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import supports_any_eagle
|
||||
from vllm.multimodal import NestedTensors
|
||||
@ -715,6 +718,30 @@ def maybe_prefix(prefix: str, name: str) -> str:
|
||||
return name if not prefix else f"{prefix}.{name}"
|
||||
|
||||
|
||||
def get_draft_quant_config(
|
||||
vllm_config: VllmConfig,
|
||||
) -> QuantizationConfig | None:
|
||||
"""Get quantization config for Draft models.
|
||||
|
||||
Draft models should use their own quantization config instead of the verifier/target
|
||||
model's config. This helper retrieves the draft model's quantization config.
|
||||
|
||||
Args:
|
||||
vllm_config: The vLLM configuration object.
|
||||
|
||||
Returns:
|
||||
The draft model's config if available, None otherwise.
|
||||
"""
|
||||
draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||
draft_load_config = vllm_config.load_config
|
||||
|
||||
return (
|
||||
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
|
||||
if draft_model_config
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
|
||||
"""
|
||||
Extract the layer index from the module name.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user