mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-09 00:23:38 +08:00
[Model] Deepseek GGUF support (#13167)
This commit is contained in:
parent
edf309ebbe
commit
7f0be2aa24
@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam
|
|||||||
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
|
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path
|
||||||
|
|
||||||
|
```console
|
||||||
|
# If you model is not supported by huggingface you can manually provide a huggingface compatible config path
|
||||||
|
vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0
|
||||||
|
```
|
||||||
|
|
||||||
You can also use the GGUF model directly through the LLM entrypoint:
|
You can also use the GGUF model directly through the LLM entrypoint:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@ -229,6 +229,7 @@ class ModelConfig:
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
dtype: Union[str, torch.dtype],
|
dtype: Union[str, torch.dtype],
|
||||||
seed: int,
|
seed: int,
|
||||||
|
hf_config_path: Optional[str] = None,
|
||||||
allowed_local_media_path: str = "",
|
allowed_local_media_path: str = "",
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
code_revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
@ -259,6 +260,7 @@ class ModelConfig:
|
|||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.hf_config_path = hf_config_path
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
self.trust_remote_code = trust_remote_code
|
self.trust_remote_code = trust_remote_code
|
||||||
@ -321,8 +323,9 @@ class ModelConfig:
|
|||||||
if self.enable_sleep_mode and not current_platform.is_cuda():
|
if self.enable_sleep_mode and not current_platform.is_cuda():
|
||||||
raise ValueError("Sleep mode is only supported on CUDA devices.")
|
raise ValueError("Sleep mode is only supported on CUDA devices.")
|
||||||
|
|
||||||
hf_config = get_config(self.model, trust_remote_code, revision,
|
hf_config = get_config(self.hf_config_path or self.model,
|
||||||
code_revision, config_format)
|
trust_remote_code, revision, code_revision,
|
||||||
|
config_format)
|
||||||
|
|
||||||
if hf_overrides_kw:
|
if hf_overrides_kw:
|
||||||
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
||||||
@ -947,7 +950,7 @@ class ModelConfig:
|
|||||||
def try_get_generation_config(self) -> Dict[str, Any]:
|
def try_get_generation_config(self) -> Dict[str, Any]:
|
||||||
if self.generation_config is None or self.generation_config == "auto":
|
if self.generation_config is None or self.generation_config == "auto":
|
||||||
config = try_get_generation_config(
|
config = try_get_generation_config(
|
||||||
self.model,
|
self.hf_config_path or self.model,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
revision=self.revision,
|
revision=self.revision,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -93,6 +93,7 @@ class EngineArgs:
|
|||||||
model: str = 'facebook/opt-125m'
|
model: str = 'facebook/opt-125m'
|
||||||
served_model_name: Optional[Union[str, List[str]]] = None
|
served_model_name: Optional[Union[str, List[str]]] = None
|
||||||
tokenizer: Optional[str] = None
|
tokenizer: Optional[str] = None
|
||||||
|
hf_config_path: Optional[str] = None
|
||||||
task: TaskOption = "auto"
|
task: TaskOption = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
@ -262,6 +263,12 @@ class EngineArgs:
|
|||||||
default=EngineArgs.tokenizer,
|
default=EngineArgs.tokenizer,
|
||||||
help='Name or path of the huggingface tokenizer to use. '
|
help='Name or path of the huggingface tokenizer to use. '
|
||||||
'If unspecified, model name or path will be used.')
|
'If unspecified, model name or path will be used.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-config-path",
|
||||||
|
type=nullable_str,
|
||||||
|
default=EngineArgs.hf_config_path,
|
||||||
|
help='Name or path of the huggingface config to use. '
|
||||||
|
'If unspecified, model name or path will be used.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--skip-tokenizer-init',
|
'--skip-tokenizer-init',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -1076,6 +1083,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
return ModelConfig(
|
return ModelConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
hf_config_path=self.hf_config_path,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
# We know this is not None because we set it in __post_init__
|
# We know this is not None because we set it in __post_init__
|
||||||
tokenizer=cast(str, self.tokenizer),
|
tokenizer=cast(str, self.tokenizer),
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from enum import Enum
|
|||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
@ -514,7 +515,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# dimension intermediate_size_per_partition is used.
|
# dimension intermediate_size_per_partition is used.
|
||||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||||
|
|
||||||
expert_data = param.data[expert_id]
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||||
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||||
|
if is_gguf_weight_type:
|
||||||
|
param.weight_type = loaded_weight.item()
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
return
|
||||||
|
|
||||||
# is_transposed: if the dim to shard the weight
|
# is_transposed: if the dim to shard the weight
|
||||||
# should be flipped. Required by GPTQ, compressed-tensors
|
# should be flipped. Required by GPTQ, compressed-tensors
|
||||||
@ -524,6 +530,20 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if is_transposed:
|
if is_transposed:
|
||||||
shard_dim = int(not shard_dim)
|
shard_dim = int(not shard_dim)
|
||||||
|
|
||||||
|
full_load = len(loaded_weight.shape) == 3
|
||||||
|
if full_load:
|
||||||
|
shard_dim += 1
|
||||||
|
|
||||||
|
# Materialize GGUF UninitializedParameter
|
||||||
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||||
|
final_shape = list(loaded_weight.shape)
|
||||||
|
if shard_id in ["w1", "w3"]:
|
||||||
|
final_shape[1] *= 2
|
||||||
|
final_shape[shard_dim] = final_shape[
|
||||||
|
shard_dim] // get_tensor_model_parallel_world_size()
|
||||||
|
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||||
|
|
||||||
|
expert_data = param.data if full_load else param.data[expert_id]
|
||||||
# Case input scale: input_scale loading is only supported for fp8
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
# this is needed for compressed-tensors only
|
# this is needed for compressed-tensors only
|
||||||
|
|||||||
@ -235,10 +235,23 @@ class ReplicatedLinear(LinearBase):
|
|||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
# If the weight on disk does not have a shape, give it one
|
# If the weight on disk does not have a shape, give it one
|
||||||
# (such scales for AutoFp8).
|
# (such scales for AutoFp8).
|
||||||
|
# Special case for GGUF
|
||||||
|
|
||||||
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||||
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||||
|
if is_gguf_weight_type:
|
||||||
|
param.weight_type = loaded_weight.item()
|
||||||
|
|
||||||
|
# Materialize GGUF UninitializedParameter
|
||||||
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||||
|
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||||
|
|
||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
assert param.size() == loaded_weight.size()
|
assert param.size() == loaded_weight.size(), (
|
||||||
|
f"Tried to load weights of size {loaded_weight.size()}"
|
||||||
|
f"to a parameter of size {param.size()}")
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import gguf
|
import gguf
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
|
|||||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
@ -29,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
|
|||||||
return "gguf"
|
return "gguf"
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
return [torch.half, torch.bfloat16]
|
return [torch.half]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -49,6 +52,8 @@ class GGUFConfig(QuantizationConfig):
|
|||||||
return GGUFLinearMethod(self)
|
return GGUFLinearMethod(self)
|
||||||
elif isinstance(layer, VocabParallelEmbedding):
|
elif isinstance(layer, VocabParallelEmbedding):
|
||||||
return GGUFEmbeddingMethod(self)
|
return GGUFEmbeddingMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return GGUFMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -184,6 +189,124 @@ class GGUFLinearMethod(LinearMethodBase):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GGUFMoEMethod(FusedMoEMethodBase):
|
||||||
|
"""MoE method for GGUF.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The GGUF quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: GGUFConfig):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
|
||||||
|
hidden_size)
|
||||||
|
#gate up proj
|
||||||
|
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||||
|
set_weight_attrs(
|
||||||
|
w13_qweight, {
|
||||||
|
"input_dim": 1,
|
||||||
|
"output_dim": 0,
|
||||||
|
"tensor_shape": tensor_shape,
|
||||||
|
"is_gguf_weight": True,
|
||||||
|
"data_container": [],
|
||||||
|
})
|
||||||
|
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w13_qweight", w13_qweight)
|
||||||
|
|
||||||
|
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(w13_qweight_type, {
|
||||||
|
"is_gguf_weight_type": True,
|
||||||
|
"weight_type": 0,
|
||||||
|
"ignore_warning": True
|
||||||
|
})
|
||||||
|
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w13_qweight_type", w13_qweight_type)
|
||||||
|
|
||||||
|
tensor_shape = (num_experts, intermediate_size_per_partition,
|
||||||
|
hidden_size)
|
||||||
|
#gate down proj
|
||||||
|
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||||
|
set_weight_attrs(
|
||||||
|
w2_qweight, {
|
||||||
|
"input_dim": 1,
|
||||||
|
"output_dim": 0,
|
||||||
|
"tensor_shape": tensor_shape,
|
||||||
|
"is_gguf_weight": True,
|
||||||
|
"data_container": [],
|
||||||
|
})
|
||||||
|
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w2_qweight", w2_qweight)
|
||||||
|
|
||||||
|
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(w2_qweight_type, {
|
||||||
|
"is_gguf_weight_type": True,
|
||||||
|
"weight_type": 0,
|
||||||
|
"ignore_warning": True
|
||||||
|
})
|
||||||
|
|
||||||
|
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||||
|
self.act = SiluAndMul()
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
):
|
||||||
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
final_hidden_states = torch.empty_like(x)
|
||||||
|
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
|
||||||
|
inp = x[tok].reshape((1, ) + x.shape[1:])
|
||||||
|
current_hidden_state = None
|
||||||
|
for ww, ii in zip(w, idx):
|
||||||
|
expert_up = layer.w13_qweight[ii]
|
||||||
|
|
||||||
|
out = _fuse_mul_mat(inp, expert_up,
|
||||||
|
layer.w13_qweight_type.weight_type)
|
||||||
|
out = self.act(out)
|
||||||
|
|
||||||
|
expert_down = layer.w2_qweight[ii]
|
||||||
|
current_state = _fuse_mul_mat(
|
||||||
|
out, expert_down,
|
||||||
|
layer.w2_qweight_type.weight_type).mul_(ww)
|
||||||
|
if current_hidden_state is None:
|
||||||
|
current_hidden_state = current_state
|
||||||
|
else:
|
||||||
|
current_hidden_state.add_(current_state)
|
||||||
|
final_hidden_states[tok] = current_hidden_state
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||||
"""Embedding method for GGUF.
|
"""Embedding method for GGUF.
|
||||||
|
|
||||||
|
|||||||
@ -1245,9 +1245,24 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
"""
|
"""
|
||||||
config = model_config.hf_config
|
config = model_config.hf_config
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
gguf_to_hf_name_map = {}
|
||||||
# hack: ggufs have a different name than transformers
|
# hack: ggufs have a different name than transformers
|
||||||
if model_type == "cohere":
|
if model_type == "cohere":
|
||||||
model_type = "command-r"
|
model_type = "command-r"
|
||||||
|
if model_type in ("deepseek_v3", "deepseek_v2"):
|
||||||
|
model_type = "deepseek2"
|
||||||
|
# GGUF layer map assumes that we will have a merged expert weights
|
||||||
|
# so we need to map them manually
|
||||||
|
for idx in range(config.num_hidden_layers):
|
||||||
|
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
|
||||||
|
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
|
||||||
|
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
|
||||||
|
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||||
|
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
|
||||||
|
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||||
|
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
|
||||||
|
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||||
|
|
||||||
arch = None
|
arch = None
|
||||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||||
if value == model_type:
|
if value == model_type:
|
||||||
@ -1258,10 +1273,10 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
num_layers = config.num_hidden_layers
|
num_layers = config.num_hidden_layers
|
||||||
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
dummy_model = AutoModelForCausalLM.from_config(config)
|
dummy_model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=model_config.trust_remote_code)
|
||||||
state_dict = dummy_model.state_dict()
|
state_dict = dummy_model.state_dict()
|
||||||
|
|
||||||
gguf_to_hf_name_map = {}
|
|
||||||
for hf_name in state_dict:
|
for hf_name in state_dict:
|
||||||
name, suffix = hf_name.rsplit(".", 1)
|
name, suffix = hf_name.rsplit(".", 1)
|
||||||
gguf_name = name_map.get_name(name)
|
gguf_name = name_map.get_name(name)
|
||||||
|
|||||||
@ -496,7 +496,6 @@ def gguf_quant_weights_iterator(
|
|||||||
weight = tensor.data
|
weight = tensor.data
|
||||||
weight_type = tensor.tensor_type
|
weight_type = tensor.tensor_type
|
||||||
name = gguf_to_hf_name_map[tensor.name]
|
name = gguf_to_hf_name_map[tensor.name]
|
||||||
|
|
||||||
if weight_type.name != "F32":
|
if weight_type.name != "F32":
|
||||||
name = name.replace("weight", "qweight")
|
name = name.replace("weight", "qweight")
|
||||||
param = torch.tensor(weight)
|
param = torch.tensor(weight)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user