[Model] Deepseek GGUF support (#13167)

This commit is contained in:
Szymon Ożóg 2025-02-27 11:08:35 +01:00 committed by GitHub
parent edf309ebbe
commit 7f0be2aa24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 198 additions and 10 deletions

View File

@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
:::
GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path
```console
# If you model is not supported by huggingface you can manually provide a huggingface compatible config path
vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0
```
You can also use the GGUF model directly through the LLM entrypoint:
```python

View File

@ -229,6 +229,7 @@ class ModelConfig:
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
hf_config_path: Optional[str] = None,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
@ -259,6 +260,7 @@ class ModelConfig:
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = model
self.hf_config_path = hf_config_path
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
@ -321,8 +323,9 @@ class ModelConfig:
if self.enable_sleep_mode and not current_platform.is_cuda():
raise ValueError("Sleep mode is only supported on CUDA devices.")
hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format)
hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision,
config_format)
if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
@ -947,7 +950,7 @@ class ModelConfig:
def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)

View File

@ -93,6 +93,7 @@ class EngineArgs:
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
hf_config_path: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
@ -262,6 +263,12 @@ class EngineArgs:
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=nullable_str,
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
@ -1076,6 +1083,7 @@ class EngineArgs:
return ModelConfig(
model=self.model,
hf_config_path=self.hf_config_path,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),

View File

@ -5,6 +5,7 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.distributed import (get_tensor_model_parallel_rank,
@ -514,7 +515,12 @@ class FusedMoE(torch.nn.Module):
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
param.data.copy_(loaded_weight)
return
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
@ -524,6 +530,20 @@ class FusedMoE(torch.nn.Module):
if is_transposed:
shard_dim = int(not shard_dim)
full_load = len(loaded_weight.shape) == 3
if full_load:
shard_dim += 1
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
final_shape = list(loaded_weight.shape)
if shard_id in ["w1", "w3"]:
final_shape[1] *= 2
final_shape[shard_dim] = final_shape[
shard_dim] // get_tensor_model_parallel_world_size()
param.materialize(final_shape, dtype=loaded_weight.dtype)
expert_data = param.data if full_load else param.data[expert_id]
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only

View File

@ -235,10 +235,23 @@ class ReplicatedLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param.size() == loaded_weight.size()
assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}")
param.data.copy_(loaded_weight)
def forward(self,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import gguf
import torch
@ -8,6 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
@ -29,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
@ -49,6 +52,8 @@ class GGUFConfig(QuantizationConfig):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self)
return None
@ -184,6 +189,124 @@ class GGUFLinearMethod(LinearMethodBase):
return out
class GGUFMoEMethod(FusedMoEMethodBase):
"""MoE method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
hidden_size)
#gate up proj
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w13_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w13_qweight, extra_weight_attrs)
layer.register_parameter("w13_qweight", w13_qweight)
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w13_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
layer.register_parameter("w13_qweight_type", w13_qweight_type)
tensor_shape = (num_experts, intermediate_size_per_partition,
hidden_size)
#gate down proj
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w2_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w2_qweight, extra_weight_attrs)
layer.register_parameter("w2_qweight", w2_qweight)
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w2_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
self.act = SiluAndMul()
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
final_hidden_states = torch.empty_like(x)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = layer.w13_qweight[ii]
out = _fuse_mul_mat(inp, expert_up,
layer.w13_qweight_type.weight_type)
out = self.act(out)
expert_down = layer.w2_qweight[ii]
current_state = _fuse_mul_mat(
out, expert_down,
layer.w2_qweight_type.weight_type).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
final_hidden_states[tok] = current_hidden_state
return final_hidden_states
class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.

View File

@ -1245,9 +1245,24 @@ class GGUFModelLoader(BaseModelLoader):
"""
config = model_config.hf_config
model_type = config.model_type
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
@ -1258,10 +1273,10 @@ class GGUFModelLoader(BaseModelLoader):
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(config)
dummy_model = AutoModelForCausalLM.from_config(
config, trust_remote_code=model_config.trust_remote_code)
state_dict = dummy_model.state_dict()
gguf_to_hf_name_map = {}
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)

View File

@ -496,7 +496,6 @@ def gguf_quant_weights_iterator(
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
name = name.replace("weight", "qweight")
param = torch.tensor(weight)