mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 07:11:21 +08:00
[RFC] [Mistral] FP8 format (#10130)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
870c37481e
commit
d366ccc4e3
@ -467,6 +467,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
mistral_mapping = {
|
mistral_mapping = {
|
||||||
"layers": "model.layers",
|
"layers": "model.layers",
|
||||||
"attention": "self_attn",
|
"attention": "self_attn",
|
||||||
|
"qscale_act": "input_scale",
|
||||||
|
"qscale_weight": "weight_scale",
|
||||||
|
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||||
"wq": "q_proj",
|
"wq": "q_proj",
|
||||||
"wk": "k_proj",
|
"wk": "k_proj",
|
||||||
"wv": "v_proj",
|
"wv": "v_proj",
|
||||||
@ -590,15 +593,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
modules = name.split(".")
|
modules = name.split(".")
|
||||||
|
|
||||||
# rotary embeds should be sliced
|
# rotary embeds should be sliced
|
||||||
if "wk" in modules:
|
if "wk" in modules and modules[-1] == "weight":
|
||||||
loaded_weight = permute(loaded_weight,
|
loaded_weight = permute(loaded_weight,
|
||||||
self.config.num_key_value_heads)
|
self.config.num_key_value_heads)
|
||||||
elif "wq" in modules:
|
elif "wq" in modules and modules[-1] == "weight":
|
||||||
loaded_weight = permute(loaded_weight,
|
loaded_weight = permute(loaded_weight,
|
||||||
self.config.num_attention_heads)
|
self.config.num_attention_heads)
|
||||||
|
|
||||||
for item in modules:
|
num_modules = len(modules)
|
||||||
if item in mapping and mapping[item] not in name:
|
for i in range(num_modules):
|
||||||
|
item = modules[i]
|
||||||
|
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||||
|
|
||||||
|
combined_item = (f"{item}.{next_item}"
|
||||||
|
if next_item is not None else None)
|
||||||
|
|
||||||
|
if combined_item in mapping:
|
||||||
|
name = name.replace(combined_item, mapping[combined_item])
|
||||||
|
elif item in mapping and mapping[item] not in name:
|
||||||
name = name.replace(item, mapping[item])
|
name = name.replace(item, mapping[item])
|
||||||
|
|
||||||
return name, loaded_weight
|
return name, loaded_weight
|
||||||
|
|||||||
@ -54,8 +54,11 @@ def get_max_pixtral_image_tokens(ctx: InputContext):
|
|||||||
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
tokenizer_mode=ctx.model_config.tokenizer_mode)
|
||||||
mm_encoder = tokenizer.instruct.mm_encoder
|
mm_encoder = tokenizer.instruct.mm_encoder
|
||||||
|
|
||||||
max_image_size = mm_encoder.mm_config.max_image_size
|
image_config = mm_encoder.mm_config if hasattr(
|
||||||
image_patch_size = mm_encoder.mm_config.image_patch_size
|
mm_encoder, "mm_config") else mm_encoder.image_config
|
||||||
|
|
||||||
|
max_image_size = image_config.max_image_size
|
||||||
|
image_patch_size = image_config.image_patch_size
|
||||||
|
|
||||||
return ((max_image_size // image_patch_size)**2)
|
return ((max_image_size // image_patch_size)**2)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import enum
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Type, Union
|
from typing import Any, Dict, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
|
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
|
||||||
@ -554,7 +554,8 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
|
|||||||
for key, value in elem.items():
|
for key, value in elem.items():
|
||||||
key = config_mapping.get(key, key)
|
key = config_mapping.get(key, key)
|
||||||
config_dict[key] = recurse_elems(value)
|
config_dict[key] = recurse_elems(value)
|
||||||
return PretrainedConfig(**config_dict)
|
|
||||||
|
return config_dict
|
||||||
else:
|
else:
|
||||||
return elem
|
return elem
|
||||||
|
|
||||||
@ -566,12 +567,30 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
|
|||||||
config_dict["max_position_embeddings"] = config_dict.get(
|
config_dict["max_position_embeddings"] = config_dict.get(
|
||||||
"max_position_embeddings", 128_000)
|
"max_position_embeddings", 128_000)
|
||||||
|
|
||||||
|
if config_dict.get("quantization") is not None:
|
||||||
|
quantization = config_dict.get("quantization", {})
|
||||||
|
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||||
|
# This maps to the FP8 static per-tensor quantization scheme
|
||||||
|
quantization_config = {
|
||||||
|
"quant_method": "fp8",
|
||||||
|
"activation_scheme": "static"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found unknown quantization='{quantization}' in config")
|
||||||
|
|
||||||
|
config_dict["quantization_config"] = quantization_config
|
||||||
|
|
||||||
|
config_type: Literal["text",
|
||||||
|
"multimodal"] = "multimodal" if config_dict.get(
|
||||||
|
"vision_encoder") is not None else "text"
|
||||||
|
|
||||||
if config_dict.get("moe") is not None:
|
if config_dict.get("moe") is not None:
|
||||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||||
else:
|
else:
|
||||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||||
|
|
||||||
if config_dict.get("vision_encoder") is not None:
|
if config_type == "multimodal":
|
||||||
multimodal_config = config_dict.pop("vision_encoder")
|
multimodal_config = config_dict.pop("vision_encoder")
|
||||||
|
|
||||||
config_dict = {
|
config_dict = {
|
||||||
@ -583,8 +602,16 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
|
|||||||
|
|
||||||
config_dict.update(kwargs)
|
config_dict.update(kwargs)
|
||||||
|
|
||||||
config = recurse_elems(config_dict)
|
config_dict = recurse_elems(config_dict)
|
||||||
return config
|
|
||||||
|
# transform to HF config format
|
||||||
|
if config_type == "multimodal":
|
||||||
|
config_dict["text_config"] = PretrainedConfig(
|
||||||
|
**config_dict["text_config"])
|
||||||
|
config_dict["vision_config"] = PretrainedConfig(
|
||||||
|
**config_dict["vision_config"])
|
||||||
|
|
||||||
|
return PretrainedConfig(**config_dict)
|
||||||
|
|
||||||
|
|
||||||
def get_hf_image_processor_config(
|
def get_hf_image_processor_config(
|
||||||
|
|||||||
@ -88,7 +88,8 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def find_tokenizer_file(files: List[str]):
|
def find_tokenizer_file(files: List[str]):
|
||||||
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
|
file_pattern = re.compile(
|
||||||
|
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
|
||||||
|
|
||||||
matched_files = [file for file in files if file_pattern.match(file)]
|
matched_files = [file for file in files if file_pattern.match(file)]
|
||||||
if len(matched_files) > 1:
|
if len(matched_files) > 1:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user