mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:45:49 +08:00
Remove all cases of fmt: on/off (#26253)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
4e256cadc2
commit
557b2e961d
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
# fmt: off
|
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_shape(m: int,
|
def benchmark_shape(
|
||||||
n: int,
|
m: int,
|
||||||
k: int,
|
n: int,
|
||||||
warmup: int = 100,
|
k: int,
|
||||||
repeat: int = 10000,
|
warmup: int = 100,
|
||||||
verbose: bool = False) -> dict:
|
repeat: int = 10000,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> dict:
|
||||||
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||||
|
|
||||||
# Create test tensors
|
# Create test tensors
|
||||||
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Reference result in BF16
|
# Reference result in BF16
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -49,34 +50,39 @@ def benchmark_shape(m: int,
|
|||||||
# Pre-quantize A for all implementations
|
# Pre-quantize A for all implementations
|
||||||
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
||||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||||
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||||
A, block_size[1], column_major_scales=True)
|
A, block_size[1], column_major_scales=True
|
||||||
|
)
|
||||||
|
|
||||||
# === DeepGEMM Implementation ===
|
# === DeepGEMM Implementation ===
|
||||||
def deepgemm_gemm():
|
def deepgemm_gemm():
|
||||||
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
|
fp8_gemm_nt(
|
||||||
(B_deepgemm, B_scale_deepgemm),
|
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
|
||||||
C_deepgemm)
|
)
|
||||||
return C_deepgemm
|
return C_deepgemm
|
||||||
|
|
||||||
# === vLLM Triton Implementation ===
|
# === vLLM Triton Implementation ===
|
||||||
def vllm_triton_gemm():
|
def vllm_triton_gemm():
|
||||||
return w8a8_triton_block_scaled_mm(A_vllm,
|
return w8a8_triton_block_scaled_mm(
|
||||||
B_vllm,
|
A_vllm,
|
||||||
A_scale_vllm,
|
B_vllm,
|
||||||
B_scale_vllm,
|
A_scale_vllm,
|
||||||
block_size,
|
B_scale_vllm,
|
||||||
output_dtype=torch.bfloat16)
|
block_size,
|
||||||
|
output_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
# === vLLM CUTLASS Implementation ===
|
# === vLLM CUTLASS Implementation ===
|
||||||
def vllm_cutlass_gemm():
|
def vllm_cutlass_gemm():
|
||||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
return ops.cutlass_scaled_mm(
|
||||||
B_vllm.T,
|
A_vllm_cutlass,
|
||||||
scale_a=A_scale_vllm_cutlass,
|
B_vllm.T,
|
||||||
scale_b=B_scale_vllm.T,
|
scale_a=A_scale_vllm_cutlass,
|
||||||
out_dtype=torch.bfloat16)
|
scale_b=B_scale_vllm.T,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
# Run correctness check first
|
# Run correctness check first
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -93,26 +99,23 @@ def benchmark_shape(m: int,
|
|||||||
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||||
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||||
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||||
print("vLLM Triton vs DeepGEMM difference: "
|
print(
|
||||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
|
"vLLM Triton vs DeepGEMM difference: "
|
||||||
print("vLLM CUTLASS vs DeepGEMM difference: "
|
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
|
||||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
|
)
|
||||||
|
print(
|
||||||
|
"vLLM CUTLASS vs DeepGEMM difference: "
|
||||||
|
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
# Benchmark implementations
|
# Benchmark implementations
|
||||||
implementations = {
|
implementations = {
|
||||||
"DeepGEMM": deepgemm_gemm,
|
"DeepGEMM": deepgemm_gemm,
|
||||||
"vLLM Triton": vllm_triton_gemm,
|
"vLLM Triton": vllm_triton_gemm,
|
||||||
"vLLM CUTLASS": vllm_cutlass_gemm
|
"vLLM CUTLASS": vllm_cutlass_gemm,
|
||||||
}
|
}
|
||||||
|
|
||||||
benchmark_results = {
|
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
|
||||||
"shape": {
|
|
||||||
"m": m,
|
|
||||||
"n": n,
|
|
||||||
"k": k
|
|
||||||
},
|
|
||||||
"implementations": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, func in implementations.items():
|
for name, func in implementations.items():
|
||||||
# Warmup
|
# Warmup
|
||||||
@ -140,38 +143,36 @@ def benchmark_shape(m: int,
|
|||||||
"tflops": tflops,
|
"tflops": tflops,
|
||||||
"gb_s": gb_s,
|
"gb_s": gb_s,
|
||||||
"diff": {
|
"diff": {
|
||||||
"DeepGEMM":
|
"DeepGEMM": 0.0
|
||||||
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
|
if name == "DeepGEMM"
|
||||||
"Reference":
|
else calc_diff(func(), C_deepgemm),
|
||||||
deepgemm_diff if name == "DeepGEMM" else
|
"Reference": deepgemm_diff
|
||||||
(vllm_triton_diff
|
if name == "DeepGEMM"
|
||||||
if name == "vLLM Triton" else vllm_cutlass_diff)
|
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
|
||||||
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate speedups
|
# Calculate speedups
|
||||||
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||||
for name, data in benchmark_results["implementations"].items():
|
for name, data in benchmark_results["implementations"].items():
|
||||||
if name != "DeepGEMM":
|
if name != "DeepGEMM":
|
||||||
speedup = baseline / data["time_ms"]
|
speedup = baseline / data["time_ms"]
|
||||||
benchmark_results["implementations"][name][
|
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
|
||||||
"speedup_vs_deepgemm"] = speedup
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"DeepGEMM is {1/speedup:.2f}x "
|
print(
|
||||||
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
|
f"DeepGEMM is {1 / speedup:.2f}x "
|
||||||
|
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
|
||||||
|
)
|
||||||
|
|
||||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
|
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
|
||||||
"time_ms"]
|
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
|
|
||||||
"time_ms"]
|
|
||||||
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||||
benchmark_results["implementations"]["vLLM CUTLASS"][
|
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
|
||||||
"speedup_vs_triton"] = cutlass_vs_triton
|
cutlass_vs_triton
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||||
@ -183,8 +184,7 @@ def benchmark_shape(m: int,
|
|||||||
|
|
||||||
def format_table_row(values, widths):
|
def format_table_row(values, widths):
|
||||||
"""Format a row with specified column widths."""
|
"""Format a row with specified column widths."""
|
||||||
return "| " + " | ".join(f"{val:{w}}"
|
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
|
||||||
for val, w in zip(values, widths)) + " |"
|
|
||||||
|
|
||||||
|
|
||||||
def print_table(headers, rows, title=None):
|
def print_table(headers, rows, title=None):
|
||||||
@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
for result in all_results:
|
for result in all_results:
|
||||||
shape = result["shape"]
|
shape = result["shape"]
|
||||||
impl_data = result["implementations"]["DeepGEMM"]
|
impl_data = result["implementations"]["DeepGEMM"]
|
||||||
deepgemm_rows.append([
|
deepgemm_rows.append(
|
||||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
[
|
||||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
|
shape["m"],
|
||||||
])
|
shape["n"],
|
||||||
|
shape["k"],
|
||||||
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print_table(deepgemm_headers,
|
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
|
||||||
deepgemm_rows,
|
|
||||||
title="DeepGEMM Implementation:")
|
|
||||||
|
|
||||||
# Print vLLM Triton table
|
# Print vLLM Triton table
|
||||||
triton_headers = [
|
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
|
||||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
|
|
||||||
]
|
|
||||||
triton_rows = []
|
triton_rows = []
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
shape = result["shape"]
|
shape = result["shape"]
|
||||||
impl_data = result["implementations"]["vLLM Triton"]
|
impl_data = result["implementations"]["vLLM Triton"]
|
||||||
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||||
triton_rows.append([
|
triton_rows.append(
|
||||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
[
|
||||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
shape["m"],
|
||||||
format_speedup(speedup)
|
shape["n"],
|
||||||
])
|
shape["k"],
|
||||||
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
format_speedup(speedup),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print_table(triton_headers,
|
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
|
||||||
triton_rows,
|
|
||||||
title="vLLM Triton Implementation:")
|
|
||||||
|
|
||||||
# Print vLLM CUTLASS table
|
# Print vLLM CUTLASS table
|
||||||
cutlass_headers = [
|
cutlass_headers = [
|
||||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
|
"m",
|
||||||
"vs Triton"
|
"n",
|
||||||
|
"k",
|
||||||
|
"Time (μs)",
|
||||||
|
"TFLOPS",
|
||||||
|
"GB/s",
|
||||||
|
"vs DeepGEMM",
|
||||||
|
"vs Triton",
|
||||||
]
|
]
|
||||||
cutlass_rows = []
|
cutlass_rows = []
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
impl_data = result["implementations"]["vLLM CUTLASS"]
|
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||||
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||||
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||||
cutlass_rows.append([
|
cutlass_rows.append(
|
||||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
[
|
||||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
shape["m"],
|
||||||
format_speedup(vs_deepgemm),
|
shape["n"],
|
||||||
format_speedup(vs_triton)
|
shape["k"],
|
||||||
])
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
format_speedup(vs_deepgemm),
|
||||||
|
format_speedup(vs_triton),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print_table(cutlass_headers,
|
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
|
||||||
cutlass_rows,
|
|
||||||
title="vLLM CUTLASS Implementation:")
|
|
||||||
|
|
||||||
# Calculate and print averages
|
# Calculate and print averages
|
||||||
print("\n===== AVERAGE PERFORMANCE =====")
|
print("\n===== AVERAGE PERFORMANCE =====")
|
||||||
|
|
||||||
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||||
avg_metrics = {
|
avg_metrics = {
|
||||||
impl: {
|
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
|
||||||
"tflops": 0,
|
|
||||||
"gb_s": 0,
|
|
||||||
"time_ms": 0
|
|
||||||
}
|
|
||||||
for impl in implementations
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||||
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||||
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||||
avg_rows.append([
|
avg_rows.append(
|
||||||
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
|
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
|
||||||
])
|
)
|
||||||
|
|
||||||
print_table(avg_headers, avg_rows)
|
print_table(avg_headers, avg_rows)
|
||||||
|
|
||||||
@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
avg_speedups = {
|
avg_speedups = {
|
||||||
"DeepGEMM vs vLLM Triton": 0,
|
"DeepGEMM vs vLLM Triton": 0,
|
||||||
"DeepGEMM vs vLLM CUTLASS": 0,
|
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||||
"vLLM CUTLASS vs vLLM Triton": 0
|
"vLLM CUTLASS vs vLLM Triton": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||||
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
|
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||||
"time_ms"]
|
|
||||||
|
|
||||||
avg_speedups[
|
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||||
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||||
avg_speedups[
|
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
|
||||||
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
vllm_triton_time / vllm_cutlass_time
|
||||||
avg_speedups[
|
)
|
||||||
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
|
|
||||||
|
|
||||||
print("\n===== AVERAGE SPEEDUPS =====")
|
print("\n===== AVERAGE SPEEDUPS =====")
|
||||||
speedup_headers = ["Comparison", "Speedup"]
|
speedup_headers = ["Comparison", "Speedup"]
|
||||||
@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
|
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
for impl in implementations:
|
for impl in implementations:
|
||||||
avg_diff[impl] += result["implementations"][impl]["diff"][
|
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
|
||||||
"Reference"]
|
|
||||||
|
|
||||||
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||||
diff_rows = []
|
diff_rows = []
|
||||||
|
|||||||
@ -442,14 +442,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
||||||
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
||||||
for i in range(num_sequences):
|
for i in range(num_sequences):
|
||||||
# fmt: off
|
chunk_f = lambda x, i: x[
|
||||||
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ...
|
||||||
|
]
|
||||||
|
|
||||||
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||||
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
X, i
|
||||||
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
)
|
||||||
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||||
# fmt: on
|
dt, i
|
||||||
|
)
|
||||||
|
B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||||
|
B, i
|
||||||
|
)
|
||||||
|
C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||||
|
C, i
|
||||||
|
)
|
||||||
|
|
||||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||||
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
|
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
|
||||||
@ -481,27 +489,42 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
|||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||||
# fmt: off
|
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...]
|
||||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...]
|
||||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...]
|
||||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...]
|
||||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
|
||||||
for i in range(num_sequences):
|
for i in range(num_sequences):
|
||||||
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
remaining_chunk_f = lambda x, i: x[
|
||||||
|
cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ...
|
||||||
|
]
|
||||||
|
|
||||||
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
remaining_X_chunked[
|
||||||
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||||
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
] = remaining_chunk_f(X, i)
|
||||||
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
remaining_dt_chunked[
|
||||||
|
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||||
|
] = remaining_chunk_f(dt, i)
|
||||||
|
remaining_B_chunked[
|
||||||
|
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||||
|
] = remaining_chunk_f(B, i)
|
||||||
|
remaining_C_chunked[
|
||||||
|
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||||
|
] = remaining_chunk_f(C, i)
|
||||||
|
|
||||||
# assert input chunking is correct
|
# assert input chunking is correct
|
||||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
concat_chunk_f = lambda pt1, pt2, i: torch.cat(
|
||||||
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
[
|
||||||
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...],
|
||||||
|
pt2[
|
||||||
|
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1],
|
||||||
|
...,
|
||||||
|
],
|
||||||
],
|
],
|
||||||
dim=0)
|
dim=0,
|
||||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
|
)
|
||||||
# fmt: on
|
concat_batch_f = lambda pt1, pt2: torch.cat(
|
||||||
|
[concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||||
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
||||||
|
|||||||
@ -182,8 +182,9 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
|||||||
finished = i == len(test_tokens) - 1
|
finished = i == len(test_tokens) - 1
|
||||||
output += detokenizer.get_next_output_text(finished, delta=True)
|
output += detokenizer.get_next_output_text(finished, delta=True)
|
||||||
|
|
||||||
# fmt: off
|
assert (
|
||||||
assert output == r'''[
|
output
|
||||||
|
== r"""[
|
||||||
{
|
{
|
||||||
"source": "Résultats",
|
"source": "Résultats",
|
||||||
"source_type": "CONCEPT",
|
"source_type": "CONCEPT",
|
||||||
@ -191,4 +192,5 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
|||||||
"target": "Israël",
|
"target": "Israël",
|
||||||
"target_type": "ORGANIZATION",
|
"target_type": "ORGANIZATION",
|
||||||
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
|
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
|
||||||
"relationship": "Obtention d'un niveau de'''
|
"relationship": "Obtention d'un niveau de"""
|
||||||
|
)
|
||||||
|
|||||||
@ -1398,12 +1398,10 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
# Tensors
|
# Tensors
|
||||||
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
|
input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)
|
||||||
|
|
||||||
# fmt: off
|
mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
|
||||||
mm_embeddings_out = [mm[:, :-4] for mm in
|
mm_embeddings_pos = [
|
||||||
multimodal_embeddings]
|
mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
|
||||||
mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in
|
]
|
||||||
multimodal_embeddings]
|
|
||||||
# fmt: in
|
|
||||||
|
|
||||||
positions, mrope_positions_delta = recompute_mrope_positions(
|
positions, mrope_positions_delta = recompute_mrope_positions(
|
||||||
input_ids_t,
|
input_ids_t,
|
||||||
|
|||||||
@ -516,14 +516,18 @@ class VoxtralForConditionalGeneration(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
# fmt: off
|
|
||||||
remapping_rules = [
|
remapping_rules = [
|
||||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||||
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||||
(r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501
|
(
|
||||||
(r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501
|
r"audio_language_adapter\.0\.weight",
|
||||||
|
r"audio_language_adapter.w_in.weight",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"audio_language_adapter\.2\.weight",
|
||||||
|
r"audio_language_adapter.w_out.weight",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
audio_params = dict(
|
audio_params = dict(
|
||||||
nn.ModuleDict(
|
nn.ModuleDict(
|
||||||
@ -678,19 +682,44 @@ class AudioLanguageAdapter(nn.Module):
|
|||||||
class VoxtralEncoderModel(nn.Module):
|
class VoxtralEncoderModel(nn.Module):
|
||||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
mistral_remapping = [
|
mistral_remapping = [
|
||||||
(r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501
|
(
|
||||||
(r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501
|
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
||||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501
|
r"whisper_encoder.conv1.\1",
|
||||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501
|
),
|
||||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501
|
(
|
||||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501
|
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
||||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501
|
r"whisper_encoder.conv2.\1",
|
||||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501
|
),
|
||||||
(r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501
|
(
|
||||||
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
||||||
|
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501
|
||||||
|
r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501
|
||||||
|
r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501
|
||||||
|
r"whisper_encoder.layers.\1.mlp.fc1.\2",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
|
||||||
|
r"whisper_encoder.layers.\1.mlp.fc2.\2",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
|
||||||
|
r"whisper_encoder.layers.\1.final_layer_norm.\2",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.transformer\.norm\.(weight|bias)",
|
||||||
|
r"whisper_encoder.layer_norm.\1",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user