mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:05:02 +08:00
[TPU][Quantization] TPU W8A8 (#11785)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
47de8821d3
commit
56fe4c297c
@ -14,4 +14,13 @@ remove_docker_container
|
|||||||
# For HF_TOKEN.
|
# For HF_TOKEN.
|
||||||
source /etc/environment
|
source /etc/environment
|
||||||
# Run a simple end-to-end example.
|
# Run a simple end-to-end example.
|
||||||
docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py"
|
docker run --privileged --net host --shm-size=16G -it \
|
||||||
|
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||||
|
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||||
|
&& python3 -m pip install pytest \
|
||||||
|
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||||
|
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
|
||||||
|
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||||
|
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
||||||
|
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||||
|
&& python3 /workspace/vllm/examples/offline_inference/offline_inference_tpu.py"
|
||||||
|
|||||||
49
tests/tpu/test_quantization_accuracy.py
Normal file
49
tests/tpu/test_quantization_accuracy.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import lm_eval
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
TASK = "gsm8k"
|
||||||
|
FILTER = "exact_match,strict-match"
|
||||||
|
RTOL = 0.03
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GSM8KAccuracyTestConfig:
|
||||||
|
model_name: str
|
||||||
|
excepted_value: float
|
||||||
|
|
||||||
|
def get_model_args(self) -> str:
|
||||||
|
return (f"pretrained={self.model_name},"
|
||||||
|
"max_model_len=4096,max_num_seqs=32")
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: Accuracy scores measured on GPUs.
|
||||||
|
ACCURACY_CONFIGS = [
|
||||||
|
GSM8KAccuracyTestConfig(
|
||||||
|
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
|
||||||
|
excepted_value=0.76), # no bias
|
||||||
|
# NOTE(rob): We cannot re-initialize VLLM in the same process for TPU,
|
||||||
|
# so only one of these tests can run in a single call to pytest. As
|
||||||
|
# a follow up, move this into the LM-EVAL section of the CI.
|
||||||
|
# GSM8KAccuracyTestConfig(
|
||||||
|
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
|
||||||
|
# excepted_value=0.66), # bias in QKV layers
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
|
||||||
|
def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
|
||||||
|
|
||||||
|
results = lm_eval.simple_evaluate(
|
||||||
|
model="vllm",
|
||||||
|
model_args=config.get_model_args(),
|
||||||
|
tasks="gsm8k",
|
||||||
|
batch_size="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPECTED_VALUE = config.excepted_value
|
||||||
|
measured_value = results["results"][TASK][FILTER]
|
||||||
|
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||||
|
and measured_value + RTOL > EXPECTED_VALUE
|
||||||
|
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||||
@ -1,14 +1,13 @@
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
apply_int8_linear, convert_to_channelwise)
|
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
@ -18,6 +17,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||||
|
_kernel_backends_being_used: Set[str] = set()
|
||||||
|
|
||||||
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
||||||
input_symmetric: bool):
|
input_symmetric: bool):
|
||||||
@ -30,74 +30,25 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
# turing and up
|
# turing and up
|
||||||
return 75
|
return 75
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
# WEIGHT
|
|
||||||
# Cutlass kernels need transposed weight.
|
|
||||||
weight = layer.weight
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
|
||||||
|
|
||||||
# WEIGHT SCALE
|
|
||||||
# Cutlass kernels support only per-tensor and per-channel.
|
|
||||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
|
||||||
# scales being passed to the kernel), convert to the per-channel case.
|
|
||||||
is_fused_module = len(self.logical_widths) > 1
|
|
||||||
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
|
|
||||||
ws_channelwise = convert_to_channelwise(layer.weight_scale,
|
|
||||||
self.logical_widths)
|
|
||||||
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
|
|
||||||
else:
|
|
||||||
layer.weight_scale = Parameter(layer.weight_scale.data,
|
|
||||||
requires_grad=False)
|
|
||||||
# INPUT SCALE
|
|
||||||
if self.is_static_input_scheme:
|
|
||||||
if self.input_symmetric:
|
|
||||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.input_zero_point = None
|
|
||||||
else:
|
|
||||||
# reconstruct the ranges
|
|
||||||
int8_traits = torch.iinfo(torch.int8)
|
|
||||||
azps = layer.input_zero_point.to(dtype=torch.int32)
|
|
||||||
range_max = (layer.input_scale *
|
|
||||||
(int8_traits.max - azps)).max()
|
|
||||||
range_min = (layer.input_scale *
|
|
||||||
(int8_traits.min - azps)).min()
|
|
||||||
|
|
||||||
scale = (range_max - range_min) / (int8_traits.max -
|
|
||||||
int8_traits.min)
|
|
||||||
layer.input_scale = Parameter(scale, requires_grad=False)
|
|
||||||
|
|
||||||
# AZP loaded as int8 but used as int32
|
|
||||||
azp = (int8_traits.min -
|
|
||||||
range_min / scale).to(dtype=torch.int32)
|
|
||||||
layer.input_zero_point = Parameter(azp, requires_grad=False)
|
|
||||||
|
|
||||||
else:
|
|
||||||
layer.input_scale = None
|
|
||||||
layer.input_zero_point = None
|
|
||||||
|
|
||||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
|
||||||
# It does not depend on scales or azp, so it is the same for
|
|
||||||
# static and dynamic quantization.
|
|
||||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
|
||||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
|
||||||
if not self.input_symmetric:
|
|
||||||
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
|
||||||
if self.is_static_input_scheme:
|
|
||||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
|
||||||
# in the per-tensor case
|
|
||||||
azp_adj = layer.input_zero_point * azp_adj
|
|
||||||
|
|
||||||
layer.azp_adj = azp_adj
|
|
||||||
else:
|
|
||||||
layer.azp_adj = None
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||||
|
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||||
|
is_static_input_scheme=self.is_static_input_scheme,
|
||||||
|
input_symmetric=self.input_symmetric)
|
||||||
|
|
||||||
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
|
scaled_mm_linear_kernel_config)
|
||||||
|
|
||||||
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
|
logger.info("Using %s for CompressedTensorsW8A8Int8",
|
||||||
|
kernel_type.__name__)
|
||||||
|
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||||
|
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
weight = ModelWeightParameter(data=torch.empty(
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
@ -140,12 +91,18 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("input_zero_point", input_zero_point)
|
layer.register_parameter("input_zero_point", input_zero_point)
|
||||||
|
|
||||||
|
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||||
|
w_q_param_name="weight",
|
||||||
|
w_s_param_name="weight_scale",
|
||||||
|
i_s_param_name="input_scale",
|
||||||
|
i_zp_param_name="input_zero_point",
|
||||||
|
azp_adj_param_name="azp_adj")
|
||||||
|
|
||||||
|
# Checkpoints are serialized in compressed-tensors format, which is
|
||||||
|
# different from the format the kernel may want. Handle repacking here.
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
self.kernel.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
return apply_int8_linear(input=x,
|
return self.kernel.apply_weights(layer, x, bias)
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
input_scale=layer.input_scale,
|
|
||||||
input_zero_point=layer.input_zero_point,
|
|
||||||
azp_adj=layer.azp_adj,
|
|
||||||
bias=bias)
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from compressed_tensors.quantization import ActivationOrdering
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.kernels import (
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
marlin_repeat_scales_on_all_ranks)
|
marlin_repeat_scales_on_all_ranks)
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.kernels import (
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
|||||||
@ -1,74 +0,0 @@
|
|||||||
from typing import List, Optional, Type
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.model_executor.layers.quantization.kernels.exllama import (
|
|
||||||
ExllamaLinearKernel)
|
|
||||||
from vllm.model_executor.layers.quantization.kernels.machete import (
|
|
||||||
MacheteLinearKernel)
|
|
||||||
from vllm.model_executor.layers.quantization.kernels.marlin import (
|
|
||||||
MarlinLinearKernel)
|
|
||||||
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
|
|
||||||
MPLinearKernel, MPLinearLayerConfig)
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
# in priority/performance order (when available)
|
|
||||||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
|
||||||
MacheteLinearKernel,
|
|
||||||
MarlinLinearKernel,
|
|
||||||
ExllamaLinearKernel,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def choose_mp_linear_kernel(
|
|
||||||
config: MPLinearLayerConfig,
|
|
||||||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
|
|
||||||
"""
|
|
||||||
Choose an MPLinearKernel that can implement the given config for the given
|
|
||||||
compute capability. Attempts to choose the best kernel in terms of
|
|
||||||
performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
|
||||||
implemented.
|
|
||||||
compute_capability (Optional[int], optional): The compute capability of
|
|
||||||
the target device, if None uses `current_platform` to get the compute
|
|
||||||
capability. Defaults to None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no kernel can implement the given config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type[MPLinearKernel]: Chosen kernel.
|
|
||||||
"""
|
|
||||||
if compute_capability is None:
|
|
||||||
if current_platform is None:
|
|
||||||
raise ValueError("Cannot determine compute capability")
|
|
||||||
_cc = current_platform.get_device_capability()
|
|
||||||
compute_capability = _cc[0] * 10 + _cc[1]
|
|
||||||
|
|
||||||
failure_reasons = []
|
|
||||||
for kernel in _POSSIBLE_KERNELS:
|
|
||||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
|
||||||
failure_reasons.append(
|
|
||||||
f' {kernel.__name__} disabled by environment variable')
|
|
||||||
continue
|
|
||||||
|
|
||||||
if kernel.get_min_capability() > compute_capability:
|
|
||||||
failure_reasons.append(
|
|
||||||
f"{kernel.__name__} requires capability "
|
|
||||||
f"{kernel.get_min_capability()}, current compute capability "
|
|
||||||
f"is {compute_capability}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
can_implement, failure_reason = kernel.can_implement(config)
|
|
||||||
if can_implement:
|
|
||||||
return kernel
|
|
||||||
else:
|
|
||||||
failure_reasons.append(
|
|
||||||
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
"Failed to find a kernel that can implement the "\
|
|
||||||
"WNA16 linear layer. Reasons: \n"
|
|
||||||
+ '\n'.join(failure_reasons))
|
|
||||||
@ -0,0 +1,74 @@
|
|||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||||
|
ExllamaLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||||
|
MacheteLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501
|
||||||
|
MarlinLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501
|
||||||
|
MPLinearKernel, MPLinearLayerConfig)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# in priority/performance order (when available)
|
||||||
|
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||||
|
MacheteLinearKernel,
|
||||||
|
MarlinLinearKernel,
|
||||||
|
ExllamaLinearKernel,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def choose_mp_linear_kernel(
|
||||||
|
config: MPLinearLayerConfig,
|
||||||
|
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
|
||||||
|
"""
|
||||||
|
Choose an MPLinearKernel that can implement the given config for the given
|
||||||
|
compute capability. Attempts to choose the best kernel in terms of
|
||||||
|
performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||||
|
implemented.
|
||||||
|
compute_capability (Optional[int], optional): The compute capability of
|
||||||
|
the target device, if None uses `current_platform` to get the compute
|
||||||
|
capability. Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no kernel can implement the given config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[MPLinearKernel]: Chosen kernel.
|
||||||
|
"""
|
||||||
|
if compute_capability is None:
|
||||||
|
if current_platform is None:
|
||||||
|
raise ValueError("Cannot determine compute capability")
|
||||||
|
_cc = current_platform.get_device_capability()
|
||||||
|
compute_capability = _cc[0] * 10 + _cc[1]
|
||||||
|
|
||||||
|
failure_reasons = []
|
||||||
|
for kernel in _POSSIBLE_KERNELS:
|
||||||
|
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||||
|
failure_reasons.append(
|
||||||
|
f' {kernel.__name__} disabled by environment variable')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel.get_min_capability() > compute_capability:
|
||||||
|
failure_reasons.append(
|
||||||
|
f"{kernel.__name__} requires capability "
|
||||||
|
f"{kernel.get_min_capability()}, current compute capability "
|
||||||
|
f"is {compute_capability}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
can_implement, failure_reason = kernel.can_implement(config)
|
||||||
|
if can_implement:
|
||||||
|
return kernel
|
||||||
|
else:
|
||||||
|
failure_reasons.append(
|
||||||
|
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to find a kernel that can implement the "\
|
||||||
|
"WNA16 linear layer. Reasons: \n"
|
||||||
|
+ '\n'.join(failure_reasons))
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScaledMMLinearLayerConfig:
|
||||||
|
is_channelwise: bool
|
||||||
|
is_static_input_scheme: bool
|
||||||
|
input_symmetric: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledMMLinearKernel(ABC):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def can_implement(
|
||||||
|
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
|
||||||
|
w_s_param_name: str, i_s_param_name: str,
|
||||||
|
i_zp_param_name: str, azp_adj_param_name: str) -> None:
|
||||||
|
assert self.can_implement(c)
|
||||||
|
self.config = c
|
||||||
|
self.w_q_name = w_q_param_name
|
||||||
|
self.w_s_name = w_s_param_name
|
||||||
|
self.i_s_name = i_s_param_name
|
||||||
|
self.i_zp_name = i_zp_param_name
|
||||||
|
self.azp_adj_name = azp_adj_param_name
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_weight_params(
|
||||||
|
self, layer: torch.nn.Module
|
||||||
|
) -> Tuple[torch.Tensor, # weight
|
||||||
|
torch.Tensor, # weight_scale
|
||||||
|
Optional[torch.Tensor], # input_scale,
|
||||||
|
Optional[torch.Tensor], # input_zp
|
||||||
|
Optional[torch.Tensor], # azp_adj
|
||||||
|
]:
|
||||||
|
return (
|
||||||
|
getattr(layer, self.w_q_name),
|
||||||
|
getattr(layer, self.w_s_name),
|
||||||
|
getattr(layer, self.i_s_name),
|
||||||
|
getattr(layer, self.i_zp_name),
|
||||||
|
getattr(layer, self.azp_adj_name),
|
||||||
|
)
|
||||||
@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||||
|
CutlassScaledMMLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
|
||||||
|
# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||||
|
# TritonScaledMMLinear)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||||
|
XLAScaledMMLinearKernel)
|
||||||
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
|
||||||
|
# in priority/performance order (when available)
|
||||||
|
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
||||||
|
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
||||||
|
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||||
|
# TODO(rob): Create TritonScaledMMLinear kernel. ROCM will
|
||||||
|
# incorrectly attempt to run AZP models if prompted to.
|
||||||
|
PlatformEnum.ROCM: [CutlassScaledMMLinearKernel],
|
||||||
|
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def choose_scaled_mm_linear_kernel(
|
||||||
|
config: ScaledMMLinearLayerConfig,
|
||||||
|
compute_capability: Optional[int] = None
|
||||||
|
) -> Type[ScaledMMLinearKernel]:
|
||||||
|
"""
|
||||||
|
Choose an ScalledMMLinearKernel that can implement the given config for the
|
||||||
|
given compute capability. Attempts to choose the best kernel in terms of
|
||||||
|
performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (ScaledMMLinearLayerConfig): Description of the linear layer
|
||||||
|
to be implemented.
|
||||||
|
compute_capability (Optional[int], optional): The compute capability of
|
||||||
|
the target device, if None uses `current_platform` to get the
|
||||||
|
compute capability. Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no kernel can implement the given config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[ScaledMMLinearKernel]: Chosen kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if compute_capability is None:
|
||||||
|
_cc = current_platform.get_device_capability()
|
||||||
|
if _cc is not None:
|
||||||
|
compute_capability = _cc[0] * 10 + _cc[1]
|
||||||
|
|
||||||
|
failure_reasons = []
|
||||||
|
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||||
|
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
|
||||||
|
.split(","):
|
||||||
|
failure_reasons.append(
|
||||||
|
f' {kernel.__name__} disabled by environment variable')
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the current platform uses compute_capability,
|
||||||
|
# make sure the kernel supports the compute cability.
|
||||||
|
if compute_capability is not None:
|
||||||
|
kernel_min_capability = kernel.get_min_capability()
|
||||||
|
if (kernel_min_capability is not None
|
||||||
|
and kernel_min_capability > compute_capability):
|
||||||
|
failure_reasons.append(
|
||||||
|
f"{kernel.__name__} requires capability "
|
||||||
|
f"{kernel_min_capability}, current compute capability "
|
||||||
|
f"is {compute_capability}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
can_implement, failure_reason = kernel.can_implement(config)
|
||||||
|
if can_implement:
|
||||||
|
return kernel
|
||||||
|
else:
|
||||||
|
failure_reasons.append(
|
||||||
|
f' {kernel.__name__} cannot implement due to: {failure_reason}'
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to find a kernel that can implement the "\
|
||||||
|
"ScaledMM linear layer. Reasons: \n"
|
||||||
|
+ '\n'.join(failure_reasons))
|
||||||
@ -0,0 +1,134 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
convert_to_channelwise)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||||
|
ScaledMMLinearLayerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 75
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(
|
||||||
|
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
|
||||||
|
if (not current_platform.is_cuda() and not current_platform.is_cpu()):
|
||||||
|
return False, "CutlassScaledMM requires running on CUDA or CPU."
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# WEIGHT
|
||||||
|
# Cutlass kernels need transposed weight.
|
||||||
|
weight = getattr(layer, self.w_q_name)
|
||||||
|
replace_parameter(
|
||||||
|
layer, self.w_q_name,
|
||||||
|
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
# Cutlass kernels support only per-tensor and per-channel.
|
||||||
|
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
|
# scales being passed to the kernel), convert to the per-channel case.
|
||||||
|
is_fused_module = len(layer.logical_widths) > 1
|
||||||
|
weight_scale = getattr(layer, self.w_s_name)
|
||||||
|
if is_fused_module and not self.config.is_channelwise:
|
||||||
|
weight_scale = convert_to_channelwise(weight_scale,
|
||||||
|
layer.logical_widths)
|
||||||
|
replace_parameter(
|
||||||
|
layer, self.w_s_name,
|
||||||
|
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.config.is_static_input_scheme:
|
||||||
|
input_scale = getattr(layer, self.i_s_name)
|
||||||
|
|
||||||
|
if self.config.input_symmetric:
|
||||||
|
replace_parameter(
|
||||||
|
layer, self.i_s_name,
|
||||||
|
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||||
|
setattr(layer, self.i_zp_name, None)
|
||||||
|
else:
|
||||||
|
input_zero_point = getattr(layer, self.i_zp_name)
|
||||||
|
|
||||||
|
# reconstruct the ranges
|
||||||
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
|
azps = input_zero_point.to(dtype=torch.int32)
|
||||||
|
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||||
|
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||||
|
|
||||||
|
scale = (range_max - range_min) / (int8_traits.max -
|
||||||
|
int8_traits.min)
|
||||||
|
replace_parameter(
|
||||||
|
layer, self.i_s_name,
|
||||||
|
torch.nn.Parameter(scale, requires_grad=False))
|
||||||
|
|
||||||
|
# AZP loaded as int8 but used as int32
|
||||||
|
azp = (int8_traits.min -
|
||||||
|
range_min / scale).to(dtype=torch.int32)
|
||||||
|
replace_parameter(layer, self.i_zp_name,
|
||||||
|
torch.nn.Parameter(azp, requires_grad=False))
|
||||||
|
|
||||||
|
else:
|
||||||
|
setattr(layer, self.i_s_name, None)
|
||||||
|
setattr(layer, self.i_zp_name, None)
|
||||||
|
|
||||||
|
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||||
|
# It does not depend on scales or azp, so it is the same for
|
||||||
|
# static and dynamic quantization.
|
||||||
|
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||||
|
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||||
|
if not self.config.input_symmetric:
|
||||||
|
weight = getattr(layer, self.w_q_name)
|
||||||
|
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||||
|
if self.config.is_static_input_scheme:
|
||||||
|
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||||
|
# in the per-tensor case
|
||||||
|
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||||
|
setattr(layer, self.azp_adj_name,
|
||||||
|
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||||
|
else:
|
||||||
|
setattr(layer, self.azp_adj_name, None)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
|
# * dynamic, i_s is None and x_s computed from x.
|
||||||
|
# * static, i_s is scalar and x_s is i_s.
|
||||||
|
symmetric = azp_adj is None
|
||||||
|
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
|
||||||
|
i_s,
|
||||||
|
i_zp,
|
||||||
|
symmetric=symmetric)
|
||||||
|
|
||||||
|
if x_zp is not None:
|
||||||
|
# Currently, static is always per-tensor and dynamic is per-token
|
||||||
|
static = i_zp is not None
|
||||||
|
azp = None if static else x_zp
|
||||||
|
return ops.cutlass_scaled_mm_azp(x_q,
|
||||||
|
w_q,
|
||||||
|
scale_a=x_s,
|
||||||
|
scale_b=w_s,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
azp_adj=azp_adj,
|
||||||
|
azp=azp,
|
||||||
|
bias=bias)
|
||||||
|
return ops.cutlass_scaled_mm(x_q,
|
||||||
|
w_q,
|
||||||
|
scale_a=x_s,
|
||||||
|
scale_b=w_s,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
bias=bias)
|
||||||
101
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
Normal file
101
vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import warnings
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from functorch.experimental.control_flow import cond # noqa: F401
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
convert_to_channelwise)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
|
||||||
|
ScaledMMLinearLayerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TPU platform does have a concept of compute capability, "
|
||||||
|
"this method should not be called.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(
|
||||||
|
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
|
||||||
|
if not current_platform.is_tpu():
|
||||||
|
return False, "ScaledMMXLA requires running on TPU."
|
||||||
|
|
||||||
|
if c.is_static_input_scheme:
|
||||||
|
return False, "ScaledMMXLA requires dynamic activation scales."
|
||||||
|
|
||||||
|
if not c.input_symmetric:
|
||||||
|
return False, "ScaledMMXLA requires symmetric activation scales."
|
||||||
|
|
||||||
|
if not c.is_channelwise:
|
||||||
|
return False, "ScaledMMXLA requires channelwise weight scales"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# WEIGHT
|
||||||
|
# [out, in] (different than cutlass_scaled_mm)
|
||||||
|
weight = getattr(layer, self.w_q_name)
|
||||||
|
replace_parameter(layer, self.w_q_name,
|
||||||
|
torch.nn.Parameter(weight.data, requires_grad=False))
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
# XLA kernels support only per-tensor and per-channel.
|
||||||
|
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
|
# scales being passed to the kernel), convert to the per-channel case.
|
||||||
|
is_fused_module = len(layer.logical_widths) > 1
|
||||||
|
weight_scale = getattr(layer, self.w_s_name)
|
||||||
|
if is_fused_module and not self.config.is_channelwise:
|
||||||
|
weight_scale = convert_to_channelwise(weight_scale,
|
||||||
|
layer.logical_widths)
|
||||||
|
|
||||||
|
# [out_channel,] (different than cutlass_scaled_mm)
|
||||||
|
weight_scale = weight_scale.squeeze(-1)
|
||||||
|
replace_parameter(
|
||||||
|
layer, self.w_s_name,
|
||||||
|
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||||
|
|
||||||
|
# Only support symmetric dynamic activation quantization.
|
||||||
|
setattr(layer, self.i_s_name, None)
|
||||||
|
setattr(layer, self.i_zp_name, None)
|
||||||
|
setattr(layer, self.azp_adj_name, None)
|
||||||
|
|
||||||
|
# Filter warning for cond usage in apply_weights. It is okay
|
||||||
|
# to specialize the graph since bias is not dynamic.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=
|
||||||
|
"Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||||
|
return x + bias
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
|
||||||
|
out = torch.ops.xla.quantized_matmul(x,
|
||||||
|
w_q,
|
||||||
|
w_s,
|
||||||
|
zero_point=None,
|
||||||
|
block_size=-1,
|
||||||
|
int4_weight=False,
|
||||||
|
quantize_activation=True)
|
||||||
|
|
||||||
|
# Explicitly capture control flow to make dynamo happy.
|
||||||
|
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||||
|
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||||
@ -201,44 +201,6 @@ def apply_fp8_linear(
|
|||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
def apply_int8_linear(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
|
||||||
input_zero_point: Optional[torch.Tensor] = None,
|
|
||||||
azp_adj: Optional[torch.Tensor] = None,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
|
||||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
|
||||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
|
||||||
symmetric = azp_adj is None
|
|
||||||
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
|
|
||||||
input_scale,
|
|
||||||
input_zero_point,
|
|
||||||
symmetric=symmetric)
|
|
||||||
|
|
||||||
if x_zp is not None:
|
|
||||||
# Currently, static is always per-tensor and dynamic is per-token
|
|
||||||
static = input_zero_point is not None
|
|
||||||
azp = None if static else x_zp
|
|
||||||
return ops.cutlass_scaled_mm_azp(x_q,
|
|
||||||
weight,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=weight_scale,
|
|
||||||
out_dtype=input.dtype,
|
|
||||||
azp_adj=azp_adj,
|
|
||||||
azp=azp,
|
|
||||||
bias=bias)
|
|
||||||
return ops.cutlass_scaled_mm(x_q,
|
|
||||||
weight,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=weight_scale,
|
|
||||||
out_dtype=input.dtype,
|
|
||||||
bias=bias)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from torch.nn import Parameter
|
|||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.utils import _make_synced_weight_loader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
||||||
@ -37,6 +38,18 @@ class BasevLLMParameter(Parameter):
|
|||||||
:returns: a torch.nn.parameter
|
:returns: a torch.nn.parameter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# During weight loading, we often do something like:
|
||||||
|
# narrowed_tensor = param.data.narrow(0, offset, len)
|
||||||
|
# narrowed_tensor.copy_(real_weight)
|
||||||
|
# expecting narrowed_tensor and param.data to share the same storage.
|
||||||
|
# However, on TPUs, narrowed_tensor will lazily propagate to the base
|
||||||
|
# tensor, which is param.data, leading to the redundant memory usage.
|
||||||
|
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||||
|
# we sync the param tensor after its weight loader is called.
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
weight_loader = _make_synced_weight_loader(weight_loader)
|
||||||
|
|
||||||
self._weight_loader = weight_loader
|
self._weight_loader = weight_loader
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -19,7 +19,9 @@ class TpuPlatform(Platform):
|
|||||||
device_name: str = "tpu"
|
device_name: str = "tpu"
|
||||||
device_type: str = "tpu"
|
device_type: str = "tpu"
|
||||||
dispatch_key: str = "XLA"
|
dispatch_key: str = "XLA"
|
||||||
supported_quantization: list[str] = ["tpu_int8"]
|
supported_quantization: list[str] = [
|
||||||
|
"tpu_int8", "compressed-tensors", "compressed_tensors"
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user