mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:35:01 +08:00
[Model] Deepseek GGUF support (#13167)
This commit is contained in:
parent
edf309ebbe
commit
7f0be2aa24
@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam
|
||||
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
|
||||
:::
|
||||
|
||||
GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path
|
||||
|
||||
```console
|
||||
# If you model is not supported by huggingface you can manually provide a huggingface compatible config path
|
||||
vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0
|
||||
```
|
||||
|
||||
You can also use the GGUF model directly through the LLM entrypoint:
|
||||
|
||||
```python
|
||||
|
||||
@ -229,6 +229,7 @@ class ModelConfig:
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
hf_config_path: Optional[str] = None,
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
@ -259,6 +260,7 @@ class ModelConfig:
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.hf_config_path = hf_config_path
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
@ -321,8 +323,9 @@ class ModelConfig:
|
||||
if self.enable_sleep_mode and not current_platform.is_cuda():
|
||||
raise ValueError("Sleep mode is only supported on CUDA devices.")
|
||||
|
||||
hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, config_format)
|
||||
hf_config = get_config(self.hf_config_path or self.model,
|
||||
trust_remote_code, revision, code_revision,
|
||||
config_format)
|
||||
|
||||
if hf_overrides_kw:
|
||||
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
||||
@ -947,7 +950,7 @@ class ModelConfig:
|
||||
def try_get_generation_config(self) -> Dict[str, Any]:
|
||||
if self.generation_config is None or self.generation_config == "auto":
|
||||
config = try_get_generation_config(
|
||||
self.model,
|
||||
self.hf_config_path or self.model,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
@ -93,6 +93,7 @@ class EngineArgs:
|
||||
model: str = 'facebook/opt-125m'
|
||||
served_model_name: Optional[Union[str, List[str]]] = None
|
||||
tokenizer: Optional[str] = None
|
||||
hf_config_path: Optional[str] = None
|
||||
task: TaskOption = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
tokenizer_mode: str = 'auto'
|
||||
@ -262,6 +263,12 @@ class EngineArgs:
|
||||
default=EngineArgs.tokenizer,
|
||||
help='Name or path of the huggingface tokenizer to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
parser.add_argument(
|
||||
"--hf-config-path",
|
||||
type=nullable_str,
|
||||
default=EngineArgs.hf_config_path,
|
||||
help='Name or path of the huggingface config to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
parser.add_argument(
|
||||
'--skip-tokenizer-init',
|
||||
action='store_true',
|
||||
@ -1076,6 +1083,7 @@ class EngineArgs:
|
||||
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
hf_config_path=self.hf_config_path,
|
||||
task=self.task,
|
||||
# We know this is not None because we set it in __post_init__
|
||||
tokenizer=cast(str, self.tokenizer),
|
||||
|
||||
@ -5,6 +5,7 @@ from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
@ -514,7 +515,12 @@ class FusedMoE(torch.nn.Module):
|
||||
# dimension intermediate_size_per_partition is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
@ -524,6 +530,20 @@ class FusedMoE(torch.nn.Module):
|
||||
if is_transposed:
|
||||
shard_dim = int(not shard_dim)
|
||||
|
||||
full_load = len(loaded_weight.shape) == 3
|
||||
if full_load:
|
||||
shard_dim += 1
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if shard_id in ["w1", "w3"]:
|
||||
final_shape[1] *= 2
|
||||
final_shape[shard_dim] = final_shape[
|
||||
shard_dim] // get_tensor_model_parallel_world_size()
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
expert_data = param.data if full_load else param.data[expert_id]
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
# this is needed for compressed-tensors only
|
||||
|
||||
@ -235,10 +235,23 @@ class ReplicatedLinear(LinearBase):
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# If the weight on disk does not have a shape, give it one
|
||||
# (such scales for AutoFp8).
|
||||
# Special case for GGUF
|
||||
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param.size() == loaded_weight.size()
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Tried to load weights of size {loaded_weight.size()}"
|
||||
f"to a parameter of size {param.size()}")
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@ -8,6 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
@ -29,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -49,6 +52,8 @@ class GGUFConfig(QuantizationConfig):
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return GGUFMoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
@ -184,6 +189,124 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
return out
|
||||
|
||||
|
||||
class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for GGUF.
|
||||
|
||||
Args:
|
||||
quant_config: The GGUF quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GGUFConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
|
||||
hidden_size)
|
||||
#gate up proj
|
||||
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w13_qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
})
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
|
||||
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w13_qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"ignore_warning": True
|
||||
})
|
||||
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w13_qweight_type", w13_qweight_type)
|
||||
|
||||
tensor_shape = (num_experts, intermediate_size_per_partition,
|
||||
hidden_size)
|
||||
#gate down proj
|
||||
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w2_qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"data_container": [],
|
||||
})
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
|
||||
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w2_qweight_type, {
|
||||
"is_gguf_weight_type": True,
|
||||
"weight_type": 0,
|
||||
"ignore_warning": True
|
||||
})
|
||||
|
||||
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||
self.act = SiluAndMul()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
final_hidden_states = torch.empty_like(x)
|
||||
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||
current_hidden_state = None
|
||||
for ww, ii in zip(w, idx):
|
||||
expert_up = layer.w13_qweight[ii]
|
||||
|
||||
out = _fuse_mul_mat(inp, expert_up,
|
||||
layer.w13_qweight_type.weight_type)
|
||||
out = self.act(out)
|
||||
|
||||
expert_down = layer.w2_qweight[ii]
|
||||
current_state = _fuse_mul_mat(
|
||||
out, expert_down,
|
||||
layer.w2_qweight_type.weight_type).mul_(ww)
|
||||
if current_hidden_state is None:
|
||||
current_hidden_state = current_state
|
||||
else:
|
||||
current_hidden_state.add_(current_state)
|
||||
final_hidden_states[tok] = current_hidden_state
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
"""Embedding method for GGUF.
|
||||
|
||||
|
||||
@ -1245,9 +1245,24 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
"""
|
||||
config = model_config.hf_config
|
||||
model_type = config.model_type
|
||||
gguf_to_hf_name_map = {}
|
||||
# hack: ggufs have a different name than transformers
|
||||
if model_type == "cohere":
|
||||
model_type = "command-r"
|
||||
if model_type in ("deepseek_v3", "deepseek_v2"):
|
||||
model_type = "deepseek2"
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
|
||||
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
if value == model_type:
|
||||
@ -1258,10 +1273,10 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
num_layers = config.num_hidden_layers
|
||||
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
||||
with torch.device("meta"):
|
||||
dummy_model = AutoModelForCausalLM.from_config(config)
|
||||
dummy_model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=model_config.trust_remote_code)
|
||||
state_dict = dummy_model.state_dict()
|
||||
|
||||
gguf_to_hf_name_map = {}
|
||||
for hf_name in state_dict:
|
||||
name, suffix = hf_name.rsplit(".", 1)
|
||||
gguf_name = name_map.get_name(name)
|
||||
|
||||
@ -496,7 +496,6 @@ def gguf_quant_weights_iterator(
|
||||
weight = tensor.data
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
name = name.replace("weight", "qweight")
|
||||
param = torch.tensor(weight)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user