diff --git a/tests/model_executor/test_eagle_quantization.py b/tests/model_executor/test_eagle_quantization.py new file mode 100644 index 0000000000000..1ab75933ee31e --- /dev/null +++ b/tests/model_executor/test_eagle_quantization.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig +from vllm.model_executor.models.utils import get_draft_quant_config +from vllm.platforms import current_platform + +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) + + +def test_get_draft_quant_config_with_draft_model(): + mock_draft_model_config = Mock(spec=ModelConfig) + mock_load_config = Mock(spec=LoadConfig) + mock_speculative_config = Mock(spec=SpeculativeConfig) + mock_speculative_config.draft_model_config = mock_draft_model_config + + mock_vllm_config = Mock(spec=VllmConfig) + mock_vllm_config.speculative_config = mock_speculative_config + mock_vllm_config.load_config = mock_load_config + + mock_quant_config = Mock() + with patch.object( + VllmConfig, "get_quantization_config", return_value=mock_quant_config + ): + result = get_draft_quant_config(mock_vllm_config) + + # Verify the function calls get_quantization_config with draft model config + VllmConfig.get_quantization_config.assert_called_once_with( + mock_draft_model_config, mock_load_config + ) + assert result == mock_quant_config + + +def test_get_draft_quant_config_without_draft_model(): + mock_speculative_config = Mock(spec=SpeculativeConfig) + mock_speculative_config.draft_model_config = None + + mock_vllm_config = Mock(spec=VllmConfig) + mock_vllm_config.speculative_config = mock_speculative_config + mock_vllm_config.load_config = Mock(spec=LoadConfig) + + result = get_draft_quant_config(mock_vllm_config) + + assert result is None + + +@torch.inference_mode() +@pytest.mark.parametrize("device", DEVICES) +def test_fc_layer_quant_config_usage(dist_init, device) -> None: + import torch + + from vllm.model_executor.layers.linear import ReplicatedLinear + + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + + torch.set_default_device(device) + + input_size = 256 + output_size = 128 + + fc_no_quant = ReplicatedLinear( + input_size=input_size, + output_size=output_size, + bias=False, + params_dtype=torch.float16, + quant_config=None, + prefix="fc", + ) + + assert fc_no_quant.quant_config is None + assert fc_no_quant.input_size == input_size + assert fc_no_quant.output_size == output_size + + mock_quant_config = Mock() + fc_with_quant = ReplicatedLinear( + input_size=input_size, + output_size=output_size, + bias=False, + params_dtype=torch.float16, + quant_config=mock_quant_config, + prefix="fc", + ) + + assert fc_with_quant.quant_config == mock_quant_config + + # Check forward pass + x = torch.randn(2, input_size, dtype=torch.float16) + output, _ = fc_no_quant(x) + assert output.shape == (2, output_size) + + +def test_kv_cache_scale_name_handling(): + # Mock a quant config that supports cache scales + mock_quant_config = Mock() + mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") + + # Condition check in load_weights + name = "layers.0.self_attn.k_proj.weight" + scale_name = mock_quant_config.get_cache_scale(name) + + # Check if get_cache_scale is called and returns expected value + mock_quant_config.get_cache_scale.assert_called_once_with(name) + assert scale_name == "layers.0.self_attn.kv_scale" + + +def test_kv_cache_scale_name_no_scale(): + # Mock a quant config that returns None for get_cache_scale + mock_quant_config = Mock() + mock_quant_config.get_cache_scale = Mock(return_value=None) + + name = "layers.0.mlp.gate_proj.weight" + scale_name = mock_quant_config.get_cache_scale(name) + + # Should return None for weights that don't have cache scales + assert scale_name is None + + +def test_maybe_remap_kv_scale_name(): + from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name + + params_dict = { + "layers.0.self_attn.kv_scale": Mock(), + "layers.1.self_attn.kv_scale": Mock(), + } + + name = "layers.0.self_attn.some_scale" + remapped = maybe_remap_kv_scale_name(name, params_dict) + + assert remapped in params_dict or remapped == name or remapped is None + + +def test_load_weights_kv_scale_handling(): + kv_scale_param = Mock() + kv_scale_param.weight_loader = Mock() + + params_dict = { + "layers.0.self_attn.kv_scale": kv_scale_param, + } + + mock_quant_config = Mock() + mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") + + # Load_weights logic for KV cache scales + name = "layers.0.self_attn.k_proj.weight" + loaded_weight_tensor = torch.tensor([1.0, 2.0]) + + if mock_quant_config is not None: + scale_name = mock_quant_config.get_cache_scale(name) + if scale_name: + param = params_dict[scale_name] + assert param is kv_scale_param + weight_to_load = ( + loaded_weight_tensor + if loaded_weight_tensor.dim() == 0 + else loaded_weight_tensor[0] + ) + + assert scale_name == "layers.0.self_attn.kv_scale" + assert weight_to_load == loaded_weight_tensor[0] diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 0287132c56375..90ab5c50361b6 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -11,13 +11,22 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM -from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight +from .utils import ( + AutoWeightsLoader, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) logger = init_logger(__name__) @@ -40,14 +49,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: """Use drafter's quantization config instead of verifier's.""" - draft_model_config = vllm_config.speculative_config.draft_model_config - draft_load_config = vllm_config.load_config - - return ( - VllmConfig.get_quantization_config(draft_model_config, draft_load_config) - if draft_model_config - else None - ) + return get_draft_quant_config(vllm_config) @support_torch_compile @@ -63,6 +65,9 @@ class LlamaModel(nn.Module): self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + # Get drafter's quantization config + self.quant_config = get_draft_quant_config(vllm_config) + self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, @@ -80,8 +85,14 @@ class LlamaModel(nn.Module): for i in range(self.config.num_hidden_layers) ] ) - self.fc = torch.nn.Linear( - self.config.hidden_size * 2, self.config.hidden_size, bias=False + self.fc = ReplicatedLinear( + input_size=self.config.hidden_size * 2, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -117,6 +128,24 @@ class LlamaModel(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # Handle kv cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + # Remapping the name FP8 kv-scale + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index a3bcc5eeb32b9..75c671311b491 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -11,19 +11,27 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight +from .utils import ( + AutoWeightsLoader, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) logger = init_logger(__name__) @@ -66,14 +74,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: """Use drafter's quantization config instead of verifier's.""" - draft_model_config = vllm_config.speculative_config.draft_model_config - draft_load_config = vllm_config.load_config - - return ( - VllmConfig.get_quantization_config(draft_model_config, draft_load_config) - if draft_model_config - else None - ) + return get_draft_quant_config(vllm_config) def _norm_before_residual( self, hidden_states: torch.Tensor @@ -140,6 +141,9 @@ class LlamaModel(nn.Module): self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + # Get drafter's quantization config + self.quant_config = get_draft_quant_config(vllm_config) + current_vllm_config = get_current_vllm_config() self.embed_tokens = VocabParallelEmbedding( @@ -160,13 +164,19 @@ class LlamaModel(nn.Module): ] ) if hasattr(self.config, "target_hidden_size"): - self.fc = torch.nn.Linear( - self.config.target_hidden_size * 3, self.config.hidden_size, bias=False - ) + fc_input_size = self.config.target_hidden_size * 3 else: - self.fc = torch.nn.Linear( - self.config.hidden_size * 3, self.config.hidden_size, bias=False - ) + fc_input_size = self.config.hidden_size * 3 + self.fc = ReplicatedLinear( + input_size=fc_input_size, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, + ) + self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, @@ -211,6 +221,24 @@ class LlamaModel(nn.Module): for name, loaded_weight in weights: if "midlayer." in name: name = name.replace("midlayer.", "layers.0.") + # Handle kv cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + # Remapping the name FP8 kv-scale + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0d811fbc7585d..ca5af358e2eed 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -18,6 +18,9 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import supports_any_eagle from vllm.multimodal import NestedTensors @@ -715,6 +718,30 @@ def maybe_prefix(prefix: str, name: str) -> str: return name if not prefix else f"{prefix}.{name}" +def get_draft_quant_config( + vllm_config: VllmConfig, +) -> QuantizationConfig | None: + """Get quantization config for Draft models. + + Draft models should use their own quantization config instead of the verifier/target + model's config. This helper retrieves the draft model's quantization config. + + Args: + vllm_config: The vLLM configuration object. + + Returns: + The draft model's config if available, None otherwise. + """ + draft_model_config = vllm_config.speculative_config.draft_model_config + draft_load_config = vllm_config.load_config + + return ( + VllmConfig.get_quantization_config(draft_model_config, draft_load_config) + if draft_model_config + else None + ) + + def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int: """ Extract the layer index from the module name.