Remove all cases of fmt: on/off (#26253)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-10-05 17:18:14 +01:00 committed by GitHub
parent 4e256cadc2
commit 557b2e961d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 216 additions and 156 deletions

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# ruff: noqa: E501
import time
@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import (
)
def benchmark_shape(m: int,
n: int,
k: int,
warmup: int = 100,
repeat: int = 10000,
verbose: bool = False) -> dict:
def benchmark_shape(
m: int,
n: int,
k: int,
warmup: int = 100,
repeat: int = 10000,
verbose: bool = False,
) -> dict:
"""Benchmark all implementations for a specific (m, n, k) shape."""
if verbose:
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
# Create test tensors
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16
torch.cuda.synchronize()
@ -49,34 +50,39 @@ def benchmark_shape(m: int,
# Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
A, block_size[1], column_major_scales=True)
A, block_size[1], column_major_scales=True
)
# === DeepGEMM Implementation ===
def deepgemm_gemm():
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
(B_deepgemm, B_scale_deepgemm),
C_deepgemm)
fp8_gemm_nt(
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
)
return C_deepgemm
# === vLLM Triton Implementation ===
def vllm_triton_gemm():
return w8a8_triton_block_scaled_mm(A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
block_size,
output_dtype=torch.bfloat16)
return w8a8_triton_block_scaled_mm(
A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
block_size,
output_dtype=torch.bfloat16,
)
# === vLLM CUTLASS Implementation ===
def vllm_cutlass_gemm():
return ops.cutlass_scaled_mm(A_vllm_cutlass,
B_vllm.T,
scale_a=A_scale_vllm_cutlass,
scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16)
return ops.cutlass_scaled_mm(
A_vllm_cutlass,
B_vllm.T,
scale_a=A_scale_vllm_cutlass,
scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16,
)
# Run correctness check first
if verbose:
@ -93,26 +99,23 @@ def benchmark_shape(m: int,
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
print("vLLM Triton vs DeepGEMM difference: "
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
print("vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
print(
"vLLM Triton vs DeepGEMM difference: "
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
)
print(
"vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
)
# Benchmark implementations
implementations = {
"DeepGEMM": deepgemm_gemm,
"vLLM Triton": vllm_triton_gemm,
"vLLM CUTLASS": vllm_cutlass_gemm
"vLLM CUTLASS": vllm_cutlass_gemm,
}
benchmark_results = {
"shape": {
"m": m,
"n": n,
"k": k
},
"implementations": {}
}
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
for name, func in implementations.items():
# Warmup
@ -140,38 +143,36 @@ def benchmark_shape(m: int,
"tflops": tflops,
"gb_s": gb_s,
"diff": {
"DeepGEMM":
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
"Reference":
deepgemm_diff if name == "DeepGEMM" else
(vllm_triton_diff
if name == "vLLM Triton" else vllm_cutlass_diff)
}
"DeepGEMM": 0.0
if name == "DeepGEMM"
else calc_diff(func(), C_deepgemm),
"Reference": deepgemm_diff
if name == "DeepGEMM"
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
},
}
if verbose:
print(
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
)
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
# Calculate speedups
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
for name, data in benchmark_results["implementations"].items():
if name != "DeepGEMM":
speedup = baseline / data["time_ms"]
benchmark_results["implementations"][name][
"speedup_vs_deepgemm"] = speedup
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
if verbose:
print(f"DeepGEMM is {1/speedup:.2f}x "
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
print(
f"DeepGEMM is {1 / speedup:.2f}x "
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
)
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
"time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
"time_ms"]
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
benchmark_results["implementations"]["vLLM CUTLASS"][
"speedup_vs_triton"] = cutlass_vs_triton
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
cutlass_vs_triton
)
if verbose:
print(
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
@ -183,8 +184,7 @@ def benchmark_shape(m: int,
def format_table_row(values, widths):
"""Format a row with specified column widths."""
return "| " + " | ".join(f"{val:{w}}"
for val, w in zip(values, widths)) + " |"
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
def print_table(headers, rows, title=None):
@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False):
for result in all_results:
shape = result["shape"]
impl_data = result["implementations"]["DeepGEMM"]
deepgemm_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
])
deepgemm_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
]
)
print_table(deepgemm_headers,
deepgemm_rows,
title="DeepGEMM Implementation:")
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
# Print vLLM Triton table
triton_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
]
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
triton_rows = []
for result in all_results:
shape = result["shape"]
impl_data = result["implementations"]["vLLM Triton"]
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
triton_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
format_speedup(speedup)
])
triton_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(speedup),
]
)
print_table(triton_headers,
triton_rows,
title="vLLM Triton Implementation:")
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
# Print vLLM CUTLASS table
cutlass_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
"vs Triton"
"m",
"n",
"k",
"Time (μs)",
"TFLOPS",
"GB/s",
"vs DeepGEMM",
"vs Triton",
]
cutlass_rows = []
for result in all_results:
@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False):
impl_data = result["implementations"]["vLLM CUTLASS"]
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
cutlass_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton)
])
cutlass_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton),
]
)
print_table(cutlass_headers,
cutlass_rows,
title="vLLM CUTLASS Implementation:")
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
# Calculate and print averages
print("\n===== AVERAGE PERFORMANCE =====")
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
avg_metrics = {
impl: {
"tflops": 0,
"gb_s": 0,
"time_ms": 0
}
for impl in implementations
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
}
for result in all_results:
@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False):
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
avg_rows.append([
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
])
avg_rows.append(
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
)
print_table(avg_headers, avg_rows)
@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False):
avg_speedups = {
"DeepGEMM vs vLLM Triton": 0,
"DeepGEMM vs vLLM CUTLASS": 0,
"vLLM CUTLASS vs vLLM Triton": 0
"vLLM CUTLASS vs vLLM Triton": 0,
}
for result in all_results:
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
"time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
avg_speedups[
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
avg_speedups[
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups[
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
vllm_triton_time / vllm_cutlass_time
)
print("\n===== AVERAGE SPEEDUPS =====")
speedup_headers = ["Comparison", "Speedup"]
@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
for result in all_results:
for impl in implementations:
avg_diff[impl] += result["implementations"][impl]["diff"][
"Reference"]
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
diff_headers = ["Implementation", "Avg Diff vs Reference"]
diff_rows = []

View File

@ -442,14 +442,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
chunk_f = lambda x, i: x[
cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ...
]
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on
X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
X, i
)
dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
dt, i
)
B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
B, i
)
C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
C, i
)
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
@ -481,27 +489,42 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
dim=0,
)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...]
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...]
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...]
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...]
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_chunk_f = lambda x, i: x[
cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ...
]
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
remaining_X_chunked[
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
] = remaining_chunk_f(X, i)
remaining_dt_chunked[
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
] = remaining_chunk_f(dt, i)
remaining_B_chunked[
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
] = remaining_chunk_f(B, i)
remaining_C_chunked[
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
] = remaining_chunk_f(C, i)
# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
concat_chunk_f = lambda pt1, pt2, i: torch.cat(
[
pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...],
pt2[
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1],
...,
],
],
dim=0)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
# fmt: on
dim=0,
)
concat_batch_f = lambda pt1, pt2: torch.cat(
[concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0
)
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)

View File

@ -182,8 +182,9 @@ def test_fast_inc_detok_invalid_utf8_err_case():
finished = i == len(test_tokens) - 1
output += detokenizer.get_next_output_text(finished, delta=True)
# fmt: off
assert output == r'''[
assert (
output
== r"""[
{
"source": "Résultats",
"source_type": "CONCEPT",
@ -191,4 +192,5 @@ def test_fast_inc_detok_invalid_utf8_err_case():
"target": "Israël",
"target_type": "ORGANIZATION",
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
"relationship": "Obtention d'un niveau de'''
"relationship": "Obtention d'un niveau de"""
)

View File

@ -1398,12 +1398,10 @@ class Qwen2_5_VLForConditionalGeneration(
# Tensors
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
# fmt: off
mm_embeddings_out = [mm[:, :-4] for mm in
multimodal_embeddings]
mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in
multimodal_embeddings]
# fmt: in
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
mm_embeddings_pos = [
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
]
positions, mrope_positions_delta = recompute_mrope_positions(
input_ids_t,

View File

@ -516,14 +516,18 @@ class VoxtralForConditionalGeneration(
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# fmt: off
remapping_rules = [
(r"mm_whisper_embeddings\.(.*)", r"\1"),
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
(r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501
(r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501
(
r"audio_language_adapter\.0\.weight",
r"audio_language_adapter.w_in.weight",
),
(
r"audio_language_adapter\.2\.weight",
r"audio_language_adapter.w_out.weight",
),
]
# fmt: on
audio_params = dict(
nn.ModuleDict(
@ -678,19 +682,44 @@ class AudioLanguageAdapter(nn.Module):
class VoxtralEncoderModel(nn.Module):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
# fmt: off
mistral_remapping = [
(r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501
(r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501
(r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501
(
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
r"whisper_encoder.conv1.\1",
),
(
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
r"whisper_encoder.conv2.\1",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.mlp.fc1.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.mlp.fc2.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
r"whisper_encoder.layers.\1.final_layer_norm.\2",
),
(
r"whisper_encoder\.transformer\.norm\.(weight|bias)",
r"whisper_encoder.layer_norm.\1",
),
]
# fmt: on
def __init__(
self,