From 557b2e961d8e6fe47c67be61d089275102573590 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sun, 5 Oct 2025 17:18:14 +0100 Subject: [PATCH] Remove all cases of `fmt: on/off` (#26253) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../benchmark_fp8_block_dense_gemm.py | 226 +++++++++--------- tests/kernels/mamba/test_mamba_ssm_ssd.py | 69 ++++-- .../v1/engine/test_fast_incdec_prefix_err.py | 8 +- vllm/model_executor/models/qwen2_5_vl.py | 10 +- vllm/model_executor/models/voxtral.py | 59 +++-- 5 files changed, 216 insertions(+), 156 deletions(-) diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 2010b8038563..ba31bc563829 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# fmt: off # ruff: noqa: E501 import time @@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import ( ) -def benchmark_shape(m: int, - n: int, - k: int, - warmup: int = 100, - repeat: int = 10000, - verbose: bool = False) -> dict: +def benchmark_shape( + m: int, + n: int, + k: int, + warmup: int = 100, + repeat: int = 10000, + verbose: bool = False, +) -> dict: """Benchmark all implementations for a specific (m, n, k) shape.""" if verbose: print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") # Create test tensors - A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) # Reference result in BF16 torch.cuda.synchronize() @@ -49,34 +50,39 @@ def benchmark_shape(m: int, # Pre-quantize A for all implementations A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1]) A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - A, block_size[1], column_major_scales=True) + A, block_size[1], column_major_scales=True + ) # === DeepGEMM Implementation === def deepgemm_gemm(): - fp8_gemm_nt((A_deepgemm, A_scale_deepgemm), - (B_deepgemm, B_scale_deepgemm), - C_deepgemm) + fp8_gemm_nt( + (A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm + ) return C_deepgemm # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_triton_block_scaled_mm(A_vllm, - B_vllm, - A_scale_vllm, - B_scale_vllm, - block_size, - output_dtype=torch.bfloat16) + return w8a8_triton_block_scaled_mm( + A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16, + ) # === vLLM CUTLASS Implementation === def vllm_cutlass_gemm(): - return ops.cutlass_scaled_mm(A_vllm_cutlass, - B_vllm.T, - scale_a=A_scale_vllm_cutlass, - scale_b=B_scale_vllm.T, - out_dtype=torch.bfloat16) + return ops.cutlass_scaled_mm( + A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16, + ) # Run correctness check first if verbose: @@ -93,26 +99,23 @@ def benchmark_shape(m: int, print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") - print("vLLM Triton vs DeepGEMM difference: " - f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") - print("vLLM CUTLASS vs DeepGEMM difference: " - f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + print( + "vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}" + ) + print( + "vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}" + ) # Benchmark implementations implementations = { "DeepGEMM": deepgemm_gemm, "vLLM Triton": vllm_triton_gemm, - "vLLM CUTLASS": vllm_cutlass_gemm + "vLLM CUTLASS": vllm_cutlass_gemm, } - benchmark_results = { - "shape": { - "m": m, - "n": n, - "k": k - }, - "implementations": {} - } + benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}} for name, func in implementations.items(): # Warmup @@ -140,38 +143,36 @@ def benchmark_shape(m: int, "tflops": tflops, "gb_s": gb_s, "diff": { - "DeepGEMM": - 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), - "Reference": - deepgemm_diff if name == "DeepGEMM" else - (vllm_triton_diff - if name == "vLLM Triton" else vllm_cutlass_diff) - } + "DeepGEMM": 0.0 + if name == "DeepGEMM" + else calc_diff(func(), C_deepgemm), + "Reference": deepgemm_diff + if name == "DeepGEMM" + else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff), + }, } if verbose: - print( - f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" - ) + print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s") # Calculate speedups baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] for name, data in benchmark_results["implementations"].items(): if name != "DeepGEMM": speedup = baseline / data["time_ms"] - benchmark_results["implementations"][name][ - "speedup_vs_deepgemm"] = speedup + benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup if verbose: - print(f"DeepGEMM is {1/speedup:.2f}x " - f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") + print( + f"DeepGEMM is {1 / speedup:.2f}x " + f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}" + ) - vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ - "time_ms"] - vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"] cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time - benchmark_results["implementations"]["vLLM CUTLASS"][ - "speedup_vs_triton"] = cutlass_vs_triton + benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = ( + cutlass_vs_triton + ) if verbose: print( f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " @@ -183,8 +184,7 @@ def benchmark_shape(m: int, def format_table_row(values, widths): """Format a row with specified column widths.""" - return "| " + " | ".join(f"{val:{w}}" - for val, w in zip(values, widths)) + " |" + return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |" def print_table(headers, rows, title=None): @@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False): for result in all_results: shape = result["shape"] impl_data = result["implementations"]["DeepGEMM"] - deepgemm_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" - ]) + deepgemm_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + ] + ) - print_table(deepgemm_headers, - deepgemm_rows, - title="DeepGEMM Implementation:") + print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:") # Print vLLM Triton table - triton_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" - ] + triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"] triton_rows = [] for result in all_results: shape = result["shape"] impl_data = result["implementations"]["vLLM Triton"] speedup = impl_data.get("speedup_vs_deepgemm", 1.0) - triton_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(speedup) - ]) + triton_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(speedup), + ] + ) - print_table(triton_headers, - triton_rows, - title="vLLM Triton Implementation:") + print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:") # Print vLLM CUTLASS table cutlass_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", - "vs Triton" + "m", + "n", + "k", + "Time (μs)", + "TFLOPS", + "GB/s", + "vs DeepGEMM", + "vs Triton", ] cutlass_rows = [] for result in all_results: @@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False): impl_data = result["implementations"]["vLLM CUTLASS"] vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) vs_triton = impl_data.get("speedup_vs_triton", 1.0) - cutlass_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(vs_deepgemm), - format_speedup(vs_triton) - ]) + cutlass_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton), + ] + ) - print_table(cutlass_headers, - cutlass_rows, - title="vLLM CUTLASS Implementation:") + print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:") # Calculate and print averages print("\n===== AVERAGE PERFORMANCE =====") implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] avg_metrics = { - impl: { - "tflops": 0, - "gb_s": 0, - "time_ms": 0 - } - for impl in implementations + impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations } for result in all_results: @@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False): avg_tflops = avg_metrics[impl]["tflops"] / num_shapes avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes avg_time = avg_metrics[impl]["time_ms"] / num_shapes - avg_rows.append([ - impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" - ]) + avg_rows.append( + [impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"] + ) print_table(avg_headers, avg_rows) @@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False): avg_speedups = { "DeepGEMM vs vLLM Triton": 0, "DeepGEMM vs vLLM CUTLASS": 0, - "vLLM CUTLASS vs vLLM Triton": 0 + "vLLM CUTLASS vs vLLM Triton": 0, } for result in all_results: deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] - vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"] - avg_speedups[ - "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time - avg_speedups[ - "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time - avg_speedups[ - "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time + avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups["vLLM CUTLASS vs vLLM Triton"] += ( + vllm_triton_time / vllm_cutlass_time + ) print("\n===== AVERAGE SPEEDUPS =====") speedup_headers = ["Comparison", "Speedup"] @@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False): for result in all_results: for impl in implementations: - avg_diff[impl] += result["implementations"][impl]["diff"][ - "Reference"] + avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"] diff_headers = ["Implementation", "Avg Diff vs Reference"] diff_rows = [] diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index b4424b717d02..b068ea1ac49c 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -442,14 +442,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...] C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...] for i in range(num_sequences): - # fmt: off - chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + chunk_f = lambda x, i: x[ + cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ... + ] - X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 - dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 - B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 - C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 - # fmt: on + X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + X, i + ) + dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + dt, i + ) + B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + B, i + ) + C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + C, i + ) cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size) @@ -481,27 +489,42 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): dim=0, ) remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] - # fmt: off - remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] + remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] + remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] + remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] for i in range(num_sequences): - remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 + remaining_chunk_f = lambda x, i: x[ + cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ... + ] - remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 - remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 - remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 - remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + remaining_X_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(X, i) + remaining_dt_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(dt, i) + remaining_B_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(B, i) + remaining_C_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(C, i) # assert input chunking is correct - concat_chunk_f = lambda pt1, pt2, i: torch.cat([ - pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], - pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + concat_chunk_f = lambda pt1, pt2, i: torch.cat( + [ + pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...], + pt2[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], + ..., + ], ], - dim=0) - concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501 - # fmt: on + dim=0, + ) + concat_batch_f = lambda pt1, pt2: torch.cat( + [concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0 + ) assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index fc1fc259f259..77e67d54e587 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -182,8 +182,9 @@ def test_fast_inc_detok_invalid_utf8_err_case(): finished = i == len(test_tokens) - 1 output += detokenizer.get_next_output_text(finished, delta=True) - # fmt: off - assert output == r'''[ + assert ( + output + == r"""[ { "source": "Résultats", "source_type": "CONCEPT", @@ -191,4 +192,5 @@ def test_fast_inc_detok_invalid_utf8_err_case(): "target": "Israël", "target_type": "ORGANIZATION", "target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »", - "relationship": "Obtention d'un niveau de''' + "relationship": "Obtention d'un niveau de""" + ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ef70e37da366..f7d2f6c584ca 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1398,12 +1398,10 @@ class Qwen2_5_VLForConditionalGeneration( # Tensors input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) - # fmt: off - mm_embeddings_out = [mm[:, :-4] for mm in - multimodal_embeddings] - mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in - multimodal_embeddings] - # fmt: in + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] positions, mrope_positions_delta = recompute_mrope_positions( input_ids_t, diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 8525bdd5bfad..0d77b72675e2 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -516,14 +516,18 @@ class VoxtralForConditionalGeneration( ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # fmt: off remapping_rules = [ (r"mm_whisper_embeddings\.(.*)", r"\1"), (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"), - (r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501 - (r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501 + ( + r"audio_language_adapter\.0\.weight", + r"audio_language_adapter.w_in.weight", + ), + ( + r"audio_language_adapter\.2\.weight", + r"audio_language_adapter.w_out.weight", + ), ] - # fmt: on audio_params = dict( nn.ModuleDict( @@ -678,19 +682,44 @@ class AudioLanguageAdapter(nn.Module): class VoxtralEncoderModel(nn.Module): packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - # fmt: off mistral_remapping = [ - (r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501 - (r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501 + ( + r"whisper_encoder\.conv_layers\.0\.(weight|bias)", + r"whisper_encoder.conv1.\1", + ), + ( + r"whisper_encoder\.conv_layers\.1\.(weight|bias)", + r"whisper_encoder.conv2.\1", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn.\2_proj.\3", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn.out_proj.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn_layer_norm.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.mlp.fc1.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.mlp.fc2.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", + r"whisper_encoder.layers.\1.final_layer_norm.\2", + ), + ( + r"whisper_encoder\.transformer\.norm\.(weight|bias)", + r"whisper_encoder.layer_norm.\1", + ), ] - # fmt: on def __init__( self,