[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (#25854)

Signed-off-by: zhoukz <me@zhoukz.com>
This commit is contained in:
Zhou Jiahao 2025-09-29 18:59:04 +08:00 committed by GitHub
parent edbaadd91f
commit 8616300ae2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
import collections
import collections.abc
from collections.abc import Iterable, Mapping, Sequence
@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
import numpy as np
import torch
import torch.nn as nn
import torchaudio.transforms as audio_transforms
import torchaudio.functional as F
from torch.nn.functional import scaled_dot_product_attention
from transformers import BatchFeature
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = ColumnParallelLinear(input_size=in_features,
output_size=hidden_features,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc1 = ColumnParallelLinear(
input_size=in_features,
output_size=hidden_features,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear(input_size=hidden_features,
output_size=out_features,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
self.fc2 = RowParallelLinear(
input_size=hidden_features,
output_size=out_features,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
causal: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.attn = MultiHeadAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
)
self.proj = RowParallelLinear(
input_size=dim,
output_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.causal = causal
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv_out, _ = self.qkv(x)
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
qkv, _ = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn_out = self.attn(q, k, v)
C_local = attn_out.numel() // (B * N) # C_local for parallel
attn_out = attn_out.view(B, N, C_local)
x, _ = self.proj(attn_out)
x = scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask[:, None, None, :] if mask is not None else None,
)
x = x.transpose(1, 2).reshape(B, N, C)
x, _ = self.proj(x)
return x
@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
return x
class DashengFrontend(nn.Module):
def __init__(self, config: DashengConfig):
super().__init__()
self.config = config
spectrogram_window = torch.hann_window(self.config.win_length)
self.register_buffer(
"spectrogram_window",
spectrogram_window,
persistent=False,
)
self.spectrogram_window: torch.Tensor
melscale_fbanks = F.melscale_fbanks(
n_freqs=self.config.n_fft // 2 + 1,
f_min=self.config.f_min,
f_max=self.config.f_max,
n_mels=self.config.n_mels,
sample_rate=self.config.sample_rate,
)
self.register_buffer("melscale_fbanks",
melscale_fbanks,
persistent=False)
self.melscale_fbanks: torch.Tensor
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = F.spectrogram(
waveform=waveform.to(torch.float32),
pad=0,
window=self.spectrogram_window,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
win_length=self.config.win_length,
power=2,
normalized=False,
center=self.config.center,
)
mel_spectrogram = (
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
# x has shape [batch, freq, time].
# F.amplitude_to_DB accepts inputs shaped as:
# - [freq, time]
# - [channel, freq, time]
# - [..., channel, freq, time]
# Here we insert a channel dimension of size 1 before calling it,
# then remove that extra dimension afterward.
log_mel_spectrogram = F.amplitude_to_DB(
mel_spectrogram.unsqueeze(1),
multiplier=10,
amin=1e-10,
db_multiplier=0,
top_db=120,
).squeeze(1)
return log_mel_spectrogram.to(waveform.dtype)
class DashengAudioTransformer(nn.Module):
def __init__(
@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
self.target_length = config.target_length
self.hop_length = config.hop_length
self._init_front_end(config)
self.front_end = DashengFrontend(config)
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
qkv_bias=config.qkv_bias,
init_values=config.init_values,
quant_config=quant_config,
prefix=f"{prefix}.block{i}",
prefix=f"{prefix}.blocks.{i}",
) for i in range(config.depth))
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
def _init_front_end(self, config):
with set_default_torch_dtype(torch.float32):
self.front_end = nn.Sequential(
audio_transforms.MelSpectrogram(
f_min=config.f_min,
f_max=config.f_max,
center=config.center,
win_length=config.win_length,
hop_length=config.hop_length,
sample_rate=config.sample_rate,
n_fft=config.n_fft,
n_mels=config.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)
mel_spectrogram = self.front_end[0]
fb = mel_spectrogram.mel_scale.fb
win = mel_spectrogram.spectrogram.window
mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to(
torch.float32)
mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to(
torch.float32)
def forward_features(
self,
x: torch.Tensor,
@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.net.0",
return_bias=False,
), get_act_fn("gelu"),
),
get_act_fn("gelu"),
RowParallelLinear(
input_size=out_dim,
output_size=out_dim,
quant_config=quant_config,
prefix=f"{prefix}.net.2",
return_bias=False,
))
),
)
def forward(self, x, mask=None):
batch_size, seq_len, dim = x.shape
@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
# + Padding
min_audio_len = self.info.get_min_audio_len()
processed_audios = [
np.pad(audio, (0, min_audio_len - audio.shape[-1]),
mode='constant',
constant_values=0) if isinstance(audio, np.ndarray)
np.pad(
audio,
(0, min_audio_len - audio.shape[-1]),
mode="constant",
constant_values=0,
) if isinstance(audio, np.ndarray)
and audio.shape[-1] < min_audio_len else audio for audio in audios
]
@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
if audio_length is None:
audio_output_lengths = []
else:
audio_length_np = audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length
audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame
@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs=MiDashengLMDummyInputsBuilder,
)
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
raise ValueError(
f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:])
@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
audio_length_np = audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length
audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame
@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_output_lengths = torch.tensor(audio_output_lengths).to(
audio_embeddings.device)
audio_feature_mask = (torch.arange(
audio_feature_mask = torch.arange(
max_audio_tokens,
device=audio_embeddings.device).unsqueeze(0).expand(
batch_size, max_audio_tokens)
< audio_output_lengths.unsqueeze(1))
batch_size,
max_audio_tokens) < audio_output_lengths.unsqueeze(1)
masked_audio_features = audio_embeddings[audio_feature_mask].view(
-1, embed_dim)
@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
)
input_ids = None
return self.decoder.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return self.decoder.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
def compute_logits(
self,