mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 21:36:18 +08:00
[ Misc ] Support Fp8 via llm-compressor (#6110)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
parent
333306a252
commit
abfe705a02
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.752
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.752
|
||||||
|
limit: 250
|
||||||
|
num_fewshot: 5
|
||||||
@ -1,4 +1,4 @@
|
|||||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8 -b 32 -l 250 -f 5 -t 1
|
||||||
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
|
model_name: "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
|
||||||
tasks:
|
tasks:
|
||||||
- name: "gsm8k"
|
- name: "gsm8k"
|
||||||
|
|||||||
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.728
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.728
|
||||||
|
limit: 250
|
||||||
|
num_fewshot: 5
|
||||||
@ -1,2 +1,4 @@
|
|||||||
Meta-Llama-3-8B-Instruct.yaml
|
Meta-Llama-3-8B-Instruct.yaml
|
||||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||||
|
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||||
|
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||||
|
|||||||
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
lm_eval --model vllm \
|
lm_eval --model vllm \
|
||||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE \
|
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
|
||||||
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
||||||
--batch_size $BATCH_SIZE
|
--batch_size $BATCH_SIZE
|
||||||
|
|||||||
@ -24,7 +24,8 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
|
|||||||
|
|
||||||
def launch_lm_eval(eval_config):
|
def launch_lm_eval(eval_config):
|
||||||
model_args = f"pretrained={eval_config['model_name']}," \
|
model_args = f"pretrained={eval_config['model_name']}," \
|
||||||
f"tensor_parallel_size={TP_SIZE}"
|
f"tensor_parallel_size={TP_SIZE}," \
|
||||||
|
f"add_bos_token=true"
|
||||||
|
|
||||||
results = lm_eval.simple_evaluate(
|
results = lm_eval.simple_evaluate(
|
||||||
model="vllm",
|
model="vllm",
|
||||||
|
|||||||
@ -9,7 +9,8 @@ import torch
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8, CompressedTensorsWNA16)
|
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||||
|
CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
QuantizationType)
|
QuantizationType)
|
||||||
|
|
||||||
@ -37,12 +38,11 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
|||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
assert isinstance(down_proj.quant_method,
|
assert isinstance(down_proj.quant_method,
|
||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
||||||
|
|
||||||
assert qkv_proj.scheme.strategy == strategy
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
assert qkv_proj.scheme.is_static_input_scheme
|
assert qkv_proj.scheme.is_static_input_scheme
|
||||||
expected_type = (torch.int8 if quant_type == QuantizationType.INT else
|
expected_type = torch.int8
|
||||||
torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
assert qkv_proj.weight.dtype is expected_type
|
assert qkv_proj.weight.dtype is expected_type
|
||||||
assert o_proj.weight.dtype is expected_type
|
assert o_proj.weight.dtype is expected_type
|
||||||
@ -79,7 +79,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
|
|||||||
qkv_proj = layer.self_attn.qkv_proj
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8)
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
||||||
assert not qkv_proj.scheme.is_static_input_scheme
|
assert not qkv_proj.scheme.is_static_input_scheme
|
||||||
assert qkv_proj.scheme.strategy == strategy
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
assert qkv_proj.weight.dtype is torch.int8
|
assert qkv_proj.weight.dtype is torch.int8
|
||||||
@ -123,3 +123,25 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
|||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
def test_compressed_tensors_fp8(vllm_runner):
|
||||||
|
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||||
|
with vllm_runner(model_path) as llm:
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
|
||||||
|
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
||||||
|
assert qkv_proj.input_scale.dtype is torch.float32
|
||||||
|
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||||
|
# should be scalars after processing
|
||||||
|
assert len(qkv_proj.input_scale.shape) == 0
|
||||||
|
assert len(qkv_proj.weight_scale.shape) == 0
|
||||||
|
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||||
|
assert output
|
||||||
|
|||||||
@ -9,10 +9,11 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
|||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8, CompressedTensorsWNA16)
|
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||||
|
CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||||
find_first_name_or_class_match)
|
QuantizationType, find_first_name_or_class_match)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@ -117,6 +118,40 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
return is_8_bits and is_token and is_symmetric and is_dynamic
|
return is_8_bits and is_token and is_symmetric and is_dynamic
|
||||||
|
|
||||||
|
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
||||||
|
input_quant: BaseModel) -> bool:
|
||||||
|
# Confirm weights and activations quantized.
|
||||||
|
if weight_quant is None or input_quant is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Confirm we have floating points.
|
||||||
|
if not (weight_quant.type == QuantizationType.FLOAT
|
||||||
|
and input_quant.type == QuantizationType.FLOAT):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Confirm weight scheme is supported.
|
||||||
|
is_symmetric_weight = weight_quant.symmetric
|
||||||
|
is_static_weight = not weight_quant.dynamic
|
||||||
|
is_per_tensor_weight = (
|
||||||
|
weight_quant.strategy == QuantizationStrategy.TENSOR)
|
||||||
|
if not (is_symmetric_weight and is_static_weight
|
||||||
|
and is_per_tensor_weight):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Dynamic quantization is always supported if weights supported.
|
||||||
|
if input_quant.dynamic:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Confirm activation scheme is supported.
|
||||||
|
is_symmetric_activation = input_quant.symmetric
|
||||||
|
is_per_tensor_activation = (
|
||||||
|
input_quant.strategy == QuantizationStrategy.TENSOR)
|
||||||
|
if not (is_symmetric_activation and is_per_tensor_activation):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# All conditions satisfied.
|
||||||
|
return True
|
||||||
|
|
||||||
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> bool:
|
input_quant: BaseModel) -> bool:
|
||||||
input_quant_none = input_quant is None
|
input_quant_none = input_quant is None
|
||||||
@ -147,14 +182,21 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
strategy=weight_quant.strategy,
|
strategy=weight_quant.strategy,
|
||||||
group_size=weight_quant.group_size)
|
group_size=weight_quant.group_size)
|
||||||
|
|
||||||
if self.quant_format == CompressionFormat.int_quantized.value:
|
if (self.quant_format == CompressionFormat.int_quantized.value or
|
||||||
|
self.quant_format == CompressionFormat.float_quantized.value):
|
||||||
|
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW8A8Fp8(
|
||||||
|
input_dynamic=input_quant.dynamic)
|
||||||
|
|
||||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
|
return CompressedTensorsW8A8Int8(
|
||||||
is_static_input_scheme=True)
|
strategy=weight_quant.strategy,
|
||||||
|
is_static_input_scheme=True)
|
||||||
|
|
||||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8(strategy=weight_quant.strategy,
|
return CompressedTensorsW8A8Int8(
|
||||||
is_static_input_scheme=False)
|
strategy=weight_quant.strategy,
|
||||||
|
is_static_input_scheme=False)
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"No compressed-tensors compatible scheme was found.")
|
"No compressed-tensors compatible scheme was found.")
|
||||||
@ -187,7 +229,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|||||||
self.quantization_config = quantization_config
|
self.quantization_config = quantization_config
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
return layer.scheme.process_weights_after_loading(layer)
|
layer.scheme.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
|
|||||||
@ -1,8 +1,19 @@
|
|||||||
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||||
from .compressed_tensors_unquantized import ( # noqa: F401
|
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
|
||||||
CompressedTensorsUnquantized)
|
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||||
from .compressed_tensors_w4a16_24 import ( # noqa: F401
|
CompressedTensorsW4A16Sparse24)
|
||||||
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||||
from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401
|
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
|
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||||
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401
|
CompressedTensorsWNA16)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CompressedTensorsScheme",
|
||||||
|
"CompressedTensorsUnquantized",
|
||||||
|
"CompressedTensorsWNA16",
|
||||||
|
"CompressedTensorsW4A16Sparse24",
|
||||||
|
"CompressedTensorsW8A8Int8",
|
||||||
|
"CompressedTensorsW8A8Fp8",
|
||||||
|
"WNA16_SUPPORTED_BITS",
|
||||||
|
"W4A16SPARSE24_SUPPORTED_BITS",
|
||||||
|
]
|
||||||
|
|||||||
@ -1,109 +0,0 @@
|
|||||||
from typing import Callable, List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|
||||||
CompressedTensorsScheme)
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
QuantizationStrategy)
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8(CompressedTensorsScheme):
|
|
||||||
|
|
||||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
|
||||||
self.strategy = strategy
|
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
|
||||||
|
|
||||||
# Cutlass kernels support only per-tensor and per-channel cases.
|
|
||||||
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
|
||||||
# scales being passed to the kernel), we convert to the per-channel case.
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
if (self.strategy == QuantizationStrategy.TENSOR
|
|
||||||
and len(self.logical_widths) > 1):
|
|
||||||
|
|
||||||
# Load the N per-tensor scales into the channelwise buffer.
|
|
||||||
weight_scale_channel = torch.empty(
|
|
||||||
(sum(self.logical_widths), 1),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=layer.weight_scale.device)
|
|
||||||
start = 0
|
|
||||||
for idx, logical_width in enumerate(self.logical_widths):
|
|
||||||
end = start + logical_width
|
|
||||||
weight_scale_channel[start:end, :] = layer.weight_scale[idx]
|
|
||||||
start = end
|
|
||||||
|
|
||||||
layer.weight_scale = Parameter(weight_scale_channel,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
# transpose weights for cutlass.
|
|
||||||
weight = layer.weight
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
input_size_per_partition: int,
|
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
|
||||||
**kwargs):
|
|
||||||
self.logical_widths = output_partition_sizes
|
|
||||||
|
|
||||||
# WEIGHT SCALE
|
|
||||||
shape: Union[Tuple[int], Tuple[int, int]]
|
|
||||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
shape = (sum(self.logical_widths), 1)
|
|
||||||
else:
|
|
||||||
shape = (len(self.logical_widths), )
|
|
||||||
|
|
||||||
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
|
||||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
set_weight_attrs(weight_scale, {
|
|
||||||
"weight_loader": weight_loader,
|
|
||||||
"output_dim": 0,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
set_weight_attrs(weight_scale, {
|
|
||||||
"weight_loader": weight_loader,
|
|
||||||
"needs_scalar_to_array": True,
|
|
||||||
})
|
|
||||||
|
|
||||||
# WEIGHT
|
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=torch.int8),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(weight, {
|
|
||||||
"input_dim": 1,
|
|
||||||
"output_dim": 0,
|
|
||||||
"weight_loader": weight_loader,
|
|
||||||
})
|
|
||||||
|
|
||||||
# INPUT SCALE
|
|
||||||
# Static quantization: load from disk.
|
|
||||||
if self.is_static_input_scheme:
|
|
||||||
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
layer.register_parameter("input_scale", input_scale)
|
|
||||||
set_weight_attrs(input_scale, {
|
|
||||||
"weight_loader": weight_loader,
|
|
||||||
"ignore_warning": True,
|
|
||||||
})
|
|
||||||
# Dynamic quantization: set to None.
|
|
||||||
else:
|
|
||||||
layer.input_scale = None
|
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
|
||||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
|
||||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
|
||||||
x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale)
|
|
||||||
|
|
||||||
return ops.cutlass_scaled_mm(x_q,
|
|
||||||
layer.weight,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=layer.weight_scale,
|
|
||||||
out_dtype=x.dtype)
|
|
||||||
@ -0,0 +1,87 @@
|
|||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||||
|
requantize_with_max_scale)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self, input_dynamic: bool):
|
||||||
|
self.input_dynamic = input_dynamic
|
||||||
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
|
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||||
|
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
|
# scales being passed to the kernel), we requantize with a single scale.
|
||||||
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
|
# Dequant -> Quant with max scale.
|
||||||
|
max_w_scale, weight = requantize_with_max_scale(
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
logical_widths=layer.logical_widths,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update layer with new values.
|
||||||
|
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
|
||||||
|
layer.weight_scale = torch.nn.Parameter(max_w_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
if self.input_dynamic:
|
||||||
|
layer.input_scale = None
|
||||||
|
else:
|
||||||
|
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
del params_dtype
|
||||||
|
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.float8_e4m3fn),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
set_weight_attrs(weight, {
|
||||||
|
"input_dim": 1,
|
||||||
|
"output_dim": 0,
|
||||||
|
"weight_loader": weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
weight_scale = create_per_tensor_scale_param(
|
||||||
|
output_partition_sizes, weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if not self.input_dynamic:
|
||||||
|
input_scale = create_per_tensor_scale_param(
|
||||||
|
output_partition_sizes, weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
return apply_fp8_linear(
|
||||||
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=layer.input_scale,
|
||||||
|
bias=bias,
|
||||||
|
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
||||||
@ -0,0 +1,85 @@
|
|||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
QuantizationStrategy)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param,
|
||||||
|
create_per_tensor_scale_param)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||||
|
self.strategy = strategy
|
||||||
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# WEIGHT
|
||||||
|
# Cutlass kernels need transposed weight.
|
||||||
|
weight = layer.weight
|
||||||
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
# Cutlass kernels support only per-tensor and per-channel.
|
||||||
|
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||||
|
# scales being passed to the kernel), convert to the per-channel case.
|
||||||
|
is_fused_module = len(self.logical_widths) > 1
|
||||||
|
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
|
||||||
|
ws_channelwise = convert_to_channelwise(layer.weight_scale,
|
||||||
|
self.logical_widths)
|
||||||
|
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
layer.input_scale = None
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
self.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
set_weight_attrs(weight, {
|
||||||
|
"input_dim": 1,
|
||||||
|
"output_dim": 0,
|
||||||
|
"weight_loader": weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
layer_kwargs = {"weight_loader": weight_loader}
|
||||||
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
scale = create_per_channel_scale_param(output_partition_sizes,
|
||||||
|
**layer_kwargs)
|
||||||
|
else:
|
||||||
|
assert self.strategy == QuantizationStrategy.TENSOR
|
||||||
|
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||||
|
**layer_kwargs)
|
||||||
|
layer.register_parameter("weight_scale", scale)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||||
|
**layer_kwargs)
|
||||||
|
layer.register_parameter("input_scale", scale)
|
||||||
|
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||||
|
return apply_int8_linear(input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=layer.input_scale)
|
||||||
@ -9,6 +9,7 @@ from torch.nn import Module
|
|||||||
class CompressionFormat(Enum):
|
class CompressionFormat(Enum):
|
||||||
dense = "dense"
|
dense = "dense"
|
||||||
sparse_bitmask = "sparse-bitmask"
|
sparse_bitmask = "sparse-bitmask"
|
||||||
|
float_quantized = "float-quantized"
|
||||||
int_quantized = "int-quantized"
|
int_quantized = "int-quantized"
|
||||||
pack_quantized = "pack-quantized"
|
pack_quantized = "pack-quantized"
|
||||||
marlin_24 = "marlin-24"
|
marlin_24 = "marlin-24"
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
@ -11,11 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
|||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
|
|
||||||
marlin_permute_scales)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
pack_fp8_to_int32)
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||||
|
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
@ -25,13 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_supported() -> bool:
|
|
||||||
capability = current_platform.get_device_capability()
|
|
||||||
capability = capability[0] * 10 + capability[1]
|
|
||||||
|
|
||||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8Config(QuantizationConfig):
|
class Fp8Config(QuantizationConfig):
|
||||||
"""Config class for FP8."""
|
"""Config class for FP8."""
|
||||||
|
|
||||||
@ -117,23 +110,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
self.use_marlin = capability < 89
|
self.use_marlin = capability < 89
|
||||||
|
|
||||||
def _create_scale_param(
|
|
||||||
self,
|
|
||||||
scale_name: str,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
**extra_weight_attrs,
|
|
||||||
) -> None:
|
|
||||||
scale = Parameter(torch.empty(len(output_partition_sizes),
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
|
||||||
layer.register_parameter(scale_name, scale)
|
|
||||||
set_weight_attrs(scale, {
|
|
||||||
**extra_weight_attrs,
|
|
||||||
"needs_scalar_to_array": True,
|
|
||||||
})
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -147,7 +123,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
del input_size, output_size
|
del input_size, output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
layer.process_after_load = True
|
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
@ -173,144 +148,50 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Otherwise, wait until process_weights_after_loading.
|
# Otherwise, wait until process_weights_after_loading.
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
self._create_scale_param(
|
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||||
scale_name="weight_scale",
|
**extra_weight_attrs)
|
||||||
layer=layer,
|
layer.register_parameter("weight_scale", scale)
|
||||||
output_partition_sizes=output_partition_sizes,
|
|
||||||
**extra_weight_attrs)
|
|
||||||
|
|
||||||
# INPUT ACTIVATION SCALE
|
# INPUT ACTIVATION SCALE
|
||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
self._create_scale_param(
|
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||||
scale_name="input_scale",
|
**extra_weight_attrs)
|
||||||
layer=layer,
|
layer.register_parameter("input_scale", scale)
|
||||||
output_partition_sizes=output_partition_sizes,
|
|
||||||
**extra_weight_attrs)
|
|
||||||
|
|
||||||
# For GPUs without FP8 hardware support, we use Marlin for fast
|
|
||||||
# fused dequantization
|
|
||||||
if self.use_marlin:
|
|
||||||
layer.marlin_state = GPTQMarlinState.REPACK
|
|
||||||
|
|
||||||
def prepare_layer_for_marlin(self, layer: Module) -> None:
|
|
||||||
print_warning_once(
|
|
||||||
"Your GPU does not have native support for FP8 computation but "
|
|
||||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
|
||||||
"be used leveraging the Marlin kernel. This may degrade "
|
|
||||||
"performance for compute-heavy workloads.")
|
|
||||||
|
|
||||||
part_size_n = layer.output_size_per_partition
|
|
||||||
part_size_k = layer.input_size_per_partition
|
|
||||||
|
|
||||||
assert layer.marlin_state == GPTQMarlinState.REPACK
|
|
||||||
layer.marlin_state = GPTQMarlinState.READY
|
|
||||||
|
|
||||||
device = layer.weight.device
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
# Repack weights to gptq format (packed int32 elements)
|
|
||||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
|
||||||
|
|
||||||
# Repack weights to marlin format
|
|
||||||
marlin_qweight = ops.gptq_marlin_repack(
|
|
||||||
b_q_weight=packed_gptq_qweight,
|
|
||||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
|
||||||
size_k=part_size_k,
|
|
||||||
size_n=part_size_n,
|
|
||||||
num_bits=8,
|
|
||||||
)
|
|
||||||
layer.weight = Parameter(marlin_qweight, requires_grad=False)
|
|
||||||
|
|
||||||
# WEIGHT SCALES
|
|
||||||
# Currently Marlin doesn't support per-tensor scales, so we
|
|
||||||
# expand it to channelwise
|
|
||||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
|
||||||
layer.orig_dtype).to(device)
|
|
||||||
# Permute scales
|
|
||||||
marlin_scales = marlin_permute_scales(
|
|
||||||
s=scales,
|
|
||||||
size_k=part_size_k,
|
|
||||||
size_n=part_size_n,
|
|
||||||
group_size=-1,
|
|
||||||
num_bits=8,
|
|
||||||
)
|
|
||||||
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
|
|
||||||
|
|
||||||
# Allocate marlin workspace
|
|
||||||
max_workspace_size = (
|
|
||||||
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
|
||||||
workspace = torch.zeros(max_workspace_size,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
layer.workspace = workspace
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if (not hasattr(layer, "process_after_load")
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
or not layer.process_after_load):
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
|
||||||
scale=None)
|
scale=None)
|
||||||
|
|
||||||
|
# Update the layer with the new values.
|
||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
layer.logical_widths = None
|
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
if self.use_marlin:
|
|
||||||
self.prepare_layer_for_marlin(layer)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp8, requantize the separately quantized logical
|
# If checkpoint is fp8, requantize the separately quantized logical
|
||||||
# weights into a single fp8 weight with a single weight scale.
|
# weights into a single fp8 weight with a single weight scale.
|
||||||
else:
|
else:
|
||||||
# WEIGHT_SCALE / WEIGHT
|
# Dequant -> Quant with max scale.
|
||||||
# Loop over logical weights, requantizing with single scale.
|
max_w_scale, weight = requantize_with_max_scale(
|
||||||
max_w_scale = layer.weight_scale.max()
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
logical_widths=layer.logical_widths,
|
||||||
|
)
|
||||||
|
|
||||||
# QKV / MLP is fused in the on disk checkpoint if any of the
|
# Update layer with new values.
|
||||||
# weight scales are still set to the default since we initialize
|
|
||||||
# N weight scales for N shards but we only load 1 weight scale
|
|
||||||
# from disk in this case. As a result, we skip dequant -> requant
|
|
||||||
# since we already have quantized QKV together.
|
|
||||||
# Sample Model with fused checkpoint:
|
|
||||||
# * nm-testing/Phi-3-mini-128k-instruct-FP8
|
|
||||||
unfused_module_in_checkpoint = (
|
|
||||||
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
|
|
||||||
|
|
||||||
if unfused_module_in_checkpoint:
|
|
||||||
start = 0
|
|
||||||
for idx, logical_width in enumerate(layer.logical_widths):
|
|
||||||
end = start + logical_width
|
|
||||||
weight_dq = per_tensor_dequantize(
|
|
||||||
layer.weight[start:end, :], layer.weight_scale[idx])
|
|
||||||
|
|
||||||
layer.weight[start:end, :] = per_tensor_quantize(
|
|
||||||
weight_dq, layer.weight_scale.max())
|
|
||||||
start = end
|
|
||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
|
||||||
|
|
||||||
# WEIGHT
|
|
||||||
# Transpose weight for passing to torch._scaled_mm
|
|
||||||
weight = layer.weight
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
# INPUT ACTIVATION SCALE
|
if self.quant_config.activation_scheme == "static":
|
||||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
|
||||||
# Static: set to max of the input_scales (since they are equal).
|
|
||||||
if self.quant_config.activation_scheme == "dynamic":
|
|
||||||
layer.input_scale = None
|
|
||||||
elif self.quant_config.activation_scheme == "static":
|
|
||||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
layer.input_scale = None
|
||||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
self.prepare_layer_for_marlin(layer)
|
prepare_fp8_layer_for_marlin(layer)
|
||||||
|
# Activations not quantized for marlin.
|
||||||
|
del layer.input_scale
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -318,65 +199,22 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
return apply_fp8_marlin_linear(
|
||||||
# Marlin kernel for fast weight-only FP8 quantization
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
weight_scale=layer.weight_scale,
|
||||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
|
||||||
|
|
||||||
output = ops.fp8_marlin_gemm(
|
|
||||||
a=reshaped_x,
|
|
||||||
b_q_weight=layer.weight,
|
|
||||||
b_scales=layer.weight_scale,
|
|
||||||
workspace=layer.workspace,
|
workspace=layer.workspace,
|
||||||
num_bits=8,
|
|
||||||
size_m=reshaped_x.shape[0],
|
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
size_k=layer.input_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
)
|
bias=bias)
|
||||||
|
|
||||||
if bias is not None:
|
return apply_fp8_linear(
|
||||||
output.add_(bias) # In-place add
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
return output.reshape(out_shape)
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=layer.input_scale,
|
||||||
else:
|
bias=bias,
|
||||||
|
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x
|
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale
|
|
||||||
|
|
||||||
if bias is None and self.cutlass_fp8_supported:
|
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
|
||||||
output = ops.cutlass_scaled_mm(
|
|
||||||
qinput,
|
|
||||||
layer.weight,
|
|
||||||
out_dtype=x.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=layer.weight_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
|
||||||
layer.input_scale,
|
|
||||||
batch_dim_padding=17)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ -- note we padded the input above because
|
|
||||||
# torch._scaled_mm is more performant for matrices with
|
|
||||||
# batch dimension > 16. Note that this could change
|
|
||||||
# in the future.
|
|
||||||
output, _ = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
layer.weight,
|
|
||||||
out_dtype=x.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=layer.weight_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, x.shape[0])
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||||
@ -399,8 +237,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size: int, params_dtype: torch.dtype,
|
intermediate_size: int, params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs):
|
**extra_weight_attrs):
|
||||||
|
|
||||||
layer.process_after_load = True
|
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
@ -465,9 +301,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.a2_scale = None
|
layer.a2_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if (not hasattr(layer, "process_after_load")
|
|
||||||
or not layer.process_after_load):
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
@ -531,7 +364,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
shard_size, :],
|
shard_size, :],
|
||||||
layer.w13_scale[expert_id][shard_id])
|
layer.w13_scale[expert_id][shard_id])
|
||||||
layer.w13_weight[expert_id][
|
layer.w13_weight[expert_id][
|
||||||
start:start + shard_size, :] = per_tensor_quantize(
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||||
dq_weight, max_w13_scales[expert_id])
|
dq_weight, max_w13_scales[expert_id])
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
@ -596,23 +429,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
|
|||||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
"cause accuracy issues. Please make sure kv-cache scaling "
|
||||||
"factor is available in the fp8 checkpoint.")
|
"factor is available in the fp8 checkpoint.")
|
||||||
del layer.kv_scale
|
del layer.kv_scale
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_quantize(tensor: torch.Tensor,
|
|
||||||
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
|
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
|
||||||
return qweight.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_dequantize(
|
|
||||||
tensor: torch.Tensor, inv_scale: Union[float,
|
|
||||||
torch.Tensor]) -> torch.Tensor:
|
|
||||||
fake_qweight = tensor.to(torch.float16)
|
|
||||||
dq_weight = fake_qweight * inv_scale
|
|
||||||
return dq_weight
|
|
||||||
|
|
||||||
|
|
||||||
def all_close_1d(x: torch.Tensor) -> bool:
|
|
||||||
assert len(x.shape) == 1
|
|
||||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
|
||||||
|
|||||||
@ -11,20 +11,16 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K,
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||||
|
GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM,
|
||||||
|
GPTQ_MARLIN_TILE)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
GPTQ_MARLIN_TILE = 16
|
|
||||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
|
||||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
|
||||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
|
||||||
|
|
||||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
|
||||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
|
||||||
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
|
||||||
|
|
||||||
|
|
||||||
# Permutations for Marlin scale shuffling
|
# Permutations for Marlin scale shuffling
|
||||||
def get_scale_perms(num_bits: int):
|
def get_scale_perms(num_bits: int):
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
"""This file is used for /tests and /benchmarks"""
|
"""This file is used for /tests and /benchmarks"""
|
||||||
import random
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.format_24 import (
|
from vllm.model_executor.layers.quantization.utils.format_24 import (
|
||||||
mask_creator, sparse_semi_structured_from_dense_cutlass)
|
mask_creator, sparse_semi_structured_from_dense_cutlass)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
|
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
|
||||||
@ -13,8 +15,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
get_pack_factor, quantize_weights, sort_weights)
|
get_pack_factor, quantize_weights, sort_weights)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
MARLIN_TILE = 16
|
GPTQ_MARLIN_TILE = 16
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||||
|
|
||||||
|
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||||
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||||
|
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||||
|
|
||||||
|
|
||||||
def is_marlin_supported():
|
def is_marlin_supported():
|
||||||
@ -22,7 +32,92 @@ def is_marlin_supported():
|
|||||||
return capability[0] >= 8
|
return capability[0] >= 8
|
||||||
|
|
||||||
|
|
||||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
def apply_fp8_marlin_linear(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
workspace: torch.Tensor,
|
||||||
|
size_n: int,
|
||||||
|
size_k: int,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||||
|
# Marlin kernel for fast weight-only FP8 quantization
|
||||||
|
|
||||||
|
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||||
|
out_shape = input.shape[:-1] + (size_n, )
|
||||||
|
|
||||||
|
output = ops.fp8_marlin_gemm(
|
||||||
|
a=reshaped_x,
|
||||||
|
b_q_weight=weight,
|
||||||
|
b_scales=weight_scale,
|
||||||
|
workspace=workspace,
|
||||||
|
num_bits=8,
|
||||||
|
size_m=reshaped_x.shape[0],
|
||||||
|
size_n=size_n,
|
||||||
|
size_k=size_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output.reshape(out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||||
|
print_warning_once(
|
||||||
|
"Your GPU does not have native support for FP8 computation but "
|
||||||
|
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||||
|
"be used leveraging the Marlin kernel. This may degrade "
|
||||||
|
"performance for compute-heavy workloads.")
|
||||||
|
|
||||||
|
part_size_n = layer.output_size_per_partition
|
||||||
|
part_size_k = layer.input_size_per_partition
|
||||||
|
|
||||||
|
device = layer.weight.device
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
# Repack weights to gptq format (packed int32 elements)
|
||||||
|
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||||
|
|
||||||
|
# Repack weights to marlin format
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(
|
||||||
|
b_q_weight=packed_gptq_qweight,
|
||||||
|
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||||
|
size_k=part_size_k,
|
||||||
|
size_n=part_size_n,
|
||||||
|
num_bits=8,
|
||||||
|
)
|
||||||
|
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||||
|
|
||||||
|
# WEIGHT SCALES
|
||||||
|
# Currently Marlin doesn't support per-tensor scales, so we
|
||||||
|
# expand it to channelwise
|
||||||
|
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||||
|
layer.orig_dtype).to(device)
|
||||||
|
# Permute scales
|
||||||
|
num_bits = 8
|
||||||
|
marlin_scales = marlin_permute_scales(
|
||||||
|
s=scales,
|
||||||
|
size_k=part_size_k,
|
||||||
|
size_n=part_size_n,
|
||||||
|
group_size=-1,
|
||||||
|
scale_perm=marlin_scale_perm[num_bits],
|
||||||
|
scale_perm_single=marlin_scale_perm_single[num_bits])
|
||||||
|
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||||
|
|
||||||
|
# Allocate marlin workspace
|
||||||
|
max_workspace_size = (part_size_n //
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||||
|
workspace = torch.zeros(max_workspace_size,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.workspace = workspace
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||||
assert q_w.shape == (size_k, size_n)
|
assert q_w.shape == (size_k, size_n)
|
||||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||||
|
|||||||
163
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
163
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_fp8_supported() -> bool:
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
|
||||||
|
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(
|
||||||
|
tensor: torch.Tensor, inv_scale: Union[float,
|
||||||
|
torch.Tensor]) -> torch.Tensor:
|
||||||
|
fake_qweight = tensor.to(torch.float16)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def all_close_1d(x: torch.Tensor) -> bool:
|
||||||
|
assert len(x.shape) == 1
|
||||||
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def create_per_tensor_scale_param(
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
**extra_weight_attrs,
|
||||||
|
) -> Parameter:
|
||||||
|
scale = Parameter(torch.empty(len(output_partition_sizes),
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
|
set_weight_attrs(scale, {
|
||||||
|
"needs_scalar_to_array": True,
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
return scale
|
||||||
|
|
||||||
|
|
||||||
|
def create_per_channel_scale_param(output_partition_sizes: List[int],
|
||||||
|
**extra_weight_attrs) -> Parameter:
|
||||||
|
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
scale[:] = torch.finfo(torch.float32).min
|
||||||
|
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
|
||||||
|
return scale
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_channelwise(
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Create channelwise buffer
|
||||||
|
weight_scale_channel = torch.empty((sum(logical_widths), 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=weight_scale.device)
|
||||||
|
|
||||||
|
# Expand each scale to match the size of each logical matrix.
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_scale_channel[start:end, :] = weight_scale[idx]
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight_scale_channel
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor, weight_scale: torch.Tensor,
|
||||||
|
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max()
|
||||||
|
|
||||||
|
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||||
|
# weight scales are still set to the default since we initialize
|
||||||
|
# N weight scales for N shards but we only load 1 weight scale
|
||||||
|
# from disk in this case. Skip requantization in this case (since)
|
||||||
|
# we already are quantized with the single scale.
|
||||||
|
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||||
|
unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
|
||||||
|
torch.float8_e4m3fn).min)
|
||||||
|
|
||||||
|
# If unfused checkpoint, need requanize with the single scale.
|
||||||
|
if unfused_module_in_checkpoint:
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(weight[start:end, :],
|
||||||
|
weight_scale[idx])
|
||||||
|
weight[start:end, :], _ = ops.scaled_fp8_quant(
|
||||||
|
weight_dq, max_w_scale)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return max_w_scale, weight
|
||||||
|
|
||||||
|
|
||||||
|
def apply_fp8_linear(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
cutlass_fp8_supported: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
|
||||||
|
if bias is None and cutlass_fp8_supported:
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
||||||
|
|
||||||
|
# Fused GEMM_DQ
|
||||||
|
output = ops.cutlass_scaled_mm(qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale)
|
||||||
|
|
||||||
|
else:
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(input,
|
||||||
|
input_scale,
|
||||||
|
batch_dim_padding=17)
|
||||||
|
|
||||||
|
# Fused GEMM_DQ -- note we padded the input above because
|
||||||
|
# torch._scaled_mm is more performant for matrices with
|
||||||
|
# batch dimension > 16. Note that this could change
|
||||||
|
# in the future.
|
||||||
|
output, _ = torch._scaled_mm(qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
return torch.narrow(output, 0, 0, input.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_int8_linear(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if bias is not None:
|
||||||
|
raise NotImplementedError("W8A8 with int8 does not yet support bias.")
|
||||||
|
|
||||||
|
# ops.scaled_int8_quant supports both dynamic and static quant.
|
||||||
|
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
|
||||||
|
|
||||||
|
return ops.cutlass_scaled_mm(x_q,
|
||||||
|
weight,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale,
|
||||||
|
out_dtype=input.dtype)
|
||||||
Loading…
x
Reference in New Issue
Block a user