[Transform] [Quantization] Add transforms to compressed tensors (#22486)

This commit is contained in:
Kyle Sayers 2025-08-28 02:43:48 -04:00 committed by GitHub
parent c8851a4723
commit 22feac8e95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 661 additions and 36 deletions

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import math
import os
import tempfile
from enum import Enum
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast
import numpy as np
import pytest
@ -33,6 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.sequence import Logprob
from vllm.transformers_utils.utils import maybe_model_redirect
logger = init_logger(__name__)
@ -602,7 +604,7 @@ class HfRunner:
def _hidden_states_to_logprobs(
self,
hidden_states: tuple[tuple[torch.Tensor, ...], ...],
num_logprobs: int,
num_logprobs: Optional[int],
) -> tuple[list[dict[int, float]], int]:
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
output_len = len(hidden_states)
@ -630,7 +632,7 @@ class HfRunner:
self,
prompts: list[str],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
@ -677,7 +679,7 @@ class HfRunner:
self,
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> list[TokensTextLogprobs]:
@ -966,7 +968,7 @@ class VllmRunner:
self,
prompts: list[str],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
@ -991,11 +993,40 @@ class VllmRunner:
videos=videos,
**kwargs)
def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]:
"""
Return the perplexity score associated with generating the prompts
:param prompts: list of prompts to score
:return: perplexity score of each prompt
"""
outputs = self.generate_greedy_logprobs(prompts,
max_tokens=1,
num_logprobs=None,
num_prompt_logprobs=0)
perplexities = []
for output in outputs:
output = cast(TokensTextLogprobsPromptLogprobs, output)
token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
assert token_datas[0] is None
token_log_probs = []
for token_data in token_datas[1:]:
assert token_data is not None
assert len(token_data) == 1
token_log_prob = list(token_data.values())[0].logprob
token_log_probs.append(token_log_prob)
perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs))
perplexities.append(perplexity)
return perplexities
def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
) -> Union[list[TokensTextLogprobs],

View File

@ -719,3 +719,25 @@ def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
assert output
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
@pytest.mark.parametrize("model,prompt,exp_perplexity", [
(
"nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16",
"Flat is better than nested.\nSparse is better than dense.",
150.0,
),
(
"nm-testing/Llama-3.2-1B-Instruct-quip-w4a16",
"Flat is better than nested.\nSparse is better than dense.",
150.0,
),
])
def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
exp_perplexity):
with vllm_runner(model, enforce_eager=True) as llm:
perplexity = llm.generate_prompt_perplexity([prompt])[0]
print(perplexity)
assert perplexity <= exp_perplexity

View File

@ -35,6 +35,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"CompressedTensorsLinearTransformMethod",
"BitBLASLinearMethod",
"GPTQBitBLASLinearMethod",
"AWQMarlinLinearMethod",
@ -199,6 +200,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# special postprocessing for CPU SGL
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
N, K = layer.weight.size()
@ -1470,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase):
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
"weight_loader": self.weight_loader_v1,
})
else:
self.bias = None
@ -1580,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase):
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v
def weight_loader_v1(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# just like all other parameters, does not yet
# support loading bias with weight_loader_v2
layer = (self.q_proj_decoder
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,

View File

@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.transform import TransformConfig
from pydantic import BaseModel
import vllm.envs as envs
@ -30,6 +31,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
transform_config: Optional[TransformConfig] = None,
):
super().__init__()
self.ignore = ignore
@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
if transform_config is not None:
self.transform_config = TransformConfig.model_validate(
transform_config)
else:
self.transform_config = None
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
input_tfms, output_tfms = get_linear_transform_schemes(
layer, prefix, self.transform_config,
self.packed_modules_mapping)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = CompressedTensorsLinearMethod(self)
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms)
else:
return quant_method
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
)
@classmethod
@ -537,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None
# Will be empty for models with only sparsity
weight_quant = input_quant = None
@ -722,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")

View File

@ -0,0 +1,227 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Generator
from itertools import accumulate
from typing import Callable, Optional
import torch
from compressed_tensors.transform import (TransformArgs, TransformConfig,
TransformLocation, TransformScheme)
from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
QKVCrossParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
"""
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
input and output transforms to either side of the original apply method
"""
@classmethod
def from_schemes(
cls, quant_method: LinearMethodBase, input_tfms: dict[int,
TransformTuple],
output_tfms: dict[int, TransformTuple]
) -> "CompressedTensorsLinearTransformMethod":
assert input_tfms or output_tfms
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
# hadacore and fwht can be selected by Transform module
return cls(quant_method, input_tfms, output_tfms)
def __init__(self, quant_method: LinearMethodBase,
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple]):
self.quant_method = quant_method
self.input_tfms = input_tfms
self.output_tfms = output_tfms
self.input_transform: Optional[HadamardTransform] = None
self.output_transform: Optional[HadamardTransform] = None
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
# get weight loader for transforms
weight_loader: Callable = extra_weight_attrs.get(
"weight_loader") # type: ignore[assignment]
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
# transforms (specifically SharedWeightParameter) requires
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name = self.quant_method.__class__.__name__
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
if isinstance(layer, QKVCrossParallelLinear):
weight_loader_v1 = layer.weight_loader_v1
else:
weight_loader_v1 = layer.weight_loader
extra_weight_attrs["weight_loader"] = weight_loader_v1
self.quant_method.create_weights(
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
**extra_weight_attrs)
# validate schemes
num_partitions = len(output_partition_sizes)
self._validate_tfm_schemes(num_partitions)
# create submodules for weight loading
if len(self.input_tfms) > 0:
scheme_name = list(self.input_tfms.values())[0].scheme_name
location = list(self.input_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.input_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.input_transform = transform
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.output_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.output_transform = transform
# compute partition ranges for slicing activations
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
self.partition_ranges = list(zip(starts, output_partition_sizes))
def process_weights_after_loading(self, layer):
self.quant_method.process_weights_after_loading(layer)
for submodule in layer.children():
if isinstance(submodule, HadamardTransform):
submodule.process_weights_after_loading()
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.input_transform is not None:
x = self.input_transform(x)
assert bias is None
x = self.quant_method.apply(layer, x, bias)
# TODO (@ksayers): Write a triton kernel to do this in parallel
if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start:start + length] = self.output_transform(
x[:, start:start + length], part_id=part_id)
return x
def _validate_tfm_schemes(self, num_partitions: int):
if len(self.input_tfms) > 0:
if 0 not in self.input_tfms:
raise ValueError("Must have same input")
for part_index in range(num_partitions):
if self.input_tfms[part_index] != self.input_tfms[0]:
raise ValueError("Must have same input")
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
for tfm in self.output_tfms.values():
if tfm.scheme_name != scheme_name:
raise ValueError("Must have same scheme name")
if tfm.args.location != location:
raise ValueError("Must have same location")
return self.input_tfms, self.output_tfms
def get_linear_transform_schemes(
layer: torch.nn.Module, layer_name: str,
transform_config: Optional[TransformConfig],
packed_modules_mapping: dict[str, list[str]]
) -> tuple[dict[int, TransformTuple], dict[
int, TransformTuple]]: # [input_transform, [output_transform, ...]]
# there can only be one transform input scheme per (fused) module
input_tfms = {}
output_tfms = {}
partition_names = get_layer_partition_names(layer_name,
packed_modules_mapping)
for scheme_name, scheme, args in get_schemes_args(transform_config):
for part_index, part_name in enumerate(partition_names):
if is_match(part_name, layer, args.targets,
args.ignore) and args.is_online():
if args.location == TransformLocation.INPUT:
input_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
elif args.location == TransformLocation.OUTPUT:
output_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
else:
raise ValueError(f"Cannot apply `{args.location}` "
f"transform to `{layer_name}`")
return (input_tfms, output_tfms)
def get_schemes_args(
transform_config: Optional[TransformConfig]
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
if transform_config is None:
return
for scheme_name, scheme in transform_config.config_groups.items():
for args in scheme.apply:
yield (scheme_name, scheme, args)
def get_layer_partition_names(
layer_name: str, packed_modules_mapping: dict[str,
list[str]]) -> list[str]:
"""
Get all partition names associated with this layer.
Names are returned in order of their partition indices.
```python
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
assert get_layer_partition_names(
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
assert get_layer_partition_names(
"mlp.down_proj", mapping) == ["down_proj"]
"""
for fused_suffix, part_suffixes in packed_modules_mapping.items():
if layer_name.endswith(fused_suffix):
return [
layer_name.removesuffix(fused_suffix) + part_suffix
for part_suffix in part_suffixes
]
return [layer_name]

View File

@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Hashable
from typing import Callable, Optional
import torch
from compressed_tensors.transform import TransformLocation, TransformScheme
from torch import Tensor
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parameter import SharedWeightParameter
class HadamardTransform(torch.nn.Module):
"""
Class which handles weight loading, postprocessing, and application of
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
and attention transforms method (not implemented yet)
"""
transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors
kernel: Callable # function used during application
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
def __init__(self,
transforms: dict[int, TransformTuple],
layer: torch.nn.Module,
weight_loader: Callable,
input_size_per_partition: int,
output_partition_sizes: list[int],
kernel: Optional[Callable] = None):
super().__init__()
self.transforms = transforms
self.scales = {}
if get_tensor_model_parallel_world_size() > 1:
raise NotImplementedError("Online transforms with tensor "
"parallelism is not supported")
# Similar to row/col parallel params, but tensors are separate
# to allow for loading with shared memory
self.weight = SharedWeightParameter(weight_loader=weight_loader)
# create shared partition data for each partition of the original weight
input_size = input_size_per_partition
for part_index, (_scheme_name, scheme,
args) in self.transforms.items():
output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(layer, args.location,
input_size, output_size)
data_key = self._get_data_key(scheme, weight_size)
self.weight.add_partition(
part_index,
data_key,
size=(weight_size, weight_size),
dtype=scheme.precision,
)
# validate that shared tensors and schemes are correct
self._validate_input_transforms()
# select kernel based on transform schemes
self.kernel = self._infer_kernel() if kernel is None else kernel
def process_weights_after_loading(self):
for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data
# required by torch.compile
self.weight.process_weights_after_loading()
# precompute scale as a runtime multiply, not division
# do not fold into weight in order to utilize FWHT
self.scales[part_id] = 1 / math.sqrt(data.size(0))
# FUTURE: avoid runtime tranpose by processing weights
# prior to apply
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
if part_id not in self.weight.partitions:
return value
weight = self.weight.partitions[part_id]
weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id]
return self.kernel(self, value.to(weight.dtype), weight, None).to(
value.dtype) * scale
def _get_data_key(self, scheme: TransformScheme,
weight_size: int) -> Hashable:
return (id(scheme), weight_size)
def _get_weight_size(self, layer: torch.nn.Module,
location: TransformLocation, input_size: int,
output_size: int) -> int:
if isinstance(layer, LinearBase):
if location == TransformLocation.INPUT:
return input_size
elif location == TransformLocation.OUTPUT:
return output_size
elif isinstance(layer, VocabParallelEmbedding):
if location == TransformLocation.INPUT:
return output_size
elif location == TransformLocation.OUTPUT:
return input_size
raise ValueError()
def _validate_input_transforms(self):
assert len(self.transforms) > 0
location = list(self.transforms.values())[0].args.location
if location == TransformLocation.INPUT:
first_data = self.weight.partitions[0].data
for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("")
def _infer_kernel(self) -> Callable:
# TODO (@ksayers): use fwht, hadacore
return dispatch_unquantized_gemm()

View File

@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod)
# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
# Therefore, a separate scheme must be created for each quantized dtype
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod):
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# fused hadamard quant linear method
raise NotImplementedError()

View File

@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
from compressed_tensors.transform import TransformArgs, TransformScheme
__all__ = ["TransformTuple"]
class TransformTuple(NamedTuple):
scheme_name: str
scheme: TransformScheme
args: TransformArgs

View File

@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Hashable
from fractions import Fraction
from typing import Callable, Optional, Union
from weakref import WeakValueDictionary
import torch
from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter):
into the parameter when the provided weight loader is called.
"""
def __new__(cls, data: torch.Tensor, **kwargs):
def __new__(cls, data: Optional[torch.Tensor], **kwargs):
return super().__new__(cls, data=data, requires_grad=False)
@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter):
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
self._assert_and_load(loaded_weight)
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert isinstance(shard_id, str)
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]
class _ColumnvLLMParameter(BasevLLMParameter):
"""
@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance(
self,
(PackedColumnParameter,
@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance(
self,
(PackedColumnParameter,
@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
"""
def __init__(self, **kwargs):
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
super().__init__(**kwargs)
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
pass
class SharedWeightParameter(BasevLLMParameter):
"""
Parameter for weights with many shared tensors across a model
For example, when applying transforms to the "gate" and "up" partitions of
`MergedColumnParallelLinear`, the transform weights must stay separate
tensors in order to allow for tensor memory sharing between layers.
"""
# global registry for sharing tensors based on passed `data_key`
# this dict holds weaksrefs to avoid memory leak after model cleanup
tensors_registry: WeakValueDictionary = WeakValueDictionary()
# local container for strong references to shared tensors
# this set compensates for the fact that torch.nn.Parameter
# and Parameter subclasses do not hold reliable references to tensors
local_tensors: set[torch.Tensor]
# dictionary mapping partition indices to associated parameters
partitions: dict[int, Union[ModelWeightParameter, Parameter]]
def __new__(cls, **kwargs):
return super().__new__(cls, data=None, **kwargs)
def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs):
weight_loader: Callable = kwargs.get(
"weight_loader") # type: ignore[assignment]
super().__init__(data=None, weight_loader=weight_loader)
self.local_tensors = set()
self.partitions = {}
self.kwargs = {
"input_dim": input_dim,
"output_dim": output_dim,
"weight_loader": self._fake_weight_loader
}
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > 1:
raise NotImplementedError(f"{self.__class__.__name__} does not "
"currently support tensor parallelism")
def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
"""
Add a partition to the weight parameter. Partitions whose `data_key`
is the same will share tensor data
:param index: index of partition to add
:param data_key: hashable key used to key shared tensors
:param *args: arguments for `torch.empty`
:param **kwargs: keyword arguments for `torch.empty`
"""
# load (shared) tensor using `data_key`
if data_key not in self.tensors_registry:
data = torch.empty(*args, **kwargs)
self.tensors_registry[data_key] = data
else:
data = self.tensors_registry[data_key]
# create associated model parameter
self.partitions[index] = ModelWeightParameter(
data=data, **self.kwargs) # type: ignore[arg-type]
# hold local reference, since ModelWeightParameter does not
# see https://github.com/pytorch/pytorch/issues/75932
self.local_tensors.add(data)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
assert len(self.partitions) == 1 and 0 in self.partitions
partition = self.partitions[0]
ModelWeightParameter.load_column_parallel_weight(
partition, loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
assert len(self.partitions) == 1 and 0 in self.partitions
partition = self.partitions[0]
ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
partition_id = kwargs.pop("shard_id")
partition_id = self._shard_id_as_int(partition_id)
partition = self.partitions[partition_id]
input_dim = self.kwargs.get("input_dim")
shard_size = partition.data.size(input_dim) // self.tp_size
shard_offset = self.tp_rank * shard_size
ModelWeightParameter.load_merged_column_weight(
partition,
loaded_weight,
shard_offset=shard_offset,
shard_size=shard_size)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
partition_id = self._shard_id_as_int(kwargs.pop("shard_id"))
partition = self.partitions[partition_id]
input_dim = self.kwargs.get("input_dim")
shard_size = partition.data.size(input_dim) // self.tp_size
shard_offset = self.tp_rank * shard_size
shard_id = "q" # fake first partition
num_heads = kwargs.get("num_heads")
ModelWeightParameter.load_qkv_weight(
partition,
loaded_weight,
shard_offset=shard_offset,
shard_size=shard_size,
shard_id=shard_id,
num_heads=num_heads,
)
def process_weights_after_loading(self):
for key in self.partitions:
self.partitions[key] = torch.nn.Parameter(
data=self.partitions[key].data, requires_grad=False)
@property
def data(self):
raise ValueError("Accessing `data` of a "
"`PartitionedModelWeightParameter` is not allowed. "
"Instead, use `get_partition` to get the weight of "
"the particular partition you want to access")
def _fake_weight_loader(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_weight_shard_id: Optional[Union[str, int]]):
raise ValueError("When loading partition weights of "
f"{self.__class__.__name__}, use methods provided by "
f"{self.__class__.__name__}, not partition loader")
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
output_dim: int, **kwargs) -> BasevLLMParameter:
"""
@ -456,4 +596,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_offset=shard_offset,
bitblas_tile_size=bitblas_tile_size)
return shard_size, shard_offset
return shard_size, shard_offset