diff --git a/vllm/config.py b/vllm/config.py index 5c300e327397b..dd59526471782 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1108,6 +1108,21 @@ class ModelConfig: if quant_cfg is None: # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) + + else: + # Set quant_method for ModelOpt models. + producer_name = quant_cfg.get("producer", {}).get("name") + if producer_name == "modelopt": + quant_algo = quant_cfg.get("quantization", + {}).get("quant_algo") + if quant_algo == "FP8": + quant_cfg["quant_method"] = "modelopt" + elif quant_algo == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + elif quant_algo is not None: + raise ValueError( + f"Unknown ModelOpt quant algo: {quant_algo}") + return quant_cfg def _verify_quantization(self) -> None: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9e7296feeae1e..f155a1b11fbff 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -919,9 +919,13 @@ class FusedMoE(torch.nn.Module): elif shard_id == "w2": param_data[expert_id] = loaded_weight - def _load_w13_weight_scale(self, shard_dim: int, - loaded_weight: torch.Tensor, - param: torch.Tensor, tp_rank: int): + def _load_combined_w13_weight_scale(self, shard_dim: int, + loaded_weight: torch.Tensor, + param: torch.Tensor, tp_rank: int): + """ + Load w13 weight scales assuming that w1 weight scales and w3 weight + scales are stored in the same loaded_weight tensor. + """ shard_size = param.shape[shard_dim] loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) @@ -1168,24 +1172,43 @@ class FusedMoE(torch.nn.Module): uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( ) - # For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale" - per_tensor_conditions = ( - "weight_scale_2" in weight_name if uses_weight_scale_2 else - "weight_scale" in weight_name) or "input_scale" in weight_name - - if "w13_weight_scale" in weight_name: - self._load_w13_weight_scale(shard_dim=shard_dim, - loaded_weight=loaded_weight, - param=param, - tp_rank=self.tp_rank) - elif per_tensor_conditions: + # Call _load_per_tensor_weight_scale() to load per-tensor (scalar) + # weights scales. + # Input scales are always per-tensor. + # Weight scales: FP4 uses "weight_scale_2" and FP8 uses + # "weight_scale" for per-tensor scales. + is_per_tensor = ("weight_scale_2" in weight_name + if uses_weight_scale_2 else "weight_scale" + in weight_name) or "input_scale" in weight_name + if is_per_tensor: self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, loaded_weight=loaded_weight, expert_id=expert_id, ) - elif "weight" in weight_name: + return True if return_success else None + + # If the weight is w13_weight_scale and w13_weight_scales are + # combined into single loaded_weight, call + # _load_combined_w13_weight_scale() to load it. + # This is checked by comparing the hidden_out dims of the + # loaded_weight and the param. + if "w13_weight_scale" in weight_name: + loaded_weight_hidden_out = loaded_weight.shape[-2] + param_hidden_out = param.data.shape[-2] * self.tp_size + if loaded_weight_hidden_out == param_hidden_out: + self._load_combined_w13_weight_scale( + shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + tp_rank=self.tp_rank, + ) + return True if return_success else None + + # For other weights, call _load_model_weight_or_group_weight_scale() + # to load it. + if "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cc41a771d06c2..8fe153464d360 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -449,6 +449,20 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) + # ModelOpt 0.31.0 and after saves the quantization config in the model + # config file. + quantization_config = config_dict.get("quantization_config", None) + + # ModelOpt 0.29.0 and before saves the quantization config in a separate + # "hf_quant_config.json" in the same directory as the model config file. + if quantization_config is None \ + and file_or_path_exists(model, "hf_quant_config.json", revision): + quantization_config = get_hf_file_to_dict("hf_quant_config.json", + model, revision) + + if quantization_config is not None: + config.quantization_config = quantization_config + if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) config.update(hf_overrides_kw)