mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:46:18 +08:00
[Misc] Directly use compressed-tensors for checkpoint definitions (#8909)
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5d264f4ab8
commit
22f8a69549
@ -31,3 +31,4 @@ pyyaml
|
|||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||||
einops # Required for Qwen2-VL.
|
einops # Required for Qwen2-VL.
|
||||||
|
compressed-tensors == 0.6.0 # required for compressed-tensors
|
||||||
|
|||||||
@ -17,7 +17,6 @@ requests
|
|||||||
ray[adag]==2.35
|
ray[adag]==2.35
|
||||||
sentence-transformers # required for embedding
|
sentence-transformers # required for embedding
|
||||||
soundfile # required for audio test
|
soundfile # required for audio test
|
||||||
compressed-tensors==0.4.0 # required for compressed-tensors
|
|
||||||
timm # required for internvl test
|
timm # required for internvl test
|
||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
|
|||||||
@ -6,13 +6,12 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationType
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
QuantizationType)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.config import CompressionFormat
|
||||||
|
from compressed_tensors.quantization import (QuantizationArgs,
|
||||||
|
QuantizationStrategy,
|
||||||
|
QuantizationType)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -16,8 +20,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
find_matched_target, is_activation_quantization_format,
|
||||||
QuantizationType, find_matched_target, is_activation_quantization_format,
|
|
||||||
should_ignore_layer)
|
should_ignore_layer)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|||||||
@ -3,14 +3,14 @@ from enum import Enum
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors import CompressionFormat
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
WNA16_SUPPORTED_BITS)
|
WNA16_SUPPORTED_BITS)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
CompressionFormat, QuantizationStrategy)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
QuantizationStrategy)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
QuantizationStrategy)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
||||||
requantize_with_max_scale)
|
requantize_with_max_scale)
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
QuantizationStrategy)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_int8_linear, convert_to_channelwise)
|
apply_int8_linear, convert_to_channelwise)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
from typing import Callable, List, Optional, Set
|
from typing import Callable, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import ActivationOrdering
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
ActivationOrdering)
|
|
||||||
from vllm.model_executor.layers.quantization.kernels import (
|
from vllm.model_executor.layers.quantization.kernels import (
|
||||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
|||||||
@ -1,111 +1,13 @@
|
|||||||
import re
|
import re
|
||||||
from enum import Enum
|
from typing import Iterable, Optional
|
||||||
from typing import Any, Dict, Iterable, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from compressed_tensors import CompressionFormat
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
FUSED_LAYER_NAME_MAPPING)
|
FUSED_LAYER_NAME_MAPPING)
|
||||||
|
|
||||||
|
|
||||||
class CompressionFormat(Enum):
|
|
||||||
dense = "dense"
|
|
||||||
sparse_bitmask = "sparse-bitmask"
|
|
||||||
naive_quantized = "naive-quantized"
|
|
||||||
float_quantized = "float-quantized"
|
|
||||||
int_quantized = "int-quantized"
|
|
||||||
pack_quantized = "pack-quantized"
|
|
||||||
marlin_24 = "marlin-24"
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizationType(str, Enum):
|
|
||||||
"""
|
|
||||||
Enum storing quantization type options
|
|
||||||
"""
|
|
||||||
|
|
||||||
INT = "int"
|
|
||||||
FLOAT = "float"
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizationStrategy(str, Enum):
|
|
||||||
"""
|
|
||||||
Enum storing quantization strategy options
|
|
||||||
"""
|
|
||||||
|
|
||||||
TENSOR = "tensor"
|
|
||||||
CHANNEL = "channel"
|
|
||||||
GROUP = "group"
|
|
||||||
BLOCK = "block"
|
|
||||||
TOKEN = "token"
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationOrdering(str, Enum):
|
|
||||||
"""
|
|
||||||
Enum storing strategies for activation ordering
|
|
||||||
|
|
||||||
Group: reorder groups and weight\n
|
|
||||||
Weight: only reorder weight, not groups. Slightly lower latency and
|
|
||||||
accuracy compared to group actorder\n
|
|
||||||
"""
|
|
||||||
|
|
||||||
GROUP = "group"
|
|
||||||
WEIGHT = "weight"
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizationArgs(BaseModel):
|
|
||||||
"""
|
|
||||||
User facing arguments used to define a quantization config
|
|
||||||
for weights or activations
|
|
||||||
|
|
||||||
:param num_bits: quantization bit depth
|
|
||||||
:param type: dtype to quantized to, either int or float
|
|
||||||
:param symmetric: whether or not quantization scale is symmetric
|
|
||||||
:param strategy: string determining the scope of scale/zero-point to apply
|
|
||||||
:param group_size: group length to use for the group strategy
|
|
||||||
:param block_structure: 2d block structure to use for the block
|
|
||||||
strategy, must be of the format "2x4", "8x16", etc.
|
|
||||||
:param dynamic: set True to perform dynamic quantization -
|
|
||||||
values will not be calibrated during calibration phase,
|
|
||||||
instead during inference new quantization ranges will be
|
|
||||||
observed with every sample. Defaults to False for static
|
|
||||||
quantization. Note that enabling dynamic quantization
|
|
||||||
will change the default observer to a memoryless one
|
|
||||||
:param actorder: whether to apply group quantization in decreasing order of
|
|
||||||
activation. Defaults to None for arbitrary ordering
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_bits: int = 8
|
|
||||||
type: QuantizationType = QuantizationType.INT
|
|
||||||
symmetric: bool = True
|
|
||||||
group_size: Optional[int] = None
|
|
||||||
strategy: Optional[QuantizationStrategy] = None
|
|
||||||
block_structure: Optional[str] = None
|
|
||||||
dynamic: bool = False
|
|
||||||
actorder: Union[ActivationOrdering, bool, None] = None
|
|
||||||
observer: str = Field(
|
|
||||||
default="minmax",
|
|
||||||
description=("The class to use to compute the quantization param - "
|
|
||||||
"scale and zero-point'"),
|
|
||||||
)
|
|
||||||
observer_kwargs: Dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description=
|
|
||||||
("optional dict of kwargs to be passed directly to torch quantization "
|
|
||||||
"Observers constructor excluding quantization range or symmetry"),
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("actorder", mode="before")
|
|
||||||
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return ActivationOrdering.GROUP if value else None
|
|
||||||
|
|
||||||
if isinstance(value, str):
|
|
||||||
return ActivationOrdering(value.lower())
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def is_activation_quantization_format(format: str) -> bool:
|
def is_activation_quantization_format(format: str) -> bool:
|
||||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||||
CompressionFormat.naive_quantized.value,
|
CompressionFormat.naive_quantized.value,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user