mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:35:54 +08:00
Remove all cases of fmt: on/off (#26253)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
4e256cadc2
commit
557b2e961d
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# fmt: off
|
||||
# ruff: noqa: E501
|
||||
import time
|
||||
|
||||
@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import (
|
||||
)
|
||||
|
||||
|
||||
def benchmark_shape(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
warmup: int = 100,
|
||||
repeat: int = 10000,
|
||||
verbose: bool = False) -> dict:
|
||||
def benchmark_shape(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
warmup: int = 100,
|
||||
repeat: int = 10000,
|
||||
verbose: bool = False,
|
||||
) -> dict:
|
||||
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||
if verbose:
|
||||
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||
|
||||
# Create test tensors
|
||||
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference result in BF16
|
||||
torch.cuda.synchronize()
|
||||
@ -49,34 +50,39 @@ def benchmark_shape(m: int,
|
||||
# Pre-quantize A for all implementations
|
||||
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
A, block_size[1], column_major_scales=True)
|
||||
A, block_size[1], column_major_scales=True
|
||||
)
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
|
||||
(B_deepgemm, B_scale_deepgemm),
|
||||
C_deepgemm)
|
||||
fp8_gemm_nt(
|
||||
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
|
||||
)
|
||||
return C_deepgemm
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
return w8a8_triton_block_scaled_mm(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
block_size,
|
||||
output_dtype=torch.bfloat16)
|
||||
return w8a8_triton_block_scaled_mm(
|
||||
A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
block_size,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# === vLLM CUTLASS Implementation ===
|
||||
def vllm_cutlass_gemm():
|
||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
scale_b=B_scale_vllm.T,
|
||||
out_dtype=torch.bfloat16)
|
||||
return ops.cutlass_scaled_mm(
|
||||
A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
scale_b=B_scale_vllm.T,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Run correctness check first
|
||||
if verbose:
|
||||
@ -93,26 +99,23 @@ def benchmark_shape(m: int,
|
||||
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||
print("vLLM Triton vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
|
||||
print("vLLM CUTLASS vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
|
||||
print(
|
||||
"vLLM Triton vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
|
||||
)
|
||||
print(
|
||||
"vLLM CUTLASS vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
|
||||
)
|
||||
|
||||
# Benchmark implementations
|
||||
implementations = {
|
||||
"DeepGEMM": deepgemm_gemm,
|
||||
"vLLM Triton": vllm_triton_gemm,
|
||||
"vLLM CUTLASS": vllm_cutlass_gemm
|
||||
"vLLM CUTLASS": vllm_cutlass_gemm,
|
||||
}
|
||||
|
||||
benchmark_results = {
|
||||
"shape": {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k
|
||||
},
|
||||
"implementations": {}
|
||||
}
|
||||
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
|
||||
|
||||
for name, func in implementations.items():
|
||||
# Warmup
|
||||
@ -140,38 +143,36 @@ def benchmark_shape(m: int,
|
||||
"tflops": tflops,
|
||||
"gb_s": gb_s,
|
||||
"diff": {
|
||||
"DeepGEMM":
|
||||
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
|
||||
"Reference":
|
||||
deepgemm_diff if name == "DeepGEMM" else
|
||||
(vllm_triton_diff
|
||||
if name == "vLLM Triton" else vllm_cutlass_diff)
|
||||
}
|
||||
"DeepGEMM": 0.0
|
||||
if name == "DeepGEMM"
|
||||
else calc_diff(func(), C_deepgemm),
|
||||
"Reference": deepgemm_diff
|
||||
if name == "DeepGEMM"
|
||||
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
|
||||
},
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
|
||||
)
|
||||
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
|
||||
|
||||
# Calculate speedups
|
||||
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||
for name, data in benchmark_results["implementations"].items():
|
||||
if name != "DeepGEMM":
|
||||
speedup = baseline / data["time_ms"]
|
||||
benchmark_results["implementations"][name][
|
||||
"speedup_vs_deepgemm"] = speedup
|
||||
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
|
||||
if verbose:
|
||||
print(f"DeepGEMM is {1/speedup:.2f}x "
|
||||
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
|
||||
print(
|
||||
f"DeepGEMM is {1 / speedup:.2f}x "
|
||||
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
|
||||
)
|
||||
|
||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
|
||||
"time_ms"]
|
||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
|
||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||
benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"speedup_vs_triton"] = cutlass_vs_triton
|
||||
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
|
||||
cutlass_vs_triton
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||
@ -183,8 +184,7 @@ def benchmark_shape(m: int,
|
||||
|
||||
def format_table_row(values, widths):
|
||||
"""Format a row with specified column widths."""
|
||||
return "| " + " | ".join(f"{val:{w}}"
|
||||
for val, w in zip(values, widths)) + " |"
|
||||
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
|
||||
|
||||
|
||||
def print_table(headers, rows, title=None):
|
||||
@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False):
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["DeepGEMM"]
|
||||
deepgemm_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
|
||||
])
|
||||
deepgemm_rows.append(
|
||||
[
|
||||
shape["m"],
|
||||
shape["n"],
|
||||
shape["k"],
|
||||
f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}",
|
||||
f"{impl_data['gb_s']:.1f}",
|
||||
]
|
||||
)
|
||||
|
||||
print_table(deepgemm_headers,
|
||||
deepgemm_rows,
|
||||
title="DeepGEMM Implementation:")
|
||||
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
|
||||
|
||||
# Print vLLM Triton table
|
||||
triton_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
|
||||
]
|
||||
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
|
||||
triton_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM Triton"]
|
||||
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
triton_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(speedup)
|
||||
])
|
||||
triton_rows.append(
|
||||
[
|
||||
shape["m"],
|
||||
shape["n"],
|
||||
shape["k"],
|
||||
f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}",
|
||||
f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(speedup),
|
||||
]
|
||||
)
|
||||
|
||||
print_table(triton_headers,
|
||||
triton_rows,
|
||||
title="vLLM Triton Implementation:")
|
||||
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
|
||||
|
||||
# Print vLLM CUTLASS table
|
||||
cutlass_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
|
||||
"vs Triton"
|
||||
"m",
|
||||
"n",
|
||||
"k",
|
||||
"Time (μs)",
|
||||
"TFLOPS",
|
||||
"GB/s",
|
||||
"vs DeepGEMM",
|
||||
"vs Triton",
|
||||
]
|
||||
cutlass_rows = []
|
||||
for result in all_results:
|
||||
@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False):
|
||||
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||
cutlass_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(vs_deepgemm),
|
||||
format_speedup(vs_triton)
|
||||
])
|
||||
cutlass_rows.append(
|
||||
[
|
||||
shape["m"],
|
||||
shape["n"],
|
||||
shape["k"],
|
||||
f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}",
|
||||
f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(vs_deepgemm),
|
||||
format_speedup(vs_triton),
|
||||
]
|
||||
)
|
||||
|
||||
print_table(cutlass_headers,
|
||||
cutlass_rows,
|
||||
title="vLLM CUTLASS Implementation:")
|
||||
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
|
||||
|
||||
# Calculate and print averages
|
||||
print("\n===== AVERAGE PERFORMANCE =====")
|
||||
|
||||
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||
avg_metrics = {
|
||||
impl: {
|
||||
"tflops": 0,
|
||||
"gb_s": 0,
|
||||
"time_ms": 0
|
||||
}
|
||||
for impl in implementations
|
||||
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False):
|
||||
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||
avg_rows.append([
|
||||
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
|
||||
])
|
||||
avg_rows.append(
|
||||
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
|
||||
)
|
||||
|
||||
print_table(avg_headers, avg_rows)
|
||||
|
||||
@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False):
|
||||
avg_speedups = {
|
||||
"DeepGEMM vs vLLM Triton": 0,
|
||||
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||
"vLLM CUTLASS vs vLLM Triton": 0
|
||||
"vLLM CUTLASS vs vLLM Triton": 0,
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
|
||||
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
|
||||
vllm_triton_time / vllm_cutlass_time
|
||||
)
|
||||
|
||||
print("\n===== AVERAGE SPEEDUPS =====")
|
||||
speedup_headers = ["Comparison", "Speedup"]
|
||||
@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
avg_diff[impl] += result["implementations"][impl]["diff"][
|
||||
"Reference"]
|
||||
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
|
||||
|
||||
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||
diff_rows = []
|
||||
|
||||
@ -442,14 +442,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
# fmt: off
|
||||
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
||||
chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ...
|
||||
]
|
||||
|
||||
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
||||
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
||||
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
||||
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
# fmt: on
|
||||
X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
X, i
|
||||
)
|
||||
dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
dt, i
|
||||
)
|
||||
B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
B, i
|
||||
)
|
||||
C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
C, i
|
||||
)
|
||||
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
|
||||
@ -481,27 +489,42 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
dim=0,
|
||||
)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
# fmt: off
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
||||
remaining_chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ...
|
||||
]
|
||||
|
||||
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
||||
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
||||
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
||||
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
||||
remaining_X_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(X, i)
|
||||
remaining_dt_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(dt, i)
|
||||
remaining_B_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(B, i)
|
||||
remaining_C_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(C, i)
|
||||
|
||||
# assert input chunking is correct
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
||||
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
||||
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat(
|
||||
[
|
||||
pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...],
|
||||
pt2[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1],
|
||||
...,
|
||||
],
|
||||
],
|
||||
dim=0)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
|
||||
# fmt: on
|
||||
dim=0,
|
||||
)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat(
|
||||
[concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0
|
||||
)
|
||||
|
||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
||||
|
||||
@ -182,8 +182,9 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
finished = i == len(test_tokens) - 1
|
||||
output += detokenizer.get_next_output_text(finished, delta=True)
|
||||
|
||||
# fmt: off
|
||||
assert output == r'''[
|
||||
assert (
|
||||
output
|
||||
== r"""[
|
||||
{
|
||||
"source": "Résultats",
|
||||
"source_type": "CONCEPT",
|
||||
@ -191,4 +192,5 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
"target": "Israël",
|
||||
"target_type": "ORGANIZATION",
|
||||
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
|
||||
"relationship": "Obtention d'un niveau de'''
|
||||
"relationship": "Obtention d'un niveau de"""
|
||||
)
|
||||
|
||||
@ -1398,12 +1398,10 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
# Tensors
|
||||
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
|
||||
|
||||
# fmt: off
|
||||
mm_embeddings_out = [mm[:, :-4] for mm in
|
||||
multimodal_embeddings]
|
||||
mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in
|
||||
multimodal_embeddings]
|
||||
# fmt: in
|
||||
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
|
||||
mm_embeddings_pos = [
|
||||
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
|
||||
]
|
||||
|
||||
positions, mrope_positions_delta = recompute_mrope_positions(
|
||||
input_ids_t,
|
||||
|
||||
@ -516,14 +516,18 @@ class VoxtralForConditionalGeneration(
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
# fmt: off
|
||||
remapping_rules = [
|
||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||
(r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501
|
||||
(r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501
|
||||
(
|
||||
r"audio_language_adapter\.0\.weight",
|
||||
r"audio_language_adapter.w_in.weight",
|
||||
),
|
||||
(
|
||||
r"audio_language_adapter\.2\.weight",
|
||||
r"audio_language_adapter.w_out.weight",
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
audio_params = dict(
|
||||
nn.ModuleDict(
|
||||
@ -678,19 +682,44 @@ class AudioLanguageAdapter(nn.Module):
|
||||
class VoxtralEncoderModel(nn.Module):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
# fmt: off
|
||||
mistral_remapping = [
|
||||
(r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501
|
||||
(r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
||||
r"whisper_encoder.conv1.\1",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
||||
r"whisper_encoder.conv2.\1",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.mlp.fc1.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.mlp.fc2.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
|
||||
r"whisper_encoder.layers.\1.final_layer_norm.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.norm\.(weight|bias)",
|
||||
r"whisper_encoder.layer_norm.\1",
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user