mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[CORE] [QUANT] Support for GPTQModel's dynamic quantization per module override/control (#7086)
This commit is contained in:
parent
2c2b560f48
commit
36a08630e8
68
tests/quantization/test_gptq_dynamic.py
Normal file
68
tests/quantization/test_gptq_dynamic.py
Normal file
@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests whether gptq models with dynamic quantized can be loaded.
|
||||
|
||||
Run `pytest tests/quantization/test_gptq_dynamic.py --forked`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_dynamic_override)
|
||||
|
||||
PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
MODEL_QUANT = [
|
||||
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue",
|
||||
True),
|
||||
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse",
|
||||
False),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
|
||||
def test_gptq_with_dynamic(vllm_runner, model_id: str,
|
||||
use_marlin_kernel: bool):
|
||||
|
||||
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
|
||||
|
||||
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
|
||||
GPTQLinearMethod)
|
||||
|
||||
for name, submodule in (vllm_model.model.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model.named_modules()):
|
||||
if name == "lm_head":
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# desc_act=True
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.weight_bits == 4
|
||||
assert config.group_size == 128
|
||||
assert config.desc_act
|
||||
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# desc_act=False
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert get_dynamic_override(config, layer_name=name,
|
||||
key="bits") == 8
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="group_size") == 32
|
||||
assert not get_dynamic_override(
|
||||
config, layer_name=name, key="desc_act")
|
||||
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
|
||||
|
||||
del vllm_model
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -17,31 +16,31 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
|
||||
PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
MODELS_QUANT = [(
|
||||
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
|
||||
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
|
||||
MODELS_QUANT = [
|
||||
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True),
|
||||
("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False),
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
|
||||
@pytest.mark.parametrize("model_id, lm_head_quantized", MODELS_QUANT)
|
||||
def test_lm_head(
|
||||
vllm_runner,
|
||||
model_lm_head_quant: Tuple[str, bool],
|
||||
model_id: str,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
model, lm_head_quantized = model_lm_head_quant
|
||||
|
||||
with vllm_runner(model, dtype=torch.float16,
|
||||
with vllm_runner(model_id, dtype=torch.float16,
|
||||
max_model_len=2048) as vllm_model:
|
||||
|
||||
def check_model(model):
|
||||
lm_head_layer = model.lm_head
|
||||
|
||||
if lm_head_quantized:
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod,
|
||||
MarlinLinearMethod))
|
||||
else:
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
UnquantizedEmbeddingMethod)
|
||||
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
@ -1039,7 +1039,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.linear_method.apply(lm_head, hidden_states)
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
|
||||
|
||||
@ -108,9 +108,9 @@ class LogitsProcessor(nn.Module):
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
|
||||
# Gather logits for TP
|
||||
logits = self._gather_logits(logits)
|
||||
|
||||
@ -3,16 +3,17 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@ -32,7 +33,33 @@ class GPTQConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
@ -47,7 +74,8 @@ class GPTQConfig(QuantizationConfig):
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -68,19 +96,20 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||
dynamic)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQLinearMethod(self)
|
||||
return None
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
@ -9,17 +9,21 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@ -47,12 +51,41 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
@ -68,7 +101,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -88,6 +122,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
@ -95,7 +132,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||
lm_head_quantized)
|
||||
lm_head_quantized, dynamic)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
@ -120,17 +157,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
|
||||
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return None
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
@ -143,7 +178,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
# Marlin conversion is only valid if required properties are found
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
94
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
94
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import QuantizationConfig
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, UnquantizedEmbeddingMethod)
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits",
|
||||
config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size",
|
||||
config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act",
|
||||
config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}")
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits,
|
||||
config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits.")
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: QuantizationConfig,
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool,
|
||||
None] = None) -> Union[Dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = isinstance(
|
||||
layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override( # noqa: E712
|
||||
cloned_config, # noqa: E712
|
||||
layer_name=prefix) == False: # noqa: E712
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
@ -226,24 +226,24 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedEmbeddingMethod()
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method))
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -260,13 +260,13 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.linear_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
@ -412,8 +412,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.linear_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user