mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:45:01 +08:00
[Transform] [Quantization] Add transforms to compressed tensors (#22486)
This commit is contained in:
parent
c8851a4723
commit
22feac8e95
@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -33,6 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -602,7 +604,7 @@ class HfRunner:
|
||||
def _hidden_states_to_logprobs(
|
||||
self,
|
||||
hidden_states: tuple[tuple[torch.Tensor, ...], ...],
|
||||
num_logprobs: int,
|
||||
num_logprobs: Optional[int],
|
||||
) -> tuple[list[dict[int, float]], int]:
|
||||
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
|
||||
output_len = len(hidden_states)
|
||||
@ -630,7 +632,7 @@ class HfRunner:
|
||||
self,
|
||||
prompts: list[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_logprobs: Optional[int],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
@ -677,7 +679,7 @@ class HfRunner:
|
||||
self,
|
||||
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_logprobs: Optional[int],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[TokensTextLogprobs]:
|
||||
@ -966,7 +968,7 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: list[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int] = None,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
@ -991,11 +993,40 @@ class VllmRunner:
|
||||
videos=videos,
|
||||
**kwargs)
|
||||
|
||||
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
|
||||
"""
|
||||
Return the perplexity score associated with generating the prompts
|
||||
|
||||
:param prompts: list of prompts to score
|
||||
:return: perplexity score of each prompt
|
||||
"""
|
||||
outputs = self.generate_greedy_logprobs(prompts,
|
||||
max_tokens=1,
|
||||
num_logprobs=None,
|
||||
num_prompt_logprobs=0)
|
||||
|
||||
perplexities = []
|
||||
for output in outputs:
|
||||
output = cast(TokensTextLogprobsPromptLogprobs, output)
|
||||
token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
|
||||
assert token_datas[0] is None
|
||||
token_log_probs = []
|
||||
for token_data in token_datas[1:]:
|
||||
assert token_data is not None
|
||||
assert len(token_data) == 1
|
||||
token_log_prob = list(token_data.values())[0].logprob
|
||||
token_log_probs.append(token_log_prob)
|
||||
|
||||
perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
|
||||
perplexities.append(perplexity)
|
||||
|
||||
return perplexities
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
) -> Union[list[TokensTextLogprobs],
|
||||
|
||||
@ -719,3 +719,25 @@ def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
print(output)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="This test is skipped on non-CUDA platform.")
|
||||
@pytest.mark.parametrize("model,prompt,exp_perplexity", [
|
||||
(
|
||||
"nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16",
|
||||
"Flat is better than nested.\nSparse is better than dense.",
|
||||
150.0,
|
||||
),
|
||||
(
|
||||
"nm-testing/Llama-3.2-1B-Instruct-quip-w4a16",
|
||||
"Flat is better than nested.\nSparse is better than dense.",
|
||||
150.0,
|
||||
),
|
||||
])
|
||||
def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
|
||||
exp_perplexity):
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
perplexity = llm.generate_prompt_perplexity([prompt])[0]
|
||||
print(perplexity)
|
||||
assert perplexity <= exp_perplexity
|
||||
@ -35,6 +35,7 @@ logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod",
|
||||
"CompressedTensorsLinearTransformMethod",
|
||||
"BitBLASLinearMethod",
|
||||
"GPTQBitBLASLinearMethod",
|
||||
"AWQMarlinLinearMethod",
|
||||
@ -199,6 +200,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# special postprocessing for CPU SGL
|
||||
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
N, K = layer.weight.size()
|
||||
@ -1470,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase):
|
||||
self.bias = torch.nn.Parameter()
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
"weight_loader": self.weight_loader_v1,
|
||||
})
|
||||
else:
|
||||
self.bias = None
|
||||
@ -1580,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase):
|
||||
k, v = kv_enc.split(self.kv_size, dim=-1)
|
||||
return q, k, v
|
||||
|
||||
def weight_loader_v1(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
# just like all other parameters, does not yet
|
||||
# support loading bias with weight_loader_v2
|
||||
layer = (self.q_proj_decoder
|
||||
if loaded_shard_id == "q" else self.kv_proj_encoder)
|
||||
target_param = self.select_proj_params(layer, param)
|
||||
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
|
||||
def weight_loader(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
|
||||
@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -30,6 +31,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_ignore_list: list[str],
|
||||
kv_cache_scheme: Optional[dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
transform_config: Optional[TransformConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
|
||||
if transform_config is not None:
|
||||
self.transform_config = TransformConfig.model_validate(
|
||||
transform_config)
|
||||
else:
|
||||
self.transform_config = None
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=self.ignore,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
if scheme is None:
|
||||
return UnquantizedLinearMethod()
|
||||
layer.scheme = scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
# collect schemes
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
input_tfms, output_tfms = get_linear_transform_schemes(
|
||||
layer, prefix, self.transform_config,
|
||||
self.packed_modules_mapping)
|
||||
|
||||
# choose quantization method
|
||||
quant_method: LinearMethodBase = UnquantizedLinearMethod()
|
||||
if quant_scheme is not None:
|
||||
layer.scheme = quant_scheme
|
||||
quant_method = CompressedTensorsLinearMethod(self)
|
||||
|
||||
# choose transform method
|
||||
if any((input_tfms, output_tfms)):
|
||||
return CompressedTensorsLinearTransformMethod.from_schemes(
|
||||
quant_method, input_tfms, output_tfms)
|
||||
|
||||
else:
|
||||
return quant_method
|
||||
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -537,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||
# so we do not have to re-write these functions
|
||||
# need to make accelerate optional in ct to do this
|
||||
# TODO (@kylesayrs): support ignore module names with ct matching utils
|
||||
if should_ignore_layer(layer_name,
|
||||
ignore=self.ignore,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return None
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
@ -722,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
|
||||
@ -0,0 +1,227 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Generator
|
||||
from itertools import accumulate
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (TransformArgs, TransformConfig,
|
||||
TransformLocation, TransformScheme)
|
||||
from compressed_tensors.utils import is_match
|
||||
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
LinearMethodBase,
|
||||
QKVCrossParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
|
||||
HadamardTransform)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple)
|
||||
|
||||
|
||||
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
"""
|
||||
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
|
||||
input and output transforms to either side of the original apply method
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_schemes(
|
||||
cls, quant_method: LinearMethodBase, input_tfms: dict[int,
|
||||
TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple]
|
||||
) -> "CompressedTensorsLinearTransformMethod":
|
||||
assert input_tfms or output_tfms
|
||||
|
||||
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
|
||||
# hadacore and fwht can be selected by Transform module
|
||||
|
||||
return cls(quant_method, input_tfms, output_tfms)
|
||||
|
||||
def __init__(self, quant_method: LinearMethodBase,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple]):
|
||||
self.quant_method = quant_method
|
||||
self.input_tfms = input_tfms
|
||||
self.output_tfms = output_tfms
|
||||
|
||||
self.input_transform: Optional[HadamardTransform] = None
|
||||
self.output_transform: Optional[HadamardTransform] = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
# get weight loader for transforms
|
||||
weight_loader: Callable = extra_weight_attrs.get(
|
||||
"weight_loader") # type: ignore[assignment]
|
||||
|
||||
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
|
||||
# transforms (specifically SharedWeightParameter) requires
|
||||
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
|
||||
# hack around this by getting weight loader v1 so ULM can load correctly
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
|
||||
if isinstance(layer, QKVCrossParallelLinear):
|
||||
weight_loader_v1 = layer.weight_loader_v1
|
||||
else:
|
||||
weight_loader_v1 = layer.weight_loader
|
||||
extra_weight_attrs["weight_loader"] = weight_loader_v1
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=layer,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
**extra_weight_attrs)
|
||||
|
||||
# validate schemes
|
||||
num_partitions = len(output_partition_sizes)
|
||||
self._validate_tfm_schemes(num_partitions)
|
||||
|
||||
# create submodules for weight loading
|
||||
if len(self.input_tfms) > 0:
|
||||
scheme_name = list(self.input_tfms.values())[0].scheme_name
|
||||
location = list(self.input_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(self.input_tfms, layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.input_transform = transform
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(self.output_tfms, layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.output_transform = transform
|
||||
|
||||
# compute partition ranges for slicing activations
|
||||
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
|
||||
self.partition_ranges = list(zip(starts, output_partition_sizes))
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
for submodule in layer.children():
|
||||
if isinstance(submodule, HadamardTransform):
|
||||
submodule.process_weights_after_loading()
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.input_transform is not None:
|
||||
x = self.input_transform(x)
|
||||
|
||||
assert bias is None
|
||||
x = self.quant_method.apply(layer, x, bias)
|
||||
|
||||
# TODO (@ksayers): Write a triton kernel to do this in parallel
|
||||
if self.output_transform is not None:
|
||||
for part_id, (start, length) in enumerate(self.partition_ranges):
|
||||
x[:, start:start + length] = self.output_transform(
|
||||
x[:, start:start + length], part_id=part_id)
|
||||
|
||||
return x
|
||||
|
||||
def _validate_tfm_schemes(self, num_partitions: int):
|
||||
if len(self.input_tfms) > 0:
|
||||
if 0 not in self.input_tfms:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
for part_index in range(num_partitions):
|
||||
if self.input_tfms[part_index] != self.input_tfms[0]:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
|
||||
for tfm in self.output_tfms.values():
|
||||
if tfm.scheme_name != scheme_name:
|
||||
raise ValueError("Must have same scheme name")
|
||||
if tfm.args.location != location:
|
||||
raise ValueError("Must have same location")
|
||||
|
||||
return self.input_tfms, self.output_tfms
|
||||
|
||||
|
||||
def get_linear_transform_schemes(
|
||||
layer: torch.nn.Module, layer_name: str,
|
||||
transform_config: Optional[TransformConfig],
|
||||
packed_modules_mapping: dict[str, list[str]]
|
||||
) -> tuple[dict[int, TransformTuple], dict[
|
||||
int, TransformTuple]]: # [input_transform, [output_transform, ...]]
|
||||
# there can only be one transform input scheme per (fused) module
|
||||
input_tfms = {}
|
||||
output_tfms = {}
|
||||
|
||||
partition_names = get_layer_partition_names(layer_name,
|
||||
packed_modules_mapping)
|
||||
|
||||
for scheme_name, scheme, args in get_schemes_args(transform_config):
|
||||
for part_index, part_name in enumerate(partition_names):
|
||||
if is_match(part_name, layer, args.targets,
|
||||
args.ignore) and args.is_online():
|
||||
if args.location == TransformLocation.INPUT:
|
||||
input_tfms[part_index] = TransformTuple(
|
||||
scheme_name, scheme, args)
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
output_tfms[part_index] = TransformTuple(
|
||||
scheme_name, scheme, args)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Cannot apply `{args.location}` "
|
||||
f"transform to `{layer_name}`")
|
||||
|
||||
return (input_tfms, output_tfms)
|
||||
|
||||
|
||||
def get_schemes_args(
|
||||
transform_config: Optional[TransformConfig]
|
||||
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
|
||||
if transform_config is None:
|
||||
return
|
||||
|
||||
for scheme_name, scheme in transform_config.config_groups.items():
|
||||
for args in scheme.apply:
|
||||
yield (scheme_name, scheme, args)
|
||||
|
||||
|
||||
def get_layer_partition_names(
|
||||
layer_name: str, packed_modules_mapping: dict[str,
|
||||
list[str]]) -> list[str]:
|
||||
"""
|
||||
Get all partition names associated with this layer.
|
||||
Names are returned in order of their partition indices.
|
||||
|
||||
```python
|
||||
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
|
||||
|
||||
assert get_layer_partition_names(
|
||||
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
|
||||
assert get_layer_partition_names(
|
||||
"mlp.down_proj", mapping) == ["down_proj"]
|
||||
"""
|
||||
for fused_suffix, part_suffixes in packed_modules_mapping.items():
|
||||
if layer_name.endswith(fused_suffix):
|
||||
return [
|
||||
layer_name.removesuffix(fused_suffix) + part_suffix
|
||||
for part_suffix in part_suffixes
|
||||
]
|
||||
|
||||
return [layer_name]
|
||||
@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Hashable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import TransformLocation, TransformScheme
|
||||
from torch import Tensor
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parameter import SharedWeightParameter
|
||||
|
||||
|
||||
class HadamardTransform(torch.nn.Module):
|
||||
"""
|
||||
Class which handles weight loading, postprocessing, and application of
|
||||
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
|
||||
and attention transforms method (not implemented yet)
|
||||
"""
|
||||
transforms: dict[int, TransformTuple] # info parsed from transforms config
|
||||
weight: SharedWeightParameter # container for shared tensors
|
||||
|
||||
kernel: Callable # function used during application
|
||||
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
|
||||
|
||||
def __init__(self,
|
||||
transforms: dict[int, TransformTuple],
|
||||
layer: torch.nn.Module,
|
||||
weight_loader: Callable,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
kernel: Optional[Callable] = None):
|
||||
super().__init__()
|
||||
self.transforms = transforms
|
||||
self.scales = {}
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
raise NotImplementedError("Online transforms with tensor "
|
||||
"parallelism is not supported")
|
||||
|
||||
# Similar to row/col parallel params, but tensors are separate
|
||||
# to allow for loading with shared memory
|
||||
self.weight = SharedWeightParameter(weight_loader=weight_loader)
|
||||
|
||||
# create shared partition data for each partition of the original weight
|
||||
input_size = input_size_per_partition
|
||||
for part_index, (_scheme_name, scheme,
|
||||
args) in self.transforms.items():
|
||||
output_size = output_partition_sizes[part_index]
|
||||
weight_size = self._get_weight_size(layer, args.location,
|
||||
input_size, output_size)
|
||||
|
||||
data_key = self._get_data_key(scheme, weight_size)
|
||||
self.weight.add_partition(
|
||||
part_index,
|
||||
data_key,
|
||||
size=(weight_size, weight_size),
|
||||
dtype=scheme.precision,
|
||||
)
|
||||
|
||||
# validate that shared tensors and schemes are correct
|
||||
self._validate_input_transforms()
|
||||
|
||||
# select kernel based on transform schemes
|
||||
self.kernel = self._infer_kernel() if kernel is None else kernel
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for part_id in self.weight.partitions:
|
||||
data = self.weight.partitions[part_id].data
|
||||
|
||||
# required by torch.compile
|
||||
self.weight.process_weights_after_loading()
|
||||
|
||||
# precompute scale as a runtime multiply, not division
|
||||
# do not fold into weight in order to utilize FWHT
|
||||
self.scales[part_id] = 1 / math.sqrt(data.size(0))
|
||||
|
||||
# FUTURE: avoid runtime tranpose by processing weights
|
||||
# prior to apply
|
||||
|
||||
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
|
||||
if part_id not in self.weight.partitions:
|
||||
return value
|
||||
|
||||
weight = self.weight.partitions[part_id]
|
||||
weight = weight if self.transforms[
|
||||
part_id].args.inverse else weight.T # linear := x(W.T)
|
||||
scale = self.scales[part_id]
|
||||
return self.kernel(self, value.to(weight.dtype), weight, None).to(
|
||||
value.dtype) * scale
|
||||
|
||||
def _get_data_key(self, scheme: TransformScheme,
|
||||
weight_size: int) -> Hashable:
|
||||
return (id(scheme), weight_size)
|
||||
|
||||
def _get_weight_size(self, layer: torch.nn.Module,
|
||||
location: TransformLocation, input_size: int,
|
||||
output_size: int) -> int:
|
||||
if isinstance(layer, LinearBase):
|
||||
if location == TransformLocation.INPUT:
|
||||
return input_size
|
||||
|
||||
elif location == TransformLocation.OUTPUT:
|
||||
return output_size
|
||||
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if location == TransformLocation.INPUT:
|
||||
return output_size
|
||||
|
||||
elif location == TransformLocation.OUTPUT:
|
||||
return input_size
|
||||
|
||||
raise ValueError()
|
||||
|
||||
def _validate_input_transforms(self):
|
||||
assert len(self.transforms) > 0
|
||||
location = list(self.transforms.values())[0].args.location
|
||||
|
||||
if location == TransformLocation.INPUT:
|
||||
first_data = self.weight.partitions[0].data
|
||||
for partition in self.weight.partitions.values():
|
||||
if partition.data.data_ptr() != first_data.data_ptr():
|
||||
raise ValueError("")
|
||||
|
||||
def _infer_kernel(self) -> Callable:
|
||||
# TODO (@ksayers): use fwht, hadacore
|
||||
return dispatch_unquantized_gemm()
|
||||
@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod)
|
||||
|
||||
|
||||
# Because qutlass fuses hadamard with quantization, it cannot automatically be
|
||||
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
|
||||
# Therefore, a separate scheme must be created for each quantized dtype
|
||||
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod):
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# fused hadamard quant linear method
|
||||
raise NotImplementedError()
|
||||
@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
from compressed_tensors.transform import TransformArgs, TransformScheme
|
||||
|
||||
__all__ = ["TransformTuple"]
|
||||
|
||||
|
||||
class TransformTuple(NamedTuple):
|
||||
scheme_name: str
|
||||
scheme: TransformScheme
|
||||
args: TransformArgs
|
||||
@ -1,13 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Hashable
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import _make_synced_weight_loader
|
||||
|
||||
@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter):
|
||||
into the parameter when the provided weight loader is called.
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, **kwargs):
|
||||
def __new__(cls, data: Optional[torch.Tensor], **kwargs):
|
||||
|
||||
return super().__new__(cls, data=data, requires_grad=False)
|
||||
|
||||
@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter):
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
self._assert_and_load(loaded_weight)
|
||||
|
||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
# if not int, assume shard_id for qkv
|
||||
# map to int and return
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
assert isinstance(shard_id, str)
|
||||
assert shard_id in qkv_idxs
|
||||
return qkv_idxs[shard_id]
|
||||
|
||||
|
||||
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
shard_id = kwargs.get("shard_id")
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
# if not int, assume shard_id for qkv
|
||||
# map to int and return
|
||||
assert isinstance(shard_id, str)
|
||||
assert shard_id in self.qkv_idxs
|
||||
return self.qkv_idxs[shard_id]
|
||||
|
||||
# For row parallel layers, no sharding needed
|
||||
# load weight into parameter as is
|
||||
def load_row_parallel_weight(self, *args, **kwargs):
|
||||
@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
pass
|
||||
|
||||
|
||||
class SharedWeightParameter(BasevLLMParameter):
|
||||
"""
|
||||
Parameter for weights with many shared tensors across a model
|
||||
|
||||
For example, when applying transforms to the "gate" and "up" partitions of
|
||||
`MergedColumnParallelLinear`, the transform weights must stay separate
|
||||
tensors in order to allow for tensor memory sharing between layers.
|
||||
"""
|
||||
# global registry for sharing tensors based on passed `data_key`
|
||||
# this dict holds weaksrefs to avoid memory leak after model cleanup
|
||||
tensors_registry: WeakValueDictionary = WeakValueDictionary()
|
||||
|
||||
# local container for strong references to shared tensors
|
||||
# this set compensates for the fact that torch.nn.Parameter
|
||||
# and Parameter subclasses do not hold reliable references to tensors
|
||||
local_tensors: set[torch.Tensor]
|
||||
|
||||
# dictionary mapping partition indices to associated parameters
|
||||
partitions: dict[int, Union[ModelWeightParameter, Parameter]]
|
||||
|
||||
def __new__(cls, **kwargs):
|
||||
return super().__new__(cls, data=None, **kwargs)
|
||||
|
||||
def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs):
|
||||
weight_loader: Callable = kwargs.get(
|
||||
"weight_loader") # type: ignore[assignment]
|
||||
super().__init__(data=None, weight_loader=weight_loader)
|
||||
|
||||
self.local_tensors = set()
|
||||
self.partitions = {}
|
||||
self.kwargs = {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": output_dim,
|
||||
"weight_loader": self._fake_weight_loader
|
||||
}
|
||||
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if self.tp_size > 1:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not "
|
||||
"currently support tensor parallelism")
|
||||
|
||||
def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
|
||||
"""
|
||||
Add a partition to the weight parameter. Partitions whose `data_key`
|
||||
is the same will share tensor data
|
||||
|
||||
:param index: index of partition to add
|
||||
:param data_key: hashable key used to key shared tensors
|
||||
:param *args: arguments for `torch.empty`
|
||||
:param **kwargs: keyword arguments for `torch.empty`
|
||||
"""
|
||||
# load (shared) tensor using `data_key`
|
||||
if data_key not in self.tensors_registry:
|
||||
data = torch.empty(*args, **kwargs)
|
||||
self.tensors_registry[data_key] = data
|
||||
else:
|
||||
data = self.tensors_registry[data_key]
|
||||
|
||||
# create associated model parameter
|
||||
self.partitions[index] = ModelWeightParameter(
|
||||
data=data, **self.kwargs) # type: ignore[arg-type]
|
||||
|
||||
# hold local reference, since ModelWeightParameter does not
|
||||
# see https://github.com/pytorch/pytorch/issues/75932
|
||||
self.local_tensors.add(data)
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
assert len(self.partitions) == 1 and 0 in self.partitions
|
||||
partition = self.partitions[0]
|
||||
|
||||
ModelWeightParameter.load_column_parallel_weight(
|
||||
partition, loaded_weight)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
assert len(self.partitions) == 1 and 0 in self.partitions
|
||||
partition = self.partitions[0]
|
||||
|
||||
ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight)
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
partition_id = kwargs.pop("shard_id")
|
||||
partition_id = self._shard_id_as_int(partition_id)
|
||||
partition = self.partitions[partition_id]
|
||||
|
||||
input_dim = self.kwargs.get("input_dim")
|
||||
shard_size = partition.data.size(input_dim) // self.tp_size
|
||||
shard_offset = self.tp_rank * shard_size
|
||||
|
||||
ModelWeightParameter.load_merged_column_weight(
|
||||
partition,
|
||||
loaded_weight,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
partition_id = self._shard_id_as_int(kwargs.pop("shard_id"))
|
||||
partition = self.partitions[partition_id]
|
||||
|
||||
input_dim = self.kwargs.get("input_dim")
|
||||
shard_size = partition.data.size(input_dim) // self.tp_size
|
||||
shard_offset = self.tp_rank * shard_size
|
||||
shard_id = "q" # fake first partition
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
ModelWeightParameter.load_qkv_weight(
|
||||
partition,
|
||||
loaded_weight,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
shard_id=shard_id,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for key in self.partitions:
|
||||
self.partitions[key] = torch.nn.Parameter(
|
||||
data=self.partitions[key].data, requires_grad=False)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
raise ValueError("Accessing `data` of a "
|
||||
"`PartitionedModelWeightParameter` is not allowed. "
|
||||
"Instead, use `get_partition` to get the weight of "
|
||||
"the particular partition you want to access")
|
||||
|
||||
def _fake_weight_loader(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_weight_shard_id: Optional[Union[str, int]]):
|
||||
raise ValueError("When loading partition weights of "
|
||||
f"{self.__class__.__name__}, use methods provided by "
|
||||
f"{self.__class__.__name__}, not partition loader")
|
||||
|
||||
|
||||
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||
output_dim: int, **kwargs) -> BasevLLMParameter:
|
||||
"""
|
||||
@ -456,4 +596,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
shard_offset=shard_offset,
|
||||
bitblas_tile_size=bitblas_tile_size)
|
||||
|
||||
return shard_size, shard_offset
|
||||
return shard_size, shard_offset
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user