mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 03:34:29 +08:00
[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (#25854)
Signed-off-by: zhoukz <me@zhoukz.com>
This commit is contained in:
parent
edbaadd91f
commit
8616300ae2
@ -22,6 +22,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio.transforms as audio_transforms
|
||||
import torchaudio.functional as F
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = ColumnParallelLinear(input_size=in_features,
|
||||
output_size=hidden_features,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_size=hidden_features,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.fc2 = RowParallelLinear(input_size=hidden_features,
|
||||
output_size=out_features,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
self.fc2 = RowParallelLinear(
|
||||
input_size=hidden_features,
|
||||
output_size=out_features,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.fc1(x)
|
||||
@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
causal: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
)
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=dim,
|
||||
output_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
)
|
||||
self.causal = causal
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
||||
B, N, C = x.shape
|
||||
|
||||
qkv_out, _ = self.qkv(x)
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
qkv, _ = self.qkv(x)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
attn_out = self.attn(q, k, v)
|
||||
C_local = attn_out.numel() // (B * N) # C_local for parallel
|
||||
attn_out = attn_out.view(B, N, C_local)
|
||||
|
||||
x, _ = self.proj(attn_out)
|
||||
x = scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=mask[:, None, None, :] if mask is not None else None,
|
||||
)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DashengFrontend(nn.Module):
|
||||
|
||||
def __init__(self, config: DashengConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
spectrogram_window = torch.hann_window(self.config.win_length)
|
||||
self.register_buffer(
|
||||
"spectrogram_window",
|
||||
spectrogram_window,
|
||||
persistent=False,
|
||||
)
|
||||
self.spectrogram_window: torch.Tensor
|
||||
|
||||
melscale_fbanks = F.melscale_fbanks(
|
||||
n_freqs=self.config.n_fft // 2 + 1,
|
||||
f_min=self.config.f_min,
|
||||
f_max=self.config.f_max,
|
||||
n_mels=self.config.n_mels,
|
||||
sample_rate=self.config.sample_rate,
|
||||
)
|
||||
self.register_buffer("melscale_fbanks",
|
||||
melscale_fbanks,
|
||||
persistent=False)
|
||||
self.melscale_fbanks: torch.Tensor
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
||||
spectrogram = F.spectrogram(
|
||||
waveform=waveform.to(torch.float32),
|
||||
pad=0,
|
||||
window=self.spectrogram_window,
|
||||
n_fft=self.config.n_fft,
|
||||
hop_length=self.config.hop_length,
|
||||
win_length=self.config.win_length,
|
||||
power=2,
|
||||
normalized=False,
|
||||
center=self.config.center,
|
||||
)
|
||||
mel_spectrogram = (
|
||||
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
|
||||
# x has shape [batch, freq, time].
|
||||
# F.amplitude_to_DB accepts inputs shaped as:
|
||||
# - [freq, time]
|
||||
# - [channel, freq, time]
|
||||
# - [..., channel, freq, time]
|
||||
# Here we insert a channel dimension of size 1 before calling it,
|
||||
# then remove that extra dimension afterward.
|
||||
log_mel_spectrogram = F.amplitude_to_DB(
|
||||
mel_spectrogram.unsqueeze(1),
|
||||
multiplier=10,
|
||||
amin=1e-10,
|
||||
db_multiplier=0,
|
||||
top_db=120,
|
||||
).squeeze(1)
|
||||
return log_mel_spectrogram.to(waveform.dtype)
|
||||
|
||||
|
||||
class DashengAudioTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
|
||||
self.target_length = config.target_length
|
||||
self.hop_length = config.hop_length
|
||||
|
||||
self._init_front_end(config)
|
||||
self.front_end = DashengFrontend(config)
|
||||
|
||||
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
|
||||
|
||||
@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
|
||||
qkv_bias=config.qkv_bias,
|
||||
init_values=config.init_values,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.block{i}",
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
) for i in range(config.depth))
|
||||
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
||||
|
||||
def _init_front_end(self, config):
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
self.front_end = nn.Sequential(
|
||||
audio_transforms.MelSpectrogram(
|
||||
f_min=config.f_min,
|
||||
f_max=config.f_max,
|
||||
center=config.center,
|
||||
win_length=config.win_length,
|
||||
hop_length=config.hop_length,
|
||||
sample_rate=config.sample_rate,
|
||||
n_fft=config.n_fft,
|
||||
n_mels=config.n_mels,
|
||||
),
|
||||
audio_transforms.AmplitudeToDB(top_db=120),
|
||||
)
|
||||
|
||||
mel_spectrogram = self.front_end[0]
|
||||
fb = mel_spectrogram.mel_scale.fb
|
||||
win = mel_spectrogram.spectrogram.window
|
||||
mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to(
|
||||
torch.float32)
|
||||
mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to(
|
||||
torch.float32)
|
||||
|
||||
def forward_features(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.net.0",
|
||||
return_bias=False,
|
||||
), get_act_fn("gelu"),
|
||||
),
|
||||
get_act_fn("gelu"),
|
||||
RowParallelLinear(
|
||||
input_size=out_dim,
|
||||
output_size=out_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.net.2",
|
||||
return_bias=False,
|
||||
))
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
batch_size, seq_len, dim = x.shape
|
||||
@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
|
||||
# + Padding
|
||||
min_audio_len = self.info.get_min_audio_len()
|
||||
processed_audios = [
|
||||
np.pad(audio, (0, min_audio_len - audio.shape[-1]),
|
||||
mode='constant',
|
||||
constant_values=0) if isinstance(audio, np.ndarray)
|
||||
np.pad(
|
||||
audio,
|
||||
(0, min_audio_len - audio.shape[-1]),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
) if isinstance(audio, np.ndarray)
|
||||
and audio.shape[-1] < min_audio_len else audio for audio in audios
|
||||
]
|
||||
|
||||
@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
|
||||
if audio_length is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
audio_length_np = audio_length.cpu().numpy() if isinstance(
|
||||
audio_length, torch.Tensor) else audio_length
|
||||
audio_length_np = (audio_length.cpu().numpy() if isinstance(
|
||||
audio_length, torch.Tensor) else audio_length)
|
||||
audio_output_lengths = [
|
||||
max(1, calculate_mel_frames_dasheng(
|
||||
int(length))) # at least one frame
|
||||
@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
|
||||
dummy_inputs=MiDashengLMDummyInputsBuilder,
|
||||
)
|
||||
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
|
||||
@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
audio_input["input_values"].dtype)
|
||||
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
|
||||
|
||||
audio_length_np = audio_length.cpu().numpy() if isinstance(
|
||||
audio_length, torch.Tensor) else audio_length
|
||||
audio_length_np = (audio_length.cpu().numpy() if isinstance(
|
||||
audio_length, torch.Tensor) else audio_length)
|
||||
audio_output_lengths = [
|
||||
max(1, calculate_mel_frames_dasheng(
|
||||
int(length))) # at least one frame
|
||||
@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
audio_output_lengths = torch.tensor(audio_output_lengths).to(
|
||||
audio_embeddings.device)
|
||||
|
||||
audio_feature_mask = (torch.arange(
|
||||
audio_feature_mask = torch.arange(
|
||||
max_audio_tokens,
|
||||
device=audio_embeddings.device).unsqueeze(0).expand(
|
||||
batch_size, max_audio_tokens)
|
||||
< audio_output_lengths.unsqueeze(1))
|
||||
batch_size,
|
||||
max_audio_tokens) < audio_output_lengths.unsqueeze(1)
|
||||
|
||||
masked_audio_features = audio_embeddings[audio_feature_mask].view(
|
||||
-1, embed_dim)
|
||||
@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
return self.decoder.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return self.decoder.model(
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user