Expand MLA to support most types of quantization (#13181)

This commit is contained in:
Michael Goin 2025-02-14 01:19:22 -05:00 committed by GitHub
parent f2b20fe491
commit f0b2da72a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 131 deletions

View File

@ -26,7 +26,7 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@ -220,16 +220,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self, act_dtype: torch.dtype):
def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
@ -239,7 +229,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant is not None:
if layer.quant_method.block_quant:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
@ -267,41 +257,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale
def get_layer_weight(layer):
if hasattr(layer, "weight"):
return layer.weight
elif hasattr(layer, "qweight"):
return layer.qweight
else:
raise AttributeError(
f"Layer '{layer}' has neither weight nor qweight")
def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer)
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
return scaled_dequantize(weight, scales,
weight_scale_group_shape)
else:
return layer.weight
if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")
weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
weight_dtype = get_layer_weight(self.kv_b_proj).dtype
assert get_layer_weight(self.o_proj).dtype == weight_dtype
assert get_layer_weight(self.q_proj).dtype == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (

View File

@ -991,37 +991,7 @@ class ModelConfig:
@property
def use_mla(self) -> bool:
if not self.is_deepseek_mla or envs.VLLM_MLA_DISABLE:
return False
if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
return False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if self.quantization == "compressed-tensors":
quant_config = self._parse_quant_hf_config()
for group_name, cfg in quant_config.get("config_groups", {
"": {}
}).items():
act_cfg = cfg.get("input_activations", {})
act_type = None if act_cfg is None else act_cfg.get("type", "")
w_cfg = cfg.get("weights", {})
w_type = None if w_cfg is None else w_cfg.get("type", "")
if act_type != "fp8" or w_type != "fp8":
logger.warning(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.\n "
"Full config: %s", group_name, act_type, w_type,
quant_config)
return False
return True
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
@property
def supported_runner_types(self) -> Set[RunnerType]:

View File

@ -153,6 +153,30 @@ def _initialize_model(
return model_class(**kwargs)
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for _, module in model.named_modules():
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
@ -376,7 +400,6 @@ class DefaultModelLoader(BaseModelLoader):
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
@ -394,23 +417,8 @@ class DefaultModelLoader(BaseModelLoader):
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
@ -429,29 +437,15 @@ class DummyModelLoader(BaseModelLoader):
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()
@ -632,6 +626,7 @@ class ShardedStateLoader(BaseModelLoader):
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
@ -640,18 +635,10 @@ class ShardedStateLoader(BaseModelLoader):
model_config.revision)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(
model_config.dtype)
_process_weights_after_loading(model, model_config,
target_device)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
@ -1401,16 +1388,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
self._get_weights_iterator(model_weights,
model_config.revision))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
_process_weights_after_loading(model, model_config, target_device)
return model.eval()