[RFC] [Mistral] FP8 format (#10130)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Patrick von Platen 2025-02-08 22:12:53 +01:00 committed by GitHub
parent 870c37481e
commit d366ccc4e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 12 deletions

View File

@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
mistral_mapping = { mistral_mapping = {
"layers": "model.layers", "layers": "model.layers",
"attention": "self_attn", "attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"wq": "q_proj", "wq": "q_proj",
"wk": "k_proj", "wk": "k_proj",
"wv": "v_proj", "wv": "v_proj",
@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
modules = name.split(".") modules = name.split(".")
# rotary embeds should be sliced # rotary embeds should be sliced
if "wk" in modules: if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads) self.config.num_key_value_heads)
elif "wq" in modules: elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight, loaded_weight = permute(loaded_weight,
self.config.num_attention_heads) self.config.num_attention_heads)
for item in modules: num_modules = len(modules)
if item in mapping and mapping[item] not in name: for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = (f"{item}.{next_item}"
if next_item is not None else None)
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item]) name = name.replace(item, mapping[item])
return name, loaded_weight return name, loaded_weight

View File

@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext):
tokenizer_mode=ctx.model_config.tokenizer_mode) tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder mm_encoder = tokenizer.instruct.mm_encoder
max_image_size = mm_encoder.mm_config.max_image_size image_config = mm_encoder.mm_config if hasattr(
image_patch_size = mm_encoder.mm_config.image_patch_size mm_encoder, "mm_config") else mm_encoder.image_config
max_image_size = image_config.max_image_size
image_patch_size = image_config.image_patch_size
return ((max_image_size // image_patch_size)**2) return ((max_image_size // image_patch_size)**2)

View File

@ -4,7 +4,7 @@ import enum
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Literal, Optional, Type, Union
import huggingface_hub import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files, from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
@ -554,7 +554,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
for key, value in elem.items(): for key, value in elem.items():
key = config_mapping.get(key, key) key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value) config_dict[key] = recurse_elems(value)
return PretrainedConfig(**config_dict)
return config_dict
else: else:
return elem return elem
@ -566,12 +567,30 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict["max_position_embeddings"] = config_dict.get( config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000) "max_position_embeddings", 128_000)
if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")
config_dict["quantization_config"] = quantization_config
config_type: Literal["text",
"multimodal"] = "multimodal" if config_dict.get(
"vision_encoder") is not None else "text"
if config_dict.get("moe") is not None: if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"] config_dict["architectures"] = ["MixtralForCausalLM"]
else: else:
config_dict["architectures"] = ["MistralForCausalLM"] config_dict["architectures"] = ["MistralForCausalLM"]
if config_dict.get("vision_encoder") is not None: if config_type == "multimodal":
multimodal_config = config_dict.pop("vision_encoder") multimodal_config = config_dict.pop("vision_encoder")
config_dict = { config_dict = {
@ -583,8 +602,16 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
config_dict.update(kwargs) config_dict.update(kwargs)
config = recurse_elems(config_dict) config_dict = recurse_elems(config_dict)
return config
# transform to HF config format
if config_type == "multimodal":
config_dict["text_config"] = PretrainedConfig(
**config_dict["text_config"])
config_dict["vision_config"] = PretrainedConfig(
**config_dict["vision_config"])
return PretrainedConfig(**config_dict)
def get_hf_image_processor_config( def get_hf_image_processor_config(

View File

@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
def find_tokenizer_file(files: List[str]): def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") file_pattern = re.compile(
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
matched_files = [file for file in files if file_pattern.match(file)] matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1: if len(matched_files) > 1: