diff --git a/docs/source/features/quantization/gguf.md b/docs/source/features/quantization/gguf.md index 65c181900f9be..4b1ff4a22a23b 100644 --- a/docs/source/features/quantization/gguf.md +++ b/docs/source/features/quantization/gguf.md @@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size. ::: +GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path + +```console +# If you model is not supported by huggingface you can manually provide a huggingface compatible config path +vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0 +``` + You can also use the GGUF model directly through the LLM entrypoint: ```python diff --git a/vllm/config.py b/vllm/config.py index a5d8ee9303d0e..d1384c6375f30 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -229,6 +229,7 @@ class ModelConfig: trust_remote_code: bool, dtype: Union[str, torch.dtype], seed: int, + hf_config_path: Optional[str] = None, allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, @@ -259,6 +260,7 @@ class ModelConfig: model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, ) -> None: self.model = model + self.hf_config_path = hf_config_path self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code @@ -321,8 +323,9 @@ class ModelConfig: if self.enable_sleep_mode and not current_platform.is_cuda(): raise ValueError("Sleep mode is only supported on CUDA devices.") - hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format) + hf_config = get_config(self.hf_config_path or self.model, + trust_remote_code, revision, code_revision, + config_format) if hf_overrides_kw: logger.info("Overriding HF config with %s", hf_overrides_kw) @@ -947,7 +950,7 @@ class ModelConfig: def try_get_generation_config(self) -> Dict[str, Any]: if self.generation_config is None or self.generation_config == "auto": config = try_get_generation_config( - self.model, + self.hf_config_path or self.model, trust_remote_code=self.trust_remote_code, revision=self.revision, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 26d4a84b841ce..1a2f794c9151d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -93,6 +93,7 @@ class EngineArgs: model: str = 'facebook/opt-125m' served_model_name: Optional[Union[str, List[str]]] = None tokenizer: Optional[str] = None + hf_config_path: Optional[str] = None task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' @@ -262,6 +263,12 @@ class EngineArgs: default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use. ' 'If unspecified, model name or path will be used.') + parser.add_argument( + "--hf-config-path", + type=nullable_str, + default=EngineArgs.hf_config_path, + help='Name or path of the huggingface config to use. ' + 'If unspecified, model name or path will be used.') parser.add_argument( '--skip-tokenizer-init', action='store_true', @@ -1076,6 +1083,7 @@ class EngineArgs: return ModelConfig( model=self.model, + hf_config_path=self.hf_config_path, task=self.task, # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 42554b61f67ab..28a88571dab4b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,6 +5,7 @@ from enum import Enum from typing import Callable, List, Optional, Tuple import torch +from torch.nn.parameter import UninitializedParameter import vllm.envs as envs from vllm.distributed import (get_tensor_model_parallel_rank, @@ -514,7 +515,12 @@ class FusedMoE(torch.nn.Module): # dimension intermediate_size_per_partition is used. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} - expert_data = param.data[expert_id] + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + param.data.copy_(loaded_weight) + return # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -524,6 +530,20 @@ class FusedMoE(torch.nn.Module): if is_transposed: shard_dim = int(not shard_dim) + full_load = len(loaded_weight.shape) == 3 + if full_load: + shard_dim += 1 + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if shard_id in ["w1", "w3"]: + final_shape[1] *= 2 + final_shape[shard_dim] = final_shape[ + shard_dim] // get_tensor_model_parallel_world_size() + param.materialize(final_shape, dtype=loaded_weight.dtype) + + expert_data = param.data if full_load else param.data[expert_id] # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: # this is needed for compressed-tensors only diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 521724765bebf..b9c85aaf50b53 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -235,10 +235,23 @@ class ReplicatedLinear(LinearBase): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # If the weight on disk does not have a shape, give it one # (such scales for AutoFp8). + # Special case for GGUF + + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert param.size() == loaded_weight.size() + assert param.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}") param.data.copy_(loaded_weight) def forward(self, diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index b1fecb32f4d80..ba176e4a567cc 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import gguf import torch @@ -8,6 +8,9 @@ from gguf import GGMLQuantizationType as WeightType from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -29,7 +32,7 @@ class GGUFConfig(QuantizationConfig): return "gguf" def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half, torch.bfloat16] + return [torch.half] @classmethod def get_min_capability(cls) -> int: @@ -49,6 +52,8 @@ class GGUFConfig(QuantizationConfig): return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): return GGUFEmbeddingMethod(self) + elif isinstance(layer, FusedMoE): + return GGUFMoEMethod(self) return None @@ -184,6 +189,124 @@ class GGUFLinearMethod(LinearMethodBase): return out +class GGUFMoEMethod(FusedMoEMethodBase): + """MoE method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + tensor_shape = (num_experts, 2 * intermediate_size_per_partition, + hidden_size) + #gate up proj + w13_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w13_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w13_qweight, extra_weight_attrs) + layer.register_parameter("w13_qweight", w13_qweight) + + w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w13_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + set_weight_attrs(w13_qweight_type, extra_weight_attrs) + layer.register_parameter("w13_qweight_type", w13_qweight_type) + + tensor_shape = (num_experts, intermediate_size_per_partition, + hidden_size) + #gate down proj + w2_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w2_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w2_qweight, extra_weight_attrs) + layer.register_parameter("w2_qweight", w2_qweight) + + w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w2_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + + set_weight_attrs(w2_qweight_type, extra_weight_attrs) + layer.register_parameter("w2_qweight_type", w2_qweight_type) + self.act = SiluAndMul() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ): + assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + final_hidden_states = torch.empty_like(x) + for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): + inp = x[tok].reshape((1, ) + x.shape[1:]) + current_hidden_state = None + for ww, ii in zip(w, idx): + expert_up = layer.w13_qweight[ii] + + out = _fuse_mul_mat(inp, expert_up, + layer.w13_qweight_type.weight_type) + out = self.act(out) + + expert_down = layer.w2_qweight[ii] + current_state = _fuse_mul_mat( + out, expert_down, + layer.w2_qweight_type.weight_type).mul_(ww) + if current_hidden_state is None: + current_hidden_state = current_state + else: + current_hidden_state.add_(current_state) + final_hidden_states[tok] = current_hidden_state + return final_hidden_states + + class GGUFEmbeddingMethod(GGUFLinearMethod): """Embedding method for GGUF. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 4e8ef49235ed5..46247eaf2a60c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1245,9 +1245,24 @@ class GGUFModelLoader(BaseModelLoader): """ config = model_config.hf_config model_type = config.model_type + gguf_to_hf_name_map = {} # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" + if model_type in ("deepseek_v3", "deepseek_v2"): + model_type = "deepseek2" + # GGUF layer map assumes that we will have a merged expert weights + # so we need to map them manually + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: @@ -1258,10 +1273,10 @@ class GGUFModelLoader(BaseModelLoader): num_layers = config.num_hidden_layers name_map = gguf.get_tensor_name_map(arch, num_layers) with torch.device("meta"): - dummy_model = AutoModelForCausalLM.from_config(config) + dummy_model = AutoModelForCausalLM.from_config( + config, trust_remote_code=model_config.trust_remote_code) state_dict = dummy_model.state_dict() - gguf_to_hf_name_map = {} for hf_name in state_dict: name, suffix = hf_name.rsplit(".", 1) gguf_name = name_map.get_name(name) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 18f6f40b32f05..245c199f75b18 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -496,7 +496,6 @@ def gguf_quant_weights_iterator( weight = tensor.data weight_type = tensor.tensor_type name = gguf_to_hf_name_map[tensor.name] - if weight_type.name != "F32": name = name.replace("weight", "qweight") param = torch.tensor(weight)